From a717c621d6a3f8fedcf6a4f10b3c7d25ca647525 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Fri, 22 May 2015 20:58:41 -0400
Subject: [PATCH] simplified client receive, made some pep fixes

---
 aiogremlin/__init__.py   |  1 +
 aiogremlin/client.py     |  5 ++-
 aiogremlin/connection.py | 93 ++++++++++++++--------------------------
 aiogremlin/exceptions.py | 40 +++++++++--------
 aiogremlin/pool.py       |  2 +-
 aiogremlin/protocol.py   | 10 +++--
 benchmark.py             | 11 +++--
 conn_bench.py            |  9 ++--
 tests/tests.py           | 54 +++++++++++++----------
 9 files changed, 111 insertions(+), 114 deletions(-)

diff --git a/aiogremlin/__init__.py b/aiogremlin/__init__.py
index 0fc5fac..6c957c5 100644
--- a/aiogremlin/__init__.py
+++ b/aiogremlin/__init__.py
@@ -3,4 +3,5 @@ from .client import *
 from .exceptions import *
 from .pool import *
 from .protocol import *
+
 __version__ = "0.0.8"
diff --git a/aiogremlin/client.py b/aiogremlin/client.py
index 224e193..965a8e2 100644
--- a/aiogremlin/client.py
+++ b/aiogremlin/client.py
@@ -23,7 +23,8 @@ def create_client(*, url='ws://localhost:8182/', loop=None,
                   timeout=None, verbose=False, fill_pool=True, connector=None):
 
     if factory is None:
-        factory = WebSocketSession(connector=connector,
+        factory = WebSocketSession(
+            connector=connector,
             ws_response_class=GremlinClientWebSocketResponse,
             loop=loop)
 
@@ -141,7 +142,7 @@ class GremlinClient:
 
     @asyncio.coroutine
     def execute(self, gremlin, *, bindings=None, lang=None,
-               op=None, processor=None, consumer=None, collect=True, **kwargs):
+                op=None, processor=None, consumer=None, collect=True):
         """
         """
         lang = lang or self.lang
diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py
index 30ce5eb..f249ee8 100644
--- a/aiogremlin/connection.py
+++ b/aiogremlin/connection.py
@@ -6,14 +6,14 @@ import hashlib
 import os
 
 from aiohttp import (client, hdrs, DataQueue, StreamParser,
-    WSServerHandshakeError, ClientSession, TCPConnector)
+                     WSServerHandshakeError, ClientSession, TCPConnector)
 from aiohttp.errors import WSServerHandshakeError
 from aiohttp.websocket import WS_KEY, Message
 from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
 from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING,
-    MSG_PONG)
+                               MSG_PONG)
 from aiohttp.websocket_client import (MsgType, closedMessage,
-    ClientWebSocketResponse)
+                                      ClientWebSocketResponse)
 
 from aiogremlin.exceptions import SocketClientError
 from aiogremlin.log import INFO, logger
@@ -34,7 +34,6 @@ class WebSocketSession(ClientSession):
 
         self._ws_response_class = ws_response_class
 
