From 98cea15e51913ae9cc4d220814497bf4f3fd78e4 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Mon, 18 May 2015 13:52:10 -0400
Subject: [PATCH] using a subclassed ClientSession as factory to leverage
 aiohttp connection pooling

---
 aiogremlin/__init__.py   |   3 +-
 aiogremlin/abc.py        |   8 +-
 aiogremlin/client.py     |   4 +-
 aiogremlin/connection.py | 143 +++++++-----
 aiogremlin/pool.py       |  38 ++--
 benchmark.py             |  18 +-
 conn_bench.py            |  55 +++++
 tests/tests.py           | 465 +++++++++++++++++++++------------------
 8 files changed, 438 insertions(+), 296 deletions(-)
 create mode 100644 conn_bench.py

diff --git a/aiogremlin/__init__.py b/aiogremlin/__init__.py
index f6a17d5..667336e 100644
--- a/aiogremlin/__init__.py
+++ b/aiogremlin/__init__.py
@@ -1,5 +1,6 @@
 from .abc import AbstractFactory, AbstractConnection
-from .connection import AiohttpFactory, BaseFactory, BaseConnection
+from .connection import (AiohttpFactory, BaseFactory, BaseConnection,
+    WebSocketSession)
 from .client import (create_client, GremlinClient, GremlinResponse,
     GremlinResponseStream)
 from .exceptions import RequestError, GremlinServerError, SocketClientError
diff --git a/aiogremlin/abc.py b/aiogremlin/abc.py
index 83bc1e0..79a9bb2 100644
--- a/aiogremlin/abc.py
+++ b/aiogremlin/abc.py
@@ -5,14 +5,8 @@ from abc import ABCMeta, abstractmethod
 
 class AbstractFactory(metaclass=ABCMeta):
 
-    @classmethod
     @abstractmethod
-    def connect(cls):
-        pass
-
-    @property
-    @abstractmethod
-    def factory(self):
+    def ws_connect(cls):
         pass
 
 
diff --git a/aiogremlin/client.py b/aiogremlin/client.py
index 340f28a..af4d270 100644
--- a/aiogremlin/client.py
+++ b/aiogremlin/client.py
@@ -62,7 +62,7 @@ class GremlinClient:
         self.poolsize = poolsize
         self.timeout = timeout
         self._pool = pool
-        self._factory = factory or AiohttpFactory
+        self._factory = factory or AiohttpFactory()
         if self._pool is None:
             self._connected = False
             self._conn = asyncio.async(self._connect(), loop=self._loop)
@@ -90,7 +90,7 @@ class GremlinClient:
         """
         """
         loop = kwargs.get("loop", "") or self._loop
-        connection = yield from self._factory.connect(self.uri, loop=loop)
+        connection = yield from self._factory.ws_connect(self.uri, loop=loop)
         self._connected = True
         return connection
 
diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py
index 95ab592..019c3bd 100644
--- a/aiogremlin/connection.py
+++ b/aiogremlin/connection.py
@@ -6,7 +6,7 @@ import hashlib
 import os
 
 from aiohttp import (client, hdrs, DataQueue, StreamParser,
-    WSServerHandshakeError)
+    WSServerHandshakeError, ClientSession, TCPConnector)
 from aiohttp.errors import WSServerHandshakeError
 from aiohttp.websocket import WS_KEY, Message
 from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
@@ -20,68 +20,98 @@ from aiogremlin.exceptions import SocketClientError
 from aiogremlin.log import INFO, logger
 
 
-# This is temporary until aiohttp pull #367 is merged/released.
-@asyncio.coroutine
-def ws_connect(url, protocols=(), timeout=10.0, connector=None,
-               response_class=None, autoclose=True, autoping=True, loop=None):
-    """Initiate websocket connection."""
-    if loop is None:
-        loop = asyncio.get_event_loop()
+class WebSocketSession(AbstractFactory, ClientSession):
+
+    @asyncio.coroutine
+    def ws_connect(self, url, protocols=(), timeout=10.0, connector=None,
+                   response_class=None, autoclose=True, autoping=True,
+                   loop=None):
+        """Initiate websocket connection."""
 
