From 91fc71cd3292c59166b9a3155b38996b21a24e69 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Sun, 17 May 2015 14:49:44 -0400
Subject: [PATCH] working on some refactoring

---
 aiogremlin/__init__.py       |   4 +-
 aiogremlin/abc.py            |  27 +--
 aiogremlin/client.py         | 122 ++++++++----
 aiogremlin/connection.py     | 367 ++++++++++++++++++-----------------
 aiogremlin/contextmanager.py |  38 ++++
 aiogremlin/log.py            |   3 +-
 aiogremlin/pool.py           | 139 +++++++++++++
 aiogremlin/protocol.py       |  39 +++-
 tests/tests.py               | 104 ++++++----
 9 files changed, 557 insertions(+), 286 deletions(-)
 create mode 100644 aiogremlin/contextmanager.py
 create mode 100644 aiogremlin/pool.py

diff --git a/aiogremlin/__init__.py b/aiogremlin/__init__.py
index e8e3780..f6a17d5 100644
--- a/aiogremlin/__init__.py
+++ b/aiogremlin/__init__.py
@@ -1,8 +1,8 @@
 from .abc import AbstractFactory, AbstractConnection
-from .connection import (WebsocketPool, AiohttpFactory, BaseFactory,
-    BaseConnection)
+from .connection import AiohttpFactory, BaseFactory, BaseConnection
 from .client import (create_client, GremlinClient, GremlinResponse,
     GremlinResponseStream)
 from .exceptions import RequestError, GremlinServerError, SocketClientError
+from .pool import WebSocketPool
 from .protocol import GremlinWriter
 __version__ = "0.0.6"
diff --git a/aiogremlin/abc.py b/aiogremlin/abc.py
index 9400e07..dd8391e 100644
--- a/aiogremlin/abc.py
+++ b/aiogremlin/abc.py
@@ -18,32 +18,23 @@ class AbstractFactory(metaclass=ABCMeta):
 
 class AbstractConnection(metaclass=ABCMeta):
 
-    @abstractmethod
-    def feed_pool(self):
-        pass
-
-    @abstractmethod
-    def release(self):
-        pass
-
-    @property
-    @abstractmethod
-    def pool(self):
-        pass
+    # @property
+    # @abstractmethod
+    # def closed(self):
+    #     pass
 
-    @property
     @abstractmethod
-    def closed(self):
+    def close():
         pass
 
     @abstractmethod
-    def close():
+    def _close():
         pass
 
     @abstractmethod
     def send(self):
         pass
 
-    @abstractmethod
-    def _receive(self):
-        pass
+    # @abstractmethod
+    # def receive(self):
+    #     pass
diff --git a/aiogremlin/client.py b/aiogremlin/client.py
index 9b40de2..518f61e 100644
--- a/aiogremlin/client.py
+++ b/aiogremlin/client.py
@@ -2,12 +2,14 @@
 
 import asyncio
 import ssl
-import uuid
 
 import aiohttp
 
-from aiogremlin.connection import WebsocketPool
-from aiogremlin.log import client_logger, INFO
+from aiogremlin.connection import AiohttpFactory
+from aiogremlin.contextmanager import ClientContextManager
+from aiogremlin.exceptions import RequestError
+from aiogremlin.log import logger, INFO
+from aiogremlin.pool import WebSocketPool
 from aiogremlin.protocol import gremlin_response_parser, GremlinWriter
 
 
@@ -16,14 +18,14 @@ def create_client(uri='ws://localhost:8182/', loop=None, ssl=None,
                   protocol=None, lang="gremlin-groovy", op="eval",
                   processor="", pool=None, factory=None, poolsize=10,
                   timeout=None, verbose=False, **kwargs):