-
     @asyncio.coroutine
     def ws_connect(self, url, *,
                    protocols=(),
@@ -72,15 +71,17 @@ class WebSocketSession(ClientSession):
 
         # key calculation
         key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
-        match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
+        match = base64.b64encode(
+            hashlib.sha1(sec_key + WS_KEY).digest()).decode()
         if key != match:
             raise WSServerHandshakeError('Invalid challenge response')
 
         # websocket protocol
         protocol = None
         if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
-            resp_protocols = [proto.strip() for proto in
-                              resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
+            resp_protocols = [
+                proto.strip() for proto in
+                resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
 
             for proto in resp_protocols:
                 if proto in protocols:
@@ -115,13 +116,14 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None,
 
     ws_session = WebSocketSession(loop=loop, connector=connector)
     try:
-        resp = yield from ws_session.ws_connect(url,
-                                                protocols=protocols,
-                                                timeout=timeout,
-                                                ws_response_class=ws_response_class,
-                                                autoclose=autoclose,
-                                                autoping=autoping,
-                                                loop=loop)
+        resp = yield from ws_session.ws_connect(
+            url,
+            protocols=protocols,
+            timeout=timeout,
+            ws_response_class=ws_response_class,
+            autoclose=autoclose,
+            autoping=autoping,
+            loop=loop)
         return resp
 
     finally:
@@ -144,8 +146,8 @@ class GremlinFactory:
         try:
             return (yield from ws_connect(
                 url, protocols=protocols, connector=connector,
-                ws_response_class=ws_response_class, autoclose=True, autoping=True,
-                loop=loop))
+                ws_response_class=ws_response_class, autoclose=True,
+                autoping=True, loop=loop))
         except WSServerHandshakeError as e:
             raise SocketClientError(e.message)
 
@@ -155,7 +157,8 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse):
     def __init__(self, reader, writer, protocol, response, timeout, autoclose,
                  autoping, loop):
         ClientWebSocketResponse.__init__(self, reader, writer, protocol,
-            response, timeout, autoclose, autoping, loop)
+                                         response, timeout, autoclose,
+                                         autoping, loop)
         self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop)
 
     @property
@@ -228,47 +231,15 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse):
 
     @asyncio.coroutine
     def receive(self):
-        if self._waiting:
-            raise RuntimeError('Concurrent call to receive() is not allowed')
-
-        self._waiting = True
-        try:
-            while True:
-                if self._closed:
-                    return closedMessage
-
-                try:
-                    msg = yield from self._reader.read()
-                except (asyncio.CancelledError, asyncio.TimeoutError):
-                    raise
-                except WebSocketError as exc:
-                    self._close_code = exc.code
-                    yield from self.close(code=exc.code)
-                    raise
-                except Exception as exc:
-                    self._exception = exc
-                    self._closing = True
-                    self._close_code = 1006
-                    yield from self.close()
-                    raise
-                if msg.tp == MsgType.close:
-                    self._closing = True
-                    self._close_code = msg.data
-                    if not self._closed and self._autoclose:
-                        yield from self.close()
-                    raise RuntimeError("Socket connection closed by server.")
-                elif not self._closed:
-                    if msg.tp == MsgType.ping and self._autoping:
-                        self._writer.pong(msg.data)
-                    elif msg.tp == MsgType.pong and self._autoping:
-                        continue
-                    else:
-                        if msg.tp == MsgType.binary:
-                            self.parser.feed_data(msg.data.decode())
-                        elif msg.tp == MsgType.text:
-                            self.parser.feed_data(msg.data.strip())
-                        else:
-                            raise RuntimeError("Unknown message type.")
-                        break
-        finally:
-            self._waiting = False
+        msg = yield from super().receive()
+        if msg.tp == MsgType.binary:
+            self.parser.feed_data(msg.data.decode())
+        elif msg.tp == MsgType.text:
+            self.parser.feed_data(msg.data.strip())
+        else:
+            if msg.tp == MsgType.close:
+                yield from ws.close()
+            elif msg.tp == MsgType.error:
+                raise msg[1]
+            elif msg.tp == MsgType.closed:
+                pass
diff --git a/aiogremlin/exceptions.py b/aiogremlin/exceptions.py
index 8b42d0b..285e83d 100644
--- a/aiogremlin/exceptions.py
+++ b/aiogremlin/exceptions.py
@@ -5,7 +5,8 @@ Gremlin Server exceptions.
 __all__ = ("RequestError", "GremlinServerError", "SocketClientError")
 
 
-class SocketClientError(IOError): pass
+class SocketClientError(IOError):
+    pass
 
 
 class StatusException(IOError):