-    sec_key = base64.b64encode(os.urandom(16))
+        sec_key = base64.b64encode(os.urandom(16))
 
-    headers = {
-        hdrs.UPGRADE: hdrs.WEBSOCKET,
-        hdrs.CONNECTION: hdrs.UPGRADE,
-        hdrs.SEC_WEBSOCKET_VERSION: '13',
-        hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
-    }
-    if protocols:
-        headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
+        headers = {
+            hdrs.UPGRADE: hdrs.WEBSOCKET,
+            hdrs.CONNECTION: hdrs.UPGRADE,
+            hdrs.SEC_WEBSOCKET_VERSION: '13',
+            hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
+        }
+        if protocols:
+            headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
 
-    # send request
-    resp = yield from client.request(
-        'get', url, headers=headers,
-        read_until_eof=False,
-        connector=connector, loop=loop)
+        # send request
+        resp = yield from self.request('get', url, headers=headers,
+            read_until_eof=False)
 
-    # check handshake
-    if resp.status != 101:
-        raise WSServerHandshakeError('Invalid response status')
+        # check handshake
+        if resp.status != 101:
+            raise WSServerHandshakeError('Invalid response status')
 
-    if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
-        raise WSServerHandshakeError('Invalid upgrade header')
+        if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
+            raise WSServerHandshakeError('Invalid upgrade header')
 
-    if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
-        raise WSServerHandshakeError('Invalid connection header')
+        if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
+            raise WSServerHandshakeError('Invalid connection header')
 
-    # key calculation
-    key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
-    match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
-    if key != match:
-        raise WSServerHandshakeError('Invalid challenge response')
+        # key calculation
+        key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
+        match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
+        if key != match:
+            raise WSServerHandshakeError('Invalid challenge response')
 