-    pool = WebsocketPool(uri,
+    pool = WebSocketPool(uri,
                          factory=factory,
                          poolsize=poolsize,
                          timeout=timeout,
                          loop=loop,
                          verbose=verbose)
 
-    yield from pool.init_pool()
+    yield from pool.fill_pool()
 
     return GremlinClient(uri=uri,
                          loop=loop,
@@ -42,7 +44,7 @@ class GremlinClient:
     def __init__(self, uri='ws://localhost:8182/', loop=None, ssl=None,
                  protocol=None, lang="gremlin-groovy", op="eval",
                  processor="", pool=None, factory=None, poolsize=10,
-                 timeout=None, verbose=True, **kwargs):
+                 timeout=None, verbose=False, **kwargs):
         """
         """
         self.uri = uri
@@ -60,11 +62,15 @@ class GremlinClient:
         self.processor = processor or ""
         self.poolsize = poolsize
         self.timeout = timeout
-        self.pool = pool or WebsocketPool(uri, factory=factory,
-            poolsize=poolsize, timeout=timeout, loop=self._loop)
-        self.factory = factory or self.pool.factory
+        self._pool = pool
+        self._factory = factory or AiohttpFactory
+        if self._pool is None:
+            self._connected = False
+            self._conn = asyncio.async(self._connect(), loop=self._loop)
+        else:
+            self._connected = self._pool._connected
         if verbose:
-            client_logger.setLevel(INFO)
+            logger.setLevel(INFO)
 
     @property
     def loop(self):
@@ -72,45 +78,54 @@ class GremlinClient:
 
     @asyncio.coroutine
     def close(self):
-        yield from self.pool.close()
+        try:
+            if self._pool:
+                yield from self._pool.close()
+            elif self._connected:
+                yield from self._conn.close()
+        finally:
+            self._connected = False
 
     @asyncio.coroutine
-    def connect(self, **kwargs):
+    def _connect(self, **kwargs):
         """
         """
-        loop = kwargs.get("loop", "") or self.loop
-        connection = yield from self.factory.connect(self.uri, loop=loop,
-            **kwargs)
+        loop = kwargs.get("loop", "") or self._loop
+        connection = yield from self._factory.connect(self.uri, loop=loop)
+        self._connected = True
         return connection
 
     @asyncio.coroutine
-    def submit(self, gremlin, connection=None, bindings=None, lang=None,
-               op=None, processor=None, session=None, binary=True):
+    def _acquire(self, **kwargs):
+        if self._pool:
+            conn = yield from self._pool.acquire()
+        elif self._connected:
+            conn = self._conn
+        else:
+            conn = yield from self._conn
+        return conn
+            # Check here for error
+            # except Error:
+                # conn = yield from self._connect()
+
+    @asyncio.coroutine
+    def submit(self, gremlin, conn=None, bindings=None, lang=None, op=None,
+               processor=None, session=None, binary=True):
         """
         """
         lang = lang or self.lang
         op = op or self.op
         processor = processor or self.processor
-        message = {
-            "requestId": str(uuid.uuid4()),
-            "op": op,
-            "processor": processor,
-            "args":{
-                "gremlin": gremlin,
-                "bindings": bindings,
-                "language":  lang
-            }
-        }
-        if processor == "session":
-            session = session or str(uuid.uuid4())
-            message["args"]["session"] = session
-            client_logger.info(
-                "Session ID: {}".format(message["args"]["session"]))
-        if connection is None:
-            connection = yield from self.pool.connect(self.uri, loop=self.loop)
-        writer = GremlinWriter(connection)
-        connection = yield from writer.write(message, binary=binary)
-        return GremlinResponse(connection, session=session, loop=self._loop)
+        if conn is None:
+            conn = yield from self._acquire()
+        writer = GremlinWriter(conn)
+        conn = yield from writer.write(gremlin, bindings=bindings,
+            lang=lang, op=op, processor=processor, session=session,
+            binary=binary)
+        return GremlinResponse(conn,
+                               self,
+                               session=session,
+                               loop=self._loop)
 
     @asyncio.coroutine
     def execute(self, gremlin, bindings=None, lang=None,
@@ -127,10 +142,11 @@ class GremlinClient:
 
 class GremlinResponse:
 
-    def __init__(self, conn, session=None, loop=None):
+    def __init__(self, conn, client, session=None, loop=None):
         self._loop = loop or asyncio.get_event_loop()
+        self._client = client
         self._session = session
-        self._stream = GremlinResponseStream(conn, loop=self._loop)
+        self._stream = GremlinResponseStream(conn, self, loop=self._loop)
 
     @property
     def stream(self):
@@ -156,11 +172,24 @@ class GremlinResponse:
             results.append(message)
         return results
 
+    # aioredis style
+    def __enter__(self):
+        raise RuntimeError(
+            "'yield from' should be used as a context manager expression")
+
+    def __exit__(self, *args):
+        pass
+
+    def __iter__(self):
+        yield from self._pool.create_pool()
+        return ClientContextManager(self)
+
 
 class GremlinResponseStream:
 
-    def __init__(self, conn, loop=None):
+    def __init__(self, conn, resp, loop=None):
         self._conn = conn
+        self._resp = resp
         self._loop = loop or asyncio.get_event_loop()
         data_stream = aiohttp.DataQueue(loop=self._loop)
         self._stream = self._conn.parser.set_parser(gremlin_response_parser,
@@ -170,13 +199,20 @@ class GremlinResponseStream:
     def read(self):
         # For 3.0.0.M9
         # if self._stream.at_eof():
-        #     self._conn.feed_pool()
+        #     self._pool.release(self._conn)
         #     message = None
         # else:
         # This will be different 3.0.0.M9
-        yield from self._conn._receive()
+        pool = self._resp._client._pool
+        try:
+            yield from self._conn.read()
+        except RequestError:
+            if pool:
+                pool.release(self._conn)
+            print("fed pool")
         if self._stream.is_eof():
-            self._conn.feed_pool()
+            if pool:
+                pool.release(self._conn)
             message = None
         else:
             message = yield from self._stream.read()
diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py
index 633a502..26cc132 100644
--- a/aiogremlin/connection.py
+++ b/aiogremlin/connection.py
@@ -1,128 +1,85 @@
 """
 """
 import asyncio
-
-import aiohttp
+import base64
+import hashlib
+import os
+
+from aiohttp import (client, hdrs, DataQueue, StreamParser,
+    WSServerHandshakeError)
+from aiohttp.errors import WSServerHandshakeError
+from aiohttp.websocket import WS_KEY, Message
+from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
+from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING,
+    MSG_PONG)
+from aiohttp.websocket_client import (MsgType, closedMessage,
+    ClientWebSocketResponse)
 
 from aiogremlin.abc import AbstractFactory, AbstractConnection
 from aiogremlin.exceptions import SocketClientError
-from aiogremlin.log import INFO, conn_logger
-
-
-class WebsocketPool:
-
-    def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10,
-                 max_retries=10, timeout=None, loop=None, verbose=False):
-        """
-        """
-        self.uri = uri
-        self._factory = factory or AiohttpFactory
-        self.poolsize = poolsize
-        self.max_retries = max_retries
-        self.timeout = timeout
-        self._loop = loop or asyncio.get_event_loop()
-        self.pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
-        self.active_conns = set()
-        self.num_connecting = 0
-        self._closed = False
-        if verbose:
-            conn_logger.setLevel(INFO)
-
-    @asyncio.coroutine
-    def init_pool(self):
-        for i in range(self.poolsize):
-            conn = yield from self.factory.connect(self.uri, pool=self,
-                loop=self._loop)
-            self._put(conn)
-
-    @property
-    def loop(self):
-        return self._loop
-
-    @property
-    def factory(self):
-        return self._factory
-
-    @property
-    def closed(self):
-        return self._closed
-
-    @property
-    def num_active_conns(self):
-        return len(self.active_conns)
-
-    def feed_pool(self, conn):
-        if self._closed:
-            raise RuntimeError("WebsocketPool is closed.")
-        self.active_conns.discard(conn)
-        self._put(conn)
-
-    @asyncio.coroutine
-    def close(self):
-        if not self._closed:
-            if self.active_conns:
-                yield from self._close_active_conns()
-            yield from self._purge_pool()
-            self._closed = True
-
-    @asyncio.coroutine
-    def _close_active_conns(self):
-        tasks = [asyncio.async(conn.close(), loop=self.loop) for conn
-            in self.active_conns]
-        yield from asyncio.wait(tasks, loop=self.loop)
-
-    @asyncio.coroutine
-    def _purge_pool(self):
-        while True:
-            try:
-                conn = self.pool.get_nowait()
-            except asyncio.QueueEmpty:
+from aiogremlin.log import INFO, logger
+
+
+# This is temporary until aiohttp pull #367 is merged/released.
+@asyncio.coroutine
+def ws_connect(url, protocols=(), timeout=10.0, connector=None,
+               response_class=None, autoclose=True, autoping=True, loop=None):
+    """Initiate websocket connection."""
+    if loop is None:
+        loop = asyncio.get_event_loop()
+
+    sec_key = base64.b64encode(os.urandom(16))
+
+    headers = {
+        hdrs.UPGRADE: hdrs.WEBSOCKET,
+        hdrs.CONNECTION: hdrs.UPGRADE,
+        hdrs.SEC_WEBSOCKET_VERSION: '13',
+        hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
+    }
+    if protocols:
+        headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
+
+    # send request
+    resp = yield from client.request(
+        'get', url, headers=headers,
+        read_until_eof=False,
+        connector=connector, loop=loop)
+
+    # check handshake
+    if resp.status != 101:
+        raise WSServerHandshakeError('Invalid response status')
+
+    if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
+        raise WSServerHandshakeError('Invalid upgrade header')
+
+    if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
+        raise WSServerHandshakeError('Invalid connection header')
+
+    # key calculation
+    key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
+    match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
+    if key != match:
+        raise WSServerHandshakeError('Invalid challenge response')
+
+    # websocket protocol
+    protocol = None
+    if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
+        resp_protocols = [proto.strip() for proto in
+                          resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
+
+        for proto in resp_protocols:
+            if proto in protocols:
+                protocol = proto
                 break
-            else:
-                yield from conn.close()
 
-    @asyncio.coroutine
-    def connect(self, uri=None, loop=None, num_retries=None):
-        if self._closed:
-            raise RuntimeError("WebsocketPool is closed.")
-        if num_retries is None:
-            num_retries = self.max_retries
-        uri = uri or self.uri
-        loop = loop or self.loop
-        if not self.pool.empty():
-            socket = self.pool.get_nowait()
-            conn_logger.info("Reusing socket: {} at {}".format(socket, uri))
-        elif self.num_active_conns + self.num_connecting >= self.poolsize:
-            conn_logger.info("Waiting for socket...")
-            socket = yield from asyncio.wait_for(self.pool.get(),
-                self.timeout, loop=loop)
-            conn_logger.info("Socket acquired: {} at {}".format(socket, uri))
-        else:
-            self.num_connecting += 1
-            try:
-                socket = yield from self.factory.connect(uri, pool=self,
-                    loop=loop)
-            finally:
-                self.num_connecting -= 1
-        if not socket.closed:
-            conn_logger.info("New connection on socket: {} at {}".format(
-                socket, uri))
-            self.active_conns.add(socket)
-        # Untested.
-        elif num_retries > 0:
-            conn_logger.warning("Got bad socket, retry...")
-            socket = yield from self.connect(uri, loop, num_retries - 1)
-        else:
-            raise RuntimeError("Unable to connect, max retries exceeded.")
-        return socket
+    reader = resp.connection.reader.set_parser(WebSocketParser)
+    writer = WebSocketWriter(resp.connection.writer, use_mask=True)
 
-    def _put(self, socket):
-        try:
-            self.pool.put_nowait(socket)
-        except asyncio.QueueFull:
-            pass
-            # This should be - not working
-            # yield from socket.release()
+    if response_class is None:
+        response_class = ClientWebSocketResponse
+
+    return response_class(
+        reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
 
 
 class BaseFactory(AbstractFactory):
@@ -141,107 +98,163 @@ class AiohttpFactory(BaseFactory):
         if pool:
             loop = loop or pool.loop
         try:
-            socket = yield from aiohttp.ws_connect(uri, protocols=protocols,
-                connector=connector, autoclose=False, autoping=True,
-                loop=loop)
-        except aiohttp.WSServerHandshakeError as e:
+            return (yield from ws_connect(
+                uri, protocols=protocols, connector=connector,
+                response_class=GremlinClientWebSocketResponse,
+                autoclose=True, autoping=True, loop=loop))
+        except WSServerHandshakeError as e:
             raise SocketClientError(e.message)
-        return AiohttpConnection(socket, pool, loop=loop)
 
 
 class BaseConnection(AbstractConnection):
 
-    def __init__(self, socket, pool=None, loop=None):
-        self.socket = socket
+    def __init__(self, loop=None):
         self._loop = loop or asyncio.get_event_loop()
-        self._pool = pool
-        self._parser = aiohttp.StreamParser(
-            buf=aiohttp.DataQueue(loop=self._loop), loop=self._loop)
+        self._parser = StreamParser(
+            buf=DataQueue(loop=self._loop), loop=self._loop)
 
     @property
     def parser(self):
         return self._parser
 
-    def feed_pool(self):
-        if self.pool:
-            if self in self.pool.active_conns:
-                self.pool.feed_pool(self)
 
-    @asyncio.coroutine
-    def release(self):
-        try:
-            yield from self.close()
-        finally:
-            if self in self.pool.active_conns:
-                self.pool.active_conns.discard(self)
-
-    @property
-    def pool(self):
-        return self._pool
+class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
 
+    def __init__(self, reader, writer, protocol, response, timeout, autoclose,
+                 autoping, loop):
+        BaseConnection.__init__(self, loop=loop)
+        ClientWebSocketResponse.__init__(self, reader, writer, protocol,
+            response, timeout, autoclose, autoping, loop)
 
-class AiohttpConnection(BaseConnection):
+    @asyncio.coroutine
+    def close(self, *, code=1000, message=b''):
+        if not self._closed:
+            self._closed = True
+            closed = self._close()
+            if closed:
+                return True
+            while True:
+                try:
+                    msg = yield from asyncio.wait_for(
+                        self._reader.read(), self._timeout, loop=self._loop)
+                except asyncio.CancelledError:
+                    self._close_code = 1006
+                    self._response.close(force=True)
+                    raise
+                except Exception as exc:
+                    self._close_code = 1006
+                    self._exception = exc
+                    self._response.close(force=True)
+                    return True
+
+                if msg.tp == MsgType.close:
+                    self._close_code = msg.data
+                    self._response.close(force=True)
+                    return True
+        else:
+            return False
 
-    @property
-    def closed(self):
-        return self.socket.closed
+    def _close(self):
+        try:
+            self._writer.close(code, message)
+        except asyncio.CancelledError:
+            self._close_code = 1006
+            self._response.close(force=True)
+            raise
+        except Exception as exc:
+            self._close_code = 1006
+            self._exception = exc
+            self._response.close(force=True)
+            return True
 
-    @asyncio.coroutine
-    def close(self):
-        if not self.socket._closed:
-            try:
-                yield from self.socket.close()
-            finally:
-                # Socket should close despite errors.
-                self._closed = True
+        if self._closing:
+            self._response.close(force=True)
+            return True
 
-    @asyncio.coroutine
+    # @asyncio.coroutine
     def send(self, message, binary=True):
         if binary:
-            method = self.socket.send_bytes
+            method = self.send_bytes
         else:
-            method = self.socket.send_str
+            method = self.send_str
         try:
             method(message)
         except RuntimeError:
             # Socket closed.
-            yield from self.release()
             raise
         except TypeError:
             # Bytes/string input error.
-            yield from self.release()
             raise
 
     @asyncio.coroutine
-    def _receive(self):
+    def read(self):
         """Implements a dispatcher using the aiohttp websocket protocol."""
         try:
-            message = yield from self.socket.receive()
+            message = yield from self.receive()
         except (asyncio.CancelledError, asyncio.TimeoutError):
-            yield from self.release()
-            raise
-        except RuntimeError:
-            yield from self.release()
+            # Hmm maybe don't close here
+            yield from self.close()
             raise
-        if message.tp == aiohttp.MsgType.binary:
+        if message.tp == MsgType.binary:
             try:
                 self.parser.feed_data(message.data.decode())
             except Exception:
-                self.release()
+                # Hmm maybe don't close here
+                yield from self.close()
                 raise
-        elif message.tp == aiohttp.MsgType.text:
+        elif message.tp == MsgType.text:
             try:
                 self.parser.feed_data(message.data.strip())
             except Exception:
-                self.release()
+                # Hmm maybe don't close here
+                yield from self.close()
                 raise
         else:
-            try:
-                if message.tp == aiohttp.MsgType.close:
-                    raise RuntimeError("Socket connection closed by server.")
-                elif message.tp == aiohttp.MsgType.error:
-                    raise SocketClientError(self.socket.exception())
-                elif message.tp == aiohttp.MsgType.closed:
-                    raise RuntimeError("Socket closed.")
-            finally:
-                yield from self.release()
+            if message.tp == MsgType.close:
+                raise RuntimeError("Socket connection closed by server.")
+            elif message.tp == MsgType.error:
+                raise SocketClientError(self.socket.exception())
+            elif message.tp == MsgType.closed:
+                raise RuntimeError("Socket closed.")
+
+    # @asyncio.coroutine
+    # def receive(self):
+    #     if self._waiting:
+    #         raise RuntimeError('Concurrent call to receive() is not allowed')
+    #
+    #     self._waiting = True
+    #     try:
+    #         while True:
+    #             if self._closed:
+    #                 return closedMessage
+    #
+    #             try:
+    #                 msg = yield from self._reader.read()
+    #             except (asyncio.CancelledError, asyncio.TimeoutError):
+    #                 raise
+    #             except WebSocketError as exc:
+    #                 self._close_code = exc.code
+    #                 yield from self.close(code=exc.code)
+    #                 return Message(MsgType.error, exc, None)
+    #             except Exception as exc:
+    #                 self._exception = exc
+    #                 self._closing = True
+    #                 self._close_code = 1006
+    #                 yield from self.close()
+    #                 return Message(MsgType.error, exc, None)
+    #
+    #             if msg.tp == MsgType.close:
+    #                 self._closing = True
+    #                 self._close_code = msg.data
+    #                 if not self._closed and self._autoclose:
+    #                     yield from self.close()
+    #                 return msg
+    #             elif not self._closed:
+    #                 if msg.tp == MsgType.ping and self._autoping:
+    #                     self._writer.pong(msg.data)
+    #                 elif msg.tp == MsgType.pong and self._autoping:
+    #                     continue
+    #                 else:
+    #                     return msg
+    #     finally:
+    #         self._waiting = False
diff --git a/aiogremlin/contextmanager.py b/aiogremlin/contextmanager.py
new file mode 100644
index 0000000..f5c1f3a
--- /dev/null
+++ b/aiogremlin/contextmanager.py
@@ -0,0 +1,38 @@
+class ClientContextManager:
+
+    __slots__ = ("_client")
+
+    def __init__(self, client):
+        self._client = client
+
+    def __enter__(self):
+        return self._client
+
+    def __exit__(self, *args):
+        try:
+            yield from self._client.close()
+        finally:
+            self._client = None
+
+
+class ConnectionContextManager:
+
+    __slots__ = ("_conn", "_pool")
+
+    def __init__(self, conn, pool):
+        self._conn = conn
+        self._pool = pool
+
+    def __enter__(self):
+        return self._conn
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        print("in __exit__")
+        import ipdb; ipdb.set_trace()
+        print("agains")
+        try:
+            print("hello")
+            yield from self._conn.release()
+        finally:
+            self._conn = None
+            self._pool = None
diff --git a/aiogremlin/log.py b/aiogremlin/log.py
index 8a4a47d..4db9367 100644
--- a/aiogremlin/log.py
+++ b/aiogremlin/log.py
@@ -8,5 +8,4 @@ logging.basicConfig(
     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
 
 
-client_logger = logging.getLogger("aiogremlin.client")
-conn_logger = logging.getLogger("aiogremlin.connection")
+logger = logging.getLogger("aiogremlin")
diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py
new file mode 100644
index 0000000..8a7b6c6
--- /dev/null
+++ b/aiogremlin/pool.py
@@ -0,0 +1,139 @@
+import asyncio
+
+from aiogremlin.connection import AiohttpFactory
+from aiogremlin.contextmanager import ConnectionContextManager
+from aiogremlin.log import logger
+
+
+def create_pool():
+    pass
+
+
+class WebSocketPool:
+
+    def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10,
+                 max_retries=10, timeout=None, loop=None, verbose=False):
+        """
+        """
+        self.uri = uri
+        self._factory = factory or AiohttpFactory
+        self.poolsize = poolsize
+        self.max_retries = max_retries
+        self.timeout = timeout
+        self._connected = False
+        self._loop = loop or asyncio.get_event_loop()
+        self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
+        self.active_conns = set()
+        self.num_connecting = 0
+        self._closed = False
+        if verbose:
+            logger.setLevel(INFO)
+
+    @asyncio.coroutine
+    def fill_pool(self):
+        for i in range(self.poolsize):
+            conn = yield from self.factory.connect(self.uri, pool=self,
+                loop=self._loop)
+            self._put(conn)
+        self._connected = True
+
+    @property
+    def loop(self):
+        return self._loop
+
+    @property
+    def factory(self):
+        return self._factory
+
+    @property
+    def closed(self):
+        return self._closed
+
+    @property
+    def num_active_conns(self):
+        return len(self.active_conns)
+
+    def release(self, conn):
+        if self._closed:
+            raise RuntimeError("WebsocketPool is closed.")
+        self.active_conns.discard(conn)
+        self._put(conn)
+
+    @asyncio.coroutine
+    def close(self):
+        if not self._closed:
+            if self.active_conns:
+                yield from self._close_active_conns()
+            yield from self._purge_pool()
+            self._closed = True
+
+    @asyncio.coroutine
+    def _close_active_conns(self):
+        tasks = [asyncio.async(conn.close(), loop=self.loop) for conn
+            in self.active_conns]
+        yield from asyncio.wait(tasks, loop=self.loop)
+
+    @asyncio.coroutine
+    def _purge_pool(self):
+        while True:
+            try:
+                conn = self._pool.get_nowait()
+            except asyncio.QueueEmpty:
+                break
+            else:
+                yield from conn.close()
+
+    @asyncio.coroutine
+    def acquire(self, uri=None, loop=None, num_retries=None):
+        if self._closed:
+            raise RuntimeError("WebsocketPool is closed.")
+        if num_retries is None:
+            num_retries = self.max_retries
+        uri = uri or self.uri
+        loop = loop or self.loop
+        if not self._pool.empty():
+            socket = self._pool.get_nowait()
+            logger.info("Reusing socket: {} at {}".format(socket, uri))
+        elif self.num_active_conns + self.num_connecting >= self.poolsize:
+            logger.info("Waiting for socket...")
+            socket = yield from asyncio.wait_for(self._pool.get(),
+                self.timeout, loop=loop)
+            logger.info("Socket acquired: {} at {}".format(socket, uri))
+        else:
+            self.num_connecting += 1
+            try:
+                socket = yield from self.factory.connect(uri, pool=self,
+                    loop=loop)
+            finally:
+                self.num_connecting -= 1
+        if not socket.closed:
+            logger.info("New connection on socket: {} at {}".format(
+                socket, uri))
+            self.active_conns.add(socket)
+        # Untested.
+        elif num_retries > 0:
+            logger.warning("Got bad socket, retry...")
+            socket = yield from self.acquire(uri, loop, num_retries - 1)
+        else:
+            raise RuntimeError("Unable to connect, max retries exceeded.")
+        return socket
+
+    def _put(self, socket):
+        try:
+            self._pool.put_nowait(socket)
+        except asyncio.QueueFull:
+            pass
+            # This should be - not working
+            # yield from socket.release()
+
+    # aioredis style
+    def __enter__(self):
+        raise RuntimeError(
+            "'yield from' should be used as a context manager expression")
+
+    def __exit__(self, *args):
+        pass
+
+    def __iter__(self):
+        conn = yield from self.acquire()
+        return ConnectionContextManager(conn, self)
diff --git a/aiogremlin/protocol.py b/aiogremlin/protocol.py
index 897acac..9144cb7 100644
--- a/aiogremlin/protocol.py
+++ b/aiogremlin/protocol.py
@@ -2,10 +2,15 @@
 
 import asyncio
 import collections
+import uuid
 
-import ujson
+try:
+    import ujson as json
+except ImportError:
+    import json
 
 from aiogremlin.exceptions import RequestError, GremlinServerError
+from aiogremlin.log import logger
 
 
 Message = collections.namedtuple("Message", ["status_code", "data", "message",
@@ -15,7 +20,7 @@ Message = collections.namedtuple("Message", ["status_code", "data", "message",
 def gremlin_response_parser(out, buf):
     while True:
         message = yield
-        message = ujson.loads(message)
+        message = json.loads(message)
         message = Message(message["status"]["code"],
                           message["result"]["data"],
                           message["result"]["meta"],
@@ -46,11 +51,15 @@ class GremlinWriter:
         self._connection = connection
 
     @asyncio.coroutine
-    def write(self, message, binary=True, mime_type="application/json"):
-        message = ujson.dumps(message)
+    def write(self, gremlin, bindings=None, lang="gremlin-groovy", op="eval",
+              processor="", session=None, binary=True,
+              mime_type="application/json"):
+        message = self._prepare_message(gremlin, bindings=bindings,
+            lang=lang, op=op, processor=processor, session=session)
+        message = json.dumps(message)
         if binary:
             message = self._set_message_header(message, mime_type)
-        yield from self._connection.send(message, binary)
+        self._connection.send(message, binary)
         return self._connection
 
     @staticmethod
@@ -61,3 +70,23 @@ class GremlinWriter:
         else:
             raise ValueError("Unknown mime type.")
         return b"".join([mime_len, mime_type, bytes(message, "utf-8")])
+
+    @staticmethod
+    def _prepare_message(gremlin, bindings=None, lang="gremlin-groovy", op="eval",
+                        processor="", session=None):
+        message = {
+            "requestId": str(uuid.uuid4()),
+            "op": op,
+            "processor": processor,
+            "args":{
+                "gremlin": gremlin,
+                "bindings": bindings,
+                "language":  lang
+            }
+        }
+        if processor == "session":
+            session = session or str(uuid.uuid4())
+            message["args"]["session"] = session
+            logger.info(
+                "Session ID: {}".format(message["args"]["session"]))
+        return message
diff --git a/tests/tests.py b/tests/tests.py
index 5ba9c13..28c2d0f 100644
--- a/tests/tests.py
+++ b/tests/tests.py
@@ -6,16 +6,19 @@ import itertools
 import unittest
 import uuid
 from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
-    SocketClientError, WebsocketPool, AiohttpFactory, create_client)
+    SocketClientError, WebSocketPool, AiohttpFactory, create_client,
+    GremlinWriter, GremlinResponse)
 
 #
-class GremlinClientTests(unittest.TestCase):
+class GremlinClientPoolTests(unittest.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
         self.gc = GremlinClient("ws://localhost:8182/",
-            factory=AiohttpFactory, loop=self.loop)
+            factory=AiohttpFactory, pool=WebSocketPool("ws://localhost:8182/",
+                                                       loop=self.loop),
+            loop=self.loop)
 
     def tearDown(self):
         self.loop.run_until_complete(self.gc.close())
@@ -24,7 +27,7 @@ class GremlinClientTests(unittest.TestCase):
     def test_connection(self):
         @asyncio.coroutine
         def conn_coro():
-            conn = yield from self.gc.connect()
+            conn = yield from self.gc._acquire()
             self.assertFalse(conn.closed)
             return conn
         conn = self.loop.run_until_complete(conn_coro())
@@ -96,12 +99,12 @@ class GremlinClientTests(unittest.TestCase):
         self.loop.run_until_complete(stream_coro())
 
 
-class WebsocketPoolTests(unittest.TestCase):
+class WebSocketPoolTests(unittest.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
-        self.pool = WebsocketPool(poolsize=2, timeout=1, loop=self.loop,
+        self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop,
             factory=AiohttpFactory)
 
     def tearDown(self):
@@ -112,10 +115,9 @@ class WebsocketPoolTests(unittest.TestCase):
 
         @asyncio.coroutine
         def conn():
-            conn = yield from self.pool.connect()
-            self.assertIsNotNone(conn.socket)
+            conn = yield from self.pool.acquire()
             self.assertFalse(conn.closed)
-            conn.feed_pool()
+            self.pool.release(conn)
             self.assertEqual(self.pool.num_active_conns, 0)
 
         self.loop.run_until_complete(conn())
@@ -124,15 +126,13 @@ class WebsocketPoolTests(unittest.TestCase):
 
         @asyncio.coroutine
         def conn():
-            conn1 = yield from self.pool.connect()
-            conn2 = yield from self.pool.connect()
-            self.assertIsNotNone(conn1.socket)
+            conn1 = yield from self.pool.acquire()
+            conn2 = yield from self.pool.acquire()
             self.assertFalse(conn1.closed)
-            self.assertIsNotNone(conn2.socket)
             self.assertFalse(conn2.closed)
-            conn1.feed_pool()
+            self.pool.release(conn1)
             self.assertEqual(self.pool.num_active_conns, 1)
-            conn2.feed_pool()
+            self.pool.release(conn2)
             self.assertEqual(self.pool.num_active_conns, 0)
 
         self.loop.run_until_complete(conn())
@@ -141,10 +141,10 @@ class WebsocketPoolTests(unittest.TestCase):
 
         @asyncio.coroutine
         def conn():
-            conn1 = yield from self.pool.connect()
-            conn2 = yield from self.pool.connect()
+            conn1 = yield from self.pool.acquire()
+            conn2 = yield from self.pool.acquire()
             try:
-                conn3 = yield from self.pool.connect()
+                conn3 = yield from self.pool.acquire()
                 timeout = False
             except asyncio.TimeoutError:
                 timeout = True
@@ -156,21 +156,19 @@ class WebsocketPoolTests(unittest.TestCase):
 
         @asyncio.coroutine
         def conn():
-            conn1 = yield from self.pool.connect()
-            conn2 = yield from self.pool.connect()
+            conn1 = yield from self.pool.acquire()
+            conn2 = yield from self.pool.acquire()
             try:
-                conn3 = yield from self.pool.connect()
+                conn3 = yield from self.pool.acquire()
                 timeout = False
             except asyncio.TimeoutError:
                 timeout = True
             self.assertTrue(timeout)
-            conn2.feed_pool()
-            conn3 = yield from self.pool.connect()
-            self.assertIsNotNone(conn1.socket)
+            self.pool.release(conn2)
+            conn3 = yield from self.pool.acquire()
             self.assertFalse(conn1.closed)
-            self.assertIsNotNone(conn3.socket)
             self.assertFalse(conn3.closed)
-            self.assertEqual(conn2.socket, conn3.socket)
+            self.assertEqual(conn2, conn3)
 
         self.loop.run_until_complete(conn())
 
@@ -178,26 +176,53 @@ class WebsocketPoolTests(unittest.TestCase):
 
         @asyncio.coroutine
         def conn():
-            conn1 = yield from self.pool.connect()
-            conn2 = yield from self.pool.connect()
-            self.assertIsNotNone(conn1.socket)
+            conn1 = yield from self.pool.acquire()
+            conn2 = yield from self.pool.acquire()
             self.assertFalse(conn1.closed)
-            self.assertIsNotNone(conn2.socket)
             self.assertFalse(conn2.closed)
-            yield from conn1.socket.close()
-            yield from conn2.socket.close()
+            yield from conn1.close()
+            yield from conn2.close()
             self.assertTrue(conn2.closed)
             self.assertTrue(conn2.closed)
-            conn1.feed_pool()
-            conn2.feed_pool()
-            conn1 = yield from self.pool.connect()
-            conn2 = yield from self.pool.connect()
-            self.assertIsNotNone(conn1.socket)
+            self.pool.release(conn1)
+            self.pool.release(conn2)
+            conn1 = yield from self.pool.acquire()
+            conn2 = yield from self.pool.acquire()
             self.assertFalse(conn1.closed)
-            self.assertIsNotNone(conn2.socket)
             self.assertFalse(conn2.closed)
 
         self.loop.run_until_complete(conn())
+#
+# class ContextMngrTest(unittest.TestCase):
+#
+#     def setUp(self):
+#         self.loop = asyncio.new_event_loop()
+#         asyncio.set_event_loop(None)
+#         self.pool = WebSocketPool(poolsize=1, loop=self.loop,
+#             factory=AiohttpFactory)
+#
+#     def tearDown(self):
+#         self.loop.run_until_complete(self.pool.close())
+#         self.loop.close()
+#
+#     def test_connection_manager(self):
+#         results = []
+#         @asyncio.coroutine
+#         def go():
+#
+#             # import ipdb; ipdb.set_trace()
+#             with (yield from self.pool) as conn:
+#                 writer = GremlinWriter(conn)
+#                 conn = yield from writer.write("1 + 1")
+#                 resp = GremlinResponse(conn, loop=self.loop)
+#                 while True:
+#                     mssg = yield from resp.stream.read()
+#                     if mssg is None:
+#                         break
+#                     results.append(mssg)
+#             conn = self.pool._pool.get_nowait()
+#             self.assertTrue(conn.closed)
+#         self.loop.run_until_complete(go())
 
 
 class CreateClientTests(unittest.TestCase):
@@ -206,7 +231,8 @@ class CreateClientTests(unittest.TestCase):
         @asyncio.coroutine
         def go(loop):
             gc = yield from create_client(poolsize=10, loop=loop)
-            self.assertEqual(gc.pool.pool.qsize(), 10)
+            self.assertEqual(gc._pool._pool.qsize(), 10)
+            yield from gc.close()
 
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
-- 
GitLab