@@ -18,22 +19,22 @@ class StatusException(IOError):
         self.response = {
             498: ("MALFORMED_REQUEST",
                   ("The request message was not properly formatted which " +
-                   "means it could not be parsed at all or the 'op' code was " +
-                   "not recognized such that Gremlin Server could properly " +
-                   "route it for processing. Check the message format and " +
-                   "retry the request")),
+                   "means it could not be parsed at all or the 'op' code " +
+                   "was not recognized such that Gremlin Server could " +
+                   "properly route it for processing. Check the message " +
+                   "format and retry the request")),
             499: ("INVALID_REQUEST_ARGUMENTS",
                   ("The request message was parseable, but the arguments " +
                    "supplied in the message were in conflict or incomplete. " +
                    "Check the message format and retry the request.")),
             500: ("SERVER_ERROR",
                   ("A general server error occurred that prevented the " +
-                  "request from being processed.")),
+                   "request from being processed.")),
             596: ("TRAVERSAL_EVALUATION",
                   ("The remote " +
                    "{@link org.apache.tinkerpop.gremlin.process.Traversal} " +
-                   "submitted for processing evaluated in on the server with " +
-                   "errors and could not be processed")),
+                   "submitted for processing evaluated in on the server " +
+                   "with errors and could not be processed")),
             597: ("SCRIPT_EVALUATION",
                   ("The script submitted for processing evaluated in the " +
                    "{@code ScriptEngine} with errors and could not be  " +
@@ -44,20 +45,25 @@ class StatusException(IOError):
                    "request and could therefore only partially respond or " +
                    " not respond at all.")),
             599: ("SERIALIZATION",
-                  ("The server was not capable of serializing an object that " +
-                   "was returned from the script supplied on the request. " +
-                   "Either transform the object into something Gremlin " +
-                   "Server can process within the script or install mapper " +
-                   "serialization classes to Gremlin Server."))
+                  ("The server was not capable of serializing an object " +
+                   "that was returned from the script supplied on the " +
+                   "requst. Either transform the object into something " +
+                   "Gremlin Server can process within the script or install " +
+                   "mapper serialization classes to Gremlin Server."))
         }
         if result:
             result = "\n\n{}".format(result)
-        self.message = 'Code [{}]: {}. {}.{}'.format(self.value,
-            self.response[self.value][0], self.response[self.value][1], result)
+        self.message = 'Code [{}]: {}. {}.{}'.format(
+            self.value,
+            self.response[self.value][0],
+            self.response[self.value][1],
+            result)
         super().__init__(self.message)
 
 
-class RequestError(StatusException): pass
+class RequestError(StatusException):
+    pass
 
 
-class GremlinServerError(StatusException): pass
+class GremlinServerError(StatusException):
+    pass
diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py
index 117e0c4..d53314e 100644
--- a/aiogremlin/pool.py
+++ b/aiogremlin/pool.py
@@ -1,7 +1,7 @@
 import asyncio
 
 from aiogremlin.connection import (GremlinFactory,
-    GremlinClientWebSocketResponse)
+                                   GremlinClientWebSocketResponse)
 from aiogremlin.contextmanager import ConnectionContextManager
 from aiogremlin.log import logger
 
diff --git a/aiogremlin/protocol.py b/aiogremlin/protocol.py
index ddfb11e..87bdcec 100644
--- a/aiogremlin/protocol.py
+++ b/aiogremlin/protocol.py
@@ -15,8 +15,9 @@ from aiogremlin.log import logger
 __all__ = ("GremlinWriter",)
 
 
-Message = collections.namedtuple("Message", ["status_code", "data", "message",
-    "metadata"])
+Message = collections.namedtuple(
+    "Message",
+    ["status_code", "data", "message", "metadata"])
 
 
 def gremlin_response_parser(out, buf):
@@ -55,7 +56,8 @@ class GremlinWriter:
     def write(self, gremlin, bindings=None, lang="gremlin-groovy", op="eval",
               processor="", session=None, binary=True,
               mime_type="application/json"):