-    # websocket protocol
-    protocol = None
-    if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
-        resp_protocols = [proto.strip() for proto in
-                          resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
+        # websocket protocol
+        protocol = None
+        if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
+            resp_protocols = [proto.strip() for proto in
+                              resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
 
-        for proto in resp_protocols:
-            if proto in protocols:
-                protocol = proto
-                break
+            for proto in resp_protocols:
+                if proto in protocols:
+                    protocol = proto
+                    break
 
-    reader = resp.connection.reader.set_parser(WebSocketParser)
-    writer = WebSocketWriter(resp.connection.writer, use_mask=True)
+        reader = resp.connection.reader.set_parser(WebSocketParser)
+        writer = WebSocketWriter(resp.connection.writer, use_mask=True)
 
-    if response_class is None:
-        response_class = ClientWebSocketResponse
+        if response_class is None:
+            response_class = ClientWebSocketResponse
 
-    return response_class(
-        reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
+        return response_class(
+            reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
 
+    def detach(self):
+        """Detach connector from session without closing the former.
+        Session is switched to closed state anyway.
+        """
+        self._connector = None
 
+
+def ws_connect(url, protocols=(), timeout=10.0, connector=None,
+               response_class=None, autoclose=True, autoping=True,
+               loop=None):
+    if loop is None:
+        asyncio.get_event_loop()
+    if connector is None:
+        connector = TCPConnector(loop=loop, force_close=True)
+
+    ws_session = WebSocketSession(loop=loop, connector=connector)
+
+    try:
+        resp = yield from ws_session.ws_connect(url,
+                                                protocols=protocols,
+                                                timeout=timeout,
+                                                connector=connector,
+                                                response_class=response_class,
+                                                autoclose=autoclose,
+                                                autoping=autoping,
+                                                loop=loop)
+        return resp
+
+    finally:
+        ws_session.detach()
+
+
+ # Will drop 'pluggable sockets' implementation in favour of aiohttp default.
 class BaseFactory(AbstractFactory):
 
     @property
@@ -91,17 +121,18 @@ class BaseFactory(AbstractFactory):
 
 class AiohttpFactory(BaseFactory):
 
-    @classmethod
     @asyncio.coroutine
-    def connect(cls, uri='ws://localhost:8182/', pool=None, protocols=(),
-                connector=None, autoclose=False, autoping=True, loop=None):
-        if pool:
-            loop = loop or pool.loop
+    def ws_connect(cls, uri='ws://localhost:8182/', protocols=(),
+                   connector=None, autoclose=False, autoping=True,
+                   response_class=None, loop=None):
+        if response_class is None:
+            response_class = GremlinClientWebSocketResponse
+
         try:
             return (yield from ws_connect(
                 uri, protocols=protocols, connector=connector,
-                response_class=GremlinClientWebSocketResponse,
-                autoclose=True, autoping=True, loop=loop))
+                response_class=response_class, autoclose=True, autoping=True,
+                loop=loop))
         except WSServerHandshakeError as e:
             raise SocketClientError(e.message)
 
diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py
index 8a7b6c6..2f31578 100644
--- a/aiogremlin/pool.py
+++ b/aiogremlin/pool.py
@@ -1,6 +1,7 @@
 import asyncio
 
-from aiogremlin.connection import AiohttpFactory
+from aiogremlin.connection import (AiohttpFactory,
+    GremlinClientWebSocketResponse)
 from aiogremlin.contextmanager import ConnectionContextManager
 from aiogremlin.log import logger
 
@@ -11,11 +12,12 @@ def create_pool():
 
 class WebSocketPool:
 
-    def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10,
-                 max_retries=10, timeout=None, loop=None, verbose=False):
+    def __init__(self, url='ws://localhost:8182/', factory=None, poolsize=10,
+                 max_retries=10, timeout=None, loop=None, verbose=False,
+                 response_class=None):
         """
         """
-        self.uri = uri
+        self.url = url
         self._factory = factory or AiohttpFactory
         self.poolsize = poolsize
         self.max_retries = max_retries
@@ -25,15 +27,21 @@ class WebSocketPool:
         self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
         self.active_conns = set()
         self.num_connecting = 0
+        self._response_class = response_class or GremlinClientWebSocketResponse
         self._closed = False
         if verbose:
             logger.setLevel(INFO)
 
     @asyncio.coroutine
     def fill_pool(self):
-        for i in range(self.poolsize):
-            conn = yield from self.factory.connect(self.uri, pool=self,
-                loop=self._loop)
+        tasks = []
+        poolsize = self.poolsize
+        for i in range(poolsize):
+            task = asyncio.async(self.factory.ws_connect(self.url,
+                response_class=self._response_class, loop=self._loop), loop=self._loop)
+            tasks.append(task)
+        for f in asyncio.as_completed(tasks, loop=self._loop):
+            conn = yield from f
             self._put(conn)
         self._connected = True
 
@@ -84,36 +92,36 @@ class WebSocketPool:
                 yield from conn.close()
 
     @asyncio.coroutine
-    def acquire(self, uri=None, loop=None, num_retries=None):
+    def acquire(self, url=None, loop=None, num_retries=None):
         if self._closed:
             raise RuntimeError("WebsocketPool is closed.")
         if num_retries is None:
             num_retries = self.max_retries
-        uri = uri or self.uri
+        url = url or self.url
         loop = loop or self.loop
         if not self._pool.empty():
             socket = self._pool.get_nowait()
-            logger.info("Reusing socket: {} at {}".format(socket, uri))
+            logger.info("Reusing socket: {} at {}".format(socket, url))
         elif self.num_active_conns + self.num_connecting >= self.poolsize:
             logger.info("Waiting for socket...")
             socket = yield from asyncio.wait_for(self._pool.get(),
                 self.timeout, loop=loop)
-            logger.info("Socket acquired: {} at {}".format(socket, uri))
+            logger.info("Socket acquired: {} at {}".format(socket, url))
         else:
             self.num_connecting += 1
             try:
-                socket = yield from self.factory.connect(uri, pool=self,
-                    loop=loop)
+                socket = yield from self.factory.ws_connect(url,
+                    response_class=self._response_class, loop=loop)
             finally:
                 self.num_connecting -= 1
         if not socket.closed:
             logger.info("New connection on socket: {} at {}".format(
-                socket, uri))
+                socket, url))
             self.active_conns.add(socket)
         # Untested.
         elif num_retries > 0:
             logger.warning("Got bad socket, retry...")
-            socket = yield from self.acquire(uri, loop, num_retries - 1)
+            socket = yield from self.acquire(url, loop, num_retries - 1)
         else:
             raise RuntimeError("Unable to connect, max retries exceeded.")
         return socket
diff --git a/benchmark.py b/benchmark.py
index 8ca0273..2e80803 100644
--- a/benchmark.py
+++ b/benchmark.py
@@ -89,6 +89,10 @@ ARGS.add_argument(
     '-w', '--warmups', action="store",
     nargs='?', type=int, default=5,
     help='num warmups (default: `%(default)s`)')
+ARGS.add_argument(
+    '-s', '--session', action="store",
+    nargs='?', type=str, default="false",
+    help='use session to establish connections (default: `%(default)s`)')
 
 
 if __name__ == "__main__":
@@ -98,12 +102,20 @@ if __name__ == "__main__":
     concurr = args.concurrency
     poolsize = args.poolsize
     num_warmups = args.warmups
+    session = args.session
     loop = asyncio.get_event_loop()
+    t1 = loop.time()
+    if session == "true":
+        factory = aiogremlin.WebSocketSession()
+    else:
+        factory = aiogremlin.AiohttpFactory()
     client = loop.run_until_complete(
-        aiogremlin.create_client(loop=loop, poolsize=poolsize))
+        aiogremlin.create_client(loop=loop, factory=factory, poolsize=poolsize))
+    t2 = loop.time()
+    print("time to establish conns: {}".format(t2 - t1))
     try:
-        print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}".format(
-            num_tests, num_warmups, num_mssg, concurr, poolsize))
+        print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}. Use Session: {}".format(
+            num_tests, num_warmups, num_mssg, concurr, poolsize, session))
         main = main(client, num_tests, num_mssg, concurr, num_warmups, loop)
         loop.run_until_complete(main)
     finally:
diff --git a/conn_bench.py b/conn_bench.py
new file mode 100644
index 0000000..d05cff0
--- /dev/null
+++ b/conn_bench.py
@@ -0,0 +1,55 @@
+import argparse
+import asyncio
+import aiogremlin
+
+
+@asyncio.coroutine
+def create_destroy(loop, factory, poolsize):
+    client = yield from aiogremlin.create_client(loop=loop,
+                                                 factory=factory,
+                                                 poolsize=poolsize)
+    yield from client.close()
+
+# NEED TO ADD MORE ARGS/CLEAN UP like benchmark.py
+ARGS = argparse.ArgumentParser(description="Run benchmark.")
+ARGS.add_argument(
+    '-t', '--tests', action="store",
+    nargs='?', type=int, default=10,
+    help='number of tests (default: `%(default)s`)')
+
+ARGS.add_argument(
+    '-s', '--session', action="store",
+    nargs='?', type=str, default="false",
+    help='use session to establish connections (default: `%(default)s`)')
+
+
+if __name__ == "__main__":
+    args = ARGS.parse_args()
+    tests = args.tests
+    print("tests", tests)
+    session = args.session
+    loop = asyncio.get_event_loop()
+    if session == "true":
+        factory = aiogremlin.WebSocketSession()
+    else:
+        factory = aiogremlin.AiohttpFactory()
+    print("factory: {}".format(factory))
+    try:
+        m1 = loop.time()
+        for x in range(50):
+            tasks = []
+            for x in range(tests):
+                task = asyncio.async(
+                    create_destroy(loop, factory, 100)
+                )
+                tasks.append(task)
+            t1 = loop.time()
+            loop.run_until_complete(asyncio.async(asyncio.gather(*tasks, loop=loop)))
+            t2 = loop.time()
+            print("avg: time to establish conn: {}".format((t2 - t1) / (tests * 50)))
+        m2 = loop.time()
+        print("time to establish conns: {}".format((m2 - m1)))
+        print("avg time to establish conns: {}".format((m2 - m1) / (tests * 100 * 50)))
+    finally:
+        loop.close()
+        print("CLOSED CLIENT AND LOOP")
diff --git a/tests/tests.py b/tests/tests.py
index 6896c02..1c8f5e3 100644
--- a/tests/tests.py
+++ b/tests/tests.py
@@ -7,44 +7,264 @@ import unittest
 import uuid
 from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
     SocketClientError, WebSocketPool, AiohttpFactory, create_client,