-        message = self._prepare_message(gremlin,
+        message = self._prepare_message(
+            gremlin,
             bindings=bindings,
             lang=lang,
             op=op,
@@ -83,7 +85,7 @@ class GremlinWriter:
             "requestId": str(uuid.uuid4()),
             "op": op,
             "processor": processor,
-            "args":{
+            "args": {
                 "gremlin": gremlin,
                 "bindings": bindings,
                 "language":  lang
diff --git a/benchmark.py b/benchmark.py
index 8310dd9..f78f804 100644
--- a/benchmark.py
+++ b/benchmark.py
@@ -45,8 +45,8 @@ def run(client, count, concurrency, loop):
     yield from asyncio.gather(*bombers, loop=loop)
     t2 = loop.time()
     mps = processed_count / (t2 - t1)
-    print("Benchmark complete: {} mps. {} messages in {}".format(mps,
-        processed_count, t2-t1))
+    print("Benchmark complete: {} mps. {} messages in {}".format(
+        mps, processed_count, t2-t1))
     return mps
 
 
@@ -102,11 +102,14 @@ if __name__ == "__main__":
     t1 = loop.time()
     factory = aiogremlin.GremlinFactory()
     client = loop.run_until_complete(
-        aiogremlin.create_client(loop=loop, factory=factory, poolsize=poolsize))
+        aiogremlin.create_client(loop=loop,
+                                 factory=factory,
+                                 poolsize=poolsize))
     t2 = loop.time()
     print("time to establish conns: {}".format(t2 - t1))
     try:
-        print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}.".format(
+        print(
+            "Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}.".format(
             num_tests, num_warmups, num_mssg, concurr, poolsize))
         main = main(client, num_tests, num_mssg, concurr, num_warmups, loop)
         loop.run_until_complete(main)
diff --git a/conn_bench.py b/conn_bench.py
index c099cdc..5eefd2c 100644
--- a/conn_bench.py
+++ b/conn_bench.py
@@ -44,12 +44,15 @@ if __name__ == "__main__":
                 )
                 tasks.append(task)
             t1 = loop.time()
-            loop.run_until_complete(asyncio.async(asyncio.gather(*tasks, loop=loop)))
+            loop.run_until_complete(
+                asyncio.async(asyncio.gather(*tasks, loop=loop)))
             t2 = loop.time()
-            print("avg: time to establish conn: {}".format((t2 - t1) / (tests * 100)))
+            print("avg: time to establish conn: {}".format(
+                (t2 - t1) / (tests * 100)))
         m2 = loop.time()
         print("time to establish conns: {}".format((m2 - m1)))
-        print("avg time to establish conns: {}".format((m2 - m1) / (tests * 100 * 50)))
+        print("avg time to establish conns: {}".format(
+            (m2 - m1) / (tests * 100 * 50)))
     finally:
         loop.close()
         print("CLOSED CLIENT AND LOOP")
diff --git a/tests/tests.py b/tests/tests.py
index 7937833..2079052 100644
--- a/tests/tests.py
+++ b/tests/tests.py
@@ -6,8 +6,9 @@ import itertools
 import unittest
 import uuid
 from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
-    SocketClientError, WebSocketPool, GremlinFactory, create_client,
-    GremlinWriter, GremlinResponse, WebSocketSession)
+                        SocketClientError, WebSocketPool, GremlinFactory,
+                        create_client, GremlinWriter, GremlinResponse,
+                        WebSocketSession)
 
 
 class GremlinClientTests(unittest.TestCase):
@@ -45,10 +46,11 @@ class GremlinClientPoolTests(unittest.TestCase):
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
+        pool = WebSocketPool("ws://localhost:8182/", loop=self.loop)
         self.gc = GremlinClient(url="ws://localhost:8182/",
-            factory=GremlinFactory(), pool=WebSocketPool("ws://localhost:8182/",
-                                                       loop=self.loop),
-            loop=self.loop)
+                                factory=GremlinFactory(),
+                                pool=pool,
+                                loop=self.loop)
 
     def tearDown(self):
         self.loop.run_until_complete(self.gc.close())
@@ -74,8 +76,9 @@ class GremlinClientPoolTests(unittest.TestCase):
         sub2 = self.gc.execute("x + x", bindings={"x": 2})
         sub3 = self.gc.execute("x + x", bindings={"x": 4})
         coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop),