-    GremlinWriter, GremlinResponse)
-
-
-class GremlinClientTests(unittest.TestCase):
-
-    def setUp(self):
-        self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
-        self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop)
-
-    def tearDown(self):
-        self.loop.run_until_complete(self.gc.close())
-        self.loop.close()
-
-    def test_connection(self):
-        @asyncio.coroutine
-        def conn_coro():
-            conn = yield from self.gc._acquire()
-            self.assertFalse(conn.closed)
-            return conn
-        conn = self.loop.run_until_complete(conn_coro())
-        # Clean up the resource.
-        self.loop.run_until_complete(conn.close())
-
-    def test_sub(self):
-        execute = self.gc.execute("x + x", bindings={"x": 4})
-        results = self.loop.run_until_complete(execute)
-        self.assertEqual(results[0].data[0], 8)
-
-
-class GremlinClientPoolTests(unittest.TestCase):
+    GremlinWriter, GremlinResponse, WebSocketSession)
+# 
+#
+# class GremlinClientTests(unittest.TestCase):
+#
+#     def setUp(self):
+#         self.loop = asyncio.new_event_loop()
+#         asyncio.set_event_loop(None)
+#         self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop)
+#
+#     def tearDown(self):
+#         self.loop.run_until_complete(self.gc.close())
+#         self.loop.close()
+#
+#     def test_connection(self):
+#         @asyncio.coroutine
+#         def conn_coro():
+#             conn = yield from self.gc._acquire()
+#             self.assertFalse(conn.closed)
+#             return conn
+#         conn = self.loop.run_until_complete(conn_coro())
+#         # Clean up the resource.
+#         self.loop.run_until_complete(conn.close())
+#
+#     def test_sub(self):
+#         execute = self.gc.execute("x + x", bindings={"x": 4})
+#         results = self.loop.run_until_complete(execute)
+#         self.assertEqual(results[0].data[0], 8)
+#
+#
+# class GremlinClientPoolTests(unittest.TestCase):
+#
+#     def setUp(self):
+#         self.loop = asyncio.new_event_loop()
+#         asyncio.set_event_loop(None)
+#         self.gc = GremlinClient("ws://localhost:8182/",
+#             factory=AiohttpFactory(), pool=WebSocketPool("ws://localhost:8182/",
+#                                                        loop=self.loop),
+#             loop=self.loop)
+#
+#     def tearDown(self):
+#         self.loop.run_until_complete(self.gc.close())
+#         self.loop.close()
+#
+#     def test_connection(self):
+#         @asyncio.coroutine
+#         def conn_coro():
+#             conn = yield from self.gc._acquire()
+#             self.assertFalse(conn.closed)
+#             return conn
+#         conn = self.loop.run_until_complete(conn_coro())
+#         # Clean up the resource.
+#         self.loop.run_until_complete(conn.close())
+#
+#     def test_sub(self):
+#         execute = self.gc.execute("x + x", bindings={"x": 4})
+#         results = self.loop.run_until_complete(execute)
+#         self.assertEqual(results[0].data[0], 8)
+#
+#     def test_sub_waitfor(self):
+#         sub1 = self.gc.execute("x + x", bindings={"x": 1})
+#         sub2 = self.gc.execute("x + x", bindings={"x": 2})
+#         sub3 = self.gc.execute("x + x", bindings={"x": 4})
+#         coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop),
+#                              asyncio.async(sub2, loop=self.loop),
+#                              asyncio.async(sub3, loop=self.loop)], loop=self.loop)
+#         # Here I am looking for resource warnings.
+#         results = self.loop.run_until_complete(coro)
+#         self.assertIsNotNone(results)
+#
+#     def test_resp_stream(self):
+#         @asyncio.coroutine
+#         def stream_coro():
+#             results = []
+#             resp = yield from self.gc.submit("x + x", bindings={"x": 4})
+#             while True:
+#                 f = yield from resp.stream.read()
+#                 if f is None:
+#                     break
+#                 results.append(f)
+#             self.assertEqual(results[0].data[0], 8)
+#         self.loop.run_until_complete(stream_coro())
+#
+#     def test_resp_get(self):
+#         @asyncio.coroutine
+#         def get_coro():
+#             conn = yield from self.gc.submit("x + x", bindings={"x": 4})
+#             results = yield from conn.get()
+#             self.assertEqual(results[0].data[0], 8)
+#         self.loop.run_until_complete(get_coro())
+#
+#     def test_execute_error(self):
+#         execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4})
+#         try:
+#             self.loop.run_until_complete(execute)
+#             error = False
+#         except:
+#             error = True
+#         self.assertTrue(error)
+#
+#     def test_session_gen(self):
+#         execute = self.gc.execute("x + x", processor="session", bindings={"x": 4})
+#         results = self.loop.run_until_complete(execute)
+#         self.assertEqual(results[0].data[0], 8)
+#
+#     def test_session(self):
+#         @asyncio.coroutine
+#         def stream_coro():
+#             session = str(uuid.uuid4())
+#             resp = yield from self.gc.submit("x + x", bindings={"x": 4},
+#                 session=session)
+#             while True:
+#                 f = yield from resp.stream.read()
+#                 if f is None:
+#                     break
+#             self.assertEqual(resp.session, session)
+#         self.loop.run_until_complete(stream_coro())
+#
+#
+# class WebSocketPoolTests(unittest.TestCase):
+#
+#     def setUp(self):
+#         self.loop = asyncio.new_event_loop()
+#         asyncio.set_event_loop(None)
+#         self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop,
+#             factory=AiohttpFactory())
+#
+#     def tearDown(self):
+#         self.loop.run_until_complete(self.pool.close())
+#         self.loop.close()
+#
+#     def test_connect(self):
+#
+#         @asyncio.coroutine
+#         def conn():
+#             conn = yield from self.pool.acquire()
+#             self.assertFalse(conn.closed)
+#             self.pool.release(conn)
+#             self.assertEqual(self.pool.num_active_conns, 0)
+#
+#         self.loop.run_until_complete(conn())
+#
+#     def test_multi_connect(self):
+#
+#         @asyncio.coroutine
+#         def conn():
+#             conn1 = yield from self.pool.acquire()
+#             conn2 = yield from self.pool.acquire()
+#             self.assertFalse(conn1.closed)
+#             self.assertFalse(conn2.closed)
+#             self.pool.release(conn1)
+#             self.assertEqual(self.pool.num_active_conns, 1)
+#             self.pool.release(conn2)
+#             self.assertEqual(self.pool.num_active_conns, 0)
+#
+#         self.loop.run_until_complete(conn())
+#
+#     def test_timeout(self):
+#
+#         @asyncio.coroutine
+#         def conn():
+#             conn1 = yield from self.pool.acquire()
+#             conn2 = yield from self.pool.acquire()
+#             try:
+#                 conn3 = yield from self.pool.acquire()
+#                 timeout = False
+#             except asyncio.TimeoutError:
+#                 timeout = True
+#             self.assertTrue(timeout)
+#
+#         self.loop.run_until_complete(conn())
+#
+#     def test_socket_reuse(self):
+#
+#         @asyncio.coroutine
+#         def conn():
+#             conn1 = yield from self.pool.acquire()
+#             conn2 = yield from self.pool.acquire()
+#             try:
+#                 conn3 = yield from self.pool.acquire()
+#                 timeout = False
+#             except asyncio.TimeoutError:
+#                 timeout = True
+#             self.assertTrue(timeout)
+#             self.pool.release(conn2)
+#             conn3 = yield from self.pool.acquire()
+#             self.assertFalse(conn1.closed)
+#             self.assertFalse(conn3.closed)
+#             self.assertEqual(conn2, conn3)
+#
+#         self.loop.run_until_complete(conn())
+#
+#     def test_socket_repare(self):
+#
+#         @asyncio.coroutine
+#         def conn():
+#             conn1 = yield from self.pool.acquire()
+#             conn2 = yield from self.pool.acquire()
+#             self.assertFalse(conn1.closed)
+#             self.assertFalse(conn2.closed)
+#             yield from conn1.close()
+#             yield from conn2.close()
+#             self.assertTrue(conn2.closed)
+#             self.assertTrue(conn2.closed)
+#             self.pool.release(conn1)
+#             self.pool.release(conn2)
+#             conn1 = yield from self.pool.acquire()
+#             conn2 = yield from self.pool.acquire()
+#             self.assertFalse(conn1.closed)
+#             self.assertFalse(conn2.closed)
+#
+#         self.loop.run_until_complete(conn())
+#
+# class ContextMngrTest(unittest.TestCase):
+#
+#     def setUp(self):
+#         self.loop = asyncio.new_event_loop()
+#         asyncio.set_event_loop(None)
+#         self.pool = WebSocketPool(poolsize=1, loop=self.loop,
+#             factory=AiohttpFactory, max_retries=0)
+#
+#     def tearDown(self):
+#         self.loop.run_until_complete(self.pool.close())
+#         self.loop.close()
+#
+#     def test_connection_manager(self):
+#         results = []
+#         @asyncio.coroutine
+#         def go():
+#             with (yield from self.pool) as conn:
+#                 writer = GremlinWriter(conn)
+#                 conn = writer.write("1 + 1")
+#                 resp = GremlinResponse(conn, self.pool, loop=self.loop)
+#                 while True:
+#                     mssg = yield from resp.stream.read()
+#                     if mssg is None:
+#                         break
+#                     results.append(mssg)
+#             conn = self.pool._pool.get_nowait()
+#             self.assertTrue(conn.closed)
+#             writer = GremlinWriter(conn)
+#             try:
+#                 conn = yield from writer.write("1 + 1")
+#                 error = False
+#             except RuntimeError:
+#                 error = True
+#             self.assertTrue(error)
+#         self.loop.run_until_complete(go())
+
+
+class GremlinClientPoolSessionTests(unittest.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
         self.gc = GremlinClient("ws://localhost:8182/",
-            factory=AiohttpFactory, pool=WebSocketPool("ws://localhost:8182/",
-                                                       loop=self.loop),
+            pool=WebSocketPool("ws://localhost:8182/", loop=self.loop,
+                              factory=WebSocketSession(loop=self.loop)),
             loop=self.loop)
 
     def tearDown(self):
@@ -77,185 +297,6 @@ class GremlinClientPoolTests(unittest.TestCase):
         results = self.loop.run_until_complete(coro)
         self.assertIsNotNone(results)
 
-    def test_resp_stream(self):
-        @asyncio.coroutine
-        def stream_coro():
-            results = []
-            resp = yield from self.gc.submit("x + x", bindings={"x": 4})
-            while True:
-                f = yield from resp.stream.read()
-                if f is None:
-                    break
-                results.append(f)
-            self.assertEqual(results[0].data[0], 8)
-        self.loop.run_until_complete(stream_coro())
-
-    def test_resp_get(self):
-        @asyncio.coroutine
-        def get_coro():
-            conn = yield from self.gc.submit("x + x", bindings={"x": 4})
-            results = yield from conn.get()
-            self.assertEqual(results[0].data[0], 8)
-        self.loop.run_until_complete(get_coro())
-
-    def test_execute_error(self):
-        execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4})
-        try:
-            self.loop.run_until_complete(execute)
-            error = False
-        except:
-            error = True
-        self.assertTrue(error)
-
-    def test_session_gen(self):
-        execute = self.gc.execute("x + x", processor="session", bindings={"x": 4})
-        results = self.loop.run_until_complete(execute)
-        self.assertEqual(results[0].data[0], 8)
-
-    def test_session(self):
-        @asyncio.coroutine
-        def stream_coro():
-            session = str(uuid.uuid4())
-            resp = yield from self.gc.submit("x + x", bindings={"x": 4},
-                session=session)
-            while True:
-                f = yield from resp.stream.read()
-                if f is None:
-                    break
-            self.assertEqual(resp.session, session)
-        self.loop.run_until_complete(stream_coro())
-
-
-class WebSocketPoolTests(unittest.TestCase):
-
-    def setUp(self):
-        self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
-        self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop,
-            factory=AiohttpFactory)
-
-    def tearDown(self):
-        self.loop.run_until_complete(self.pool.close())
-        self.loop.close()
-
-    def test_connect(self):
-
-        @asyncio.coroutine
-        def conn():
-            conn = yield from self.pool.acquire()
-            self.assertFalse(conn.closed)
-            self.pool.release(conn)
-            self.assertEqual(self.pool.num_active_conns, 0)
-
-        self.loop.run_until_complete(conn())
-
-    def test_multi_connect(self):
-
-        @asyncio.coroutine
-        def conn():
-            conn1 = yield from self.pool.acquire()
-            conn2 = yield from self.pool.acquire()
-            self.assertFalse(conn1.closed)
-            self.assertFalse(conn2.closed)
-            self.pool.release(conn1)
-            self.assertEqual(self.pool.num_active_conns, 1)
-            self.pool.release(conn2)
-            self.assertEqual(self.pool.num_active_conns, 0)
-
-        self.loop.run_until_complete(conn())
-
-    def test_timeout(self):
-
-        @asyncio.coroutine
-        def conn():
-            conn1 = yield from self.pool.acquire()
-            conn2 = yield from self.pool.acquire()
-            try:
-                conn3 = yield from self.pool.acquire()
-                timeout = False
-            except asyncio.TimeoutError:
-                timeout = True
-            self.assertTrue(timeout)
-
-        self.loop.run_until_complete(conn())
-
-    def test_socket_reuse(self):
-
-        @asyncio.coroutine
-        def conn():
-            conn1 = yield from self.pool.acquire()
-            conn2 = yield from self.pool.acquire()
-            try:
-                conn3 = yield from self.pool.acquire()
-                timeout = False
-            except asyncio.TimeoutError:
-                timeout = True
-            self.assertTrue(timeout)
-            self.pool.release(conn2)
-            conn3 = yield from self.pool.acquire()
-            self.assertFalse(conn1.closed)
-            self.assertFalse(conn3.closed)
-            self.assertEqual(conn2, conn3)
-
-        self.loop.run_until_complete(conn())
-
-    def test_socket_repare(self):
-
-        @asyncio.coroutine
-        def conn():
-            conn1 = yield from self.pool.acquire()
-            conn2 = yield from self.pool.acquire()
-            self.assertFalse(conn1.closed)
-            self.assertFalse(conn2.closed)
-            yield from conn1.close()
-            yield from conn2.close()
-            self.assertTrue(conn2.closed)
-            self.assertTrue(conn2.closed)
-            self.pool.release(conn1)
-            self.pool.release(conn2)
-            conn1 = yield from self.pool.acquire()
-            conn2 = yield from self.pool.acquire()
-            self.assertFalse(conn1.closed)
-            self.assertFalse(conn2.closed)
-
-        self.loop.run_until_complete(conn())
-
-class ContextMngrTest(unittest.TestCase):
-
-    def setUp(self):
-        self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
-        self.pool = WebSocketPool(poolsize=1, loop=self.loop,
-            factory=AiohttpFactory, max_retries=0)
-
-    def tearDown(self):
-        self.loop.run_until_complete(self.pool.close())
-        self.loop.close()
-
-    def test_connection_manager(self):
-        results = []
-        @asyncio.coroutine
-        def go():
-            with (yield from self.pool) as conn:
-                writer = GremlinWriter(conn)
-                conn = writer.write("1 + 1")
-                resp = GremlinResponse(conn, self.pool, loop=self.loop)
-                while True:
-                    mssg = yield from resp.stream.read()
-                    if mssg is None:
-                        break
-                    results.append(mssg)
-            conn = self.pool._pool.get_nowait()
-            self.assertTrue(conn.closed)
-            writer = GremlinWriter(conn)
-            try:
-                conn = yield from writer.write("1 + 1")
-                error = False
-            except RuntimeError:
-                error = True
-            self.assertTrue(error)
-        self.loop.run_until_complete(go())
-
 
 class CreateClientTests(unittest.TestCase):
 
-- 
GitLab