-                             asyncio.async(sub2, loop=self.loop),
-                             asyncio.async(sub3, loop=self.loop)], loop=self.loop)
+                              asyncio.async(sub2, loop=self.loop),
+                              asyncio.async(sub3, loop=self.loop)],
+                              loop=self.loop)
         # Here I am looking for resource warnings.
         results = self.loop.run_until_complete(coro)
         self.assertIsNotNone(results)
@@ -111,7 +114,8 @@ class GremlinClientPoolTests(unittest.TestCase):
         self.assertTrue(error)
 
     def test_session_gen(self):
-        execute = self.gc.execute("x + x", processor="session", bindings={"x": 4})
+        execute = self.gc.execute("x + x", processor="session",
+                                  bindings={"x": 4})
         results = self.loop.run_until_complete(execute)
         self.assertEqual(results[0].data[0], 8)
 
@@ -120,7 +124,7 @@ class GremlinClientPoolTests(unittest.TestCase):
         def stream_coro():
             session = str(uuid.uuid4())
             resp = yield from self.gc.submit("x + x", bindings={"x": 4},
-                session=session)
+                                             session=session)
             while True:
                 f = yield from resp.stream.read()
                 if f is None:
@@ -135,10 +139,10 @@ class WebSocketPoolTests(unittest.TestCase):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
         self.pool = WebSocketPool("ws://localhost:8182/",
-            poolsize=2,
-            timeout=1,
-            loop=self.loop,
-            factory=GremlinFactory())
+                                  poolsize=2,
+                                  timeout=1,
+                                  loop=self.loop,
+                                  factory=GremlinFactory())
 
     def tearDown(self):
         self.loop.run_until_complete(self.pool.close())
@@ -226,16 +230,17 @@ class WebSocketPoolTests(unittest.TestCase):
 
         self.loop.run_until_complete(conn())
 
+
 class ContextMngrTest(unittest.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
         self.pool = WebSocketPool("ws://localhost:8182/",
-            poolsize=1,
-            loop=self.loop,
-            factory=GremlinFactory(),
-            max_retries=0)
+                                  poolsize=1,
+                                  loop=self.loop,
+                                  factory=GremlinFactory(),
+                                  max_retries=0)
 
     def tearDown(self):
         self.loop.run_until_complete(self.pool.close())
@@ -255,6 +260,7 @@ class ContextMngrTest(unittest.TestCase):
 
     def test_connection_manager(self):
         results = []
+
         @asyncio.coroutine
         def go():
             with (yield from self.pool) as conn:
@@ -302,6 +308,7 @@ class ContextMngrTest(unittest.TestCase):
 
     def test_connection_manager_error(self):
         results = []
+
         @asyncio.coroutine
         def go():
             with (yield from self.pool) as conn:
@@ -331,10 +338,12 @@ class GremlinClientPoolSessionTests(unittest.TestCase):
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
+        pool = WebSocketPool("ws://localhost:8182/",
+                             loop=self.loop,
+                             factory=WebSocketSession(loop=self.loop))
         self.gc = GremlinClient("ws://localhost:8182/",
-            pool=WebSocketPool("ws://localhost:8182/", loop=self.loop,
-                               factory=WebSocketSession(loop=self.loop)),
-            loop=self.loop)
+                                pool=pool,
+                                loop=self.loop)
 
     def tearDown(self):
         self.loop.run_until_complete(self.gc.close())
@@ -360,8 +369,9 @@ class GremlinClientPoolSessionTests(unittest.TestCase):
         sub2 = self.gc.execute("x + x", bindings={"x": 2})
         sub3 = self.gc.execute("x + x", bindings={"x": 4})
         coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop),
-                             asyncio.async(sub2, loop=self.loop),
-                             asyncio.async(sub3, loop=self.loop)], loop=self.loop)
+                              asyncio.async(sub2, loop=self.loop),
+                              asyncio.async(sub3, loop=self.loop)],
+                              loop=self.loop)
         # Here I am looking for resource warnings.
         results = self.loop.run_until_complete(coro)
         self.assertIsNotNone(results)
-- 
GitLab