From 98cea15e51913ae9cc4d220814497bf4f3fd78e4 Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Mon, 18 May 2015 13:52:10 -0400 Subject: [PATCH] using a subclassed ClientSession as factory to leverage aiohttp connection pooling --- aiogremlin/__init__.py | 3 +- aiogremlin/abc.py | 8 +- aiogremlin/client.py | 4 +- aiogremlin/connection.py | 143 +++++++----- aiogremlin/pool.py | 38 ++-- benchmark.py | 18 +- conn_bench.py | 55 +++++ tests/tests.py | 465 +++++++++++++++++++++------------------ 8 files changed, 438 insertions(+), 296 deletions(-) create mode 100644 conn_bench.py diff --git a/aiogremlin/__init__.py b/aiogremlin/__init__.py index f6a17d5..667336e 100644 --- a/aiogremlin/__init__.py +++ b/aiogremlin/__init__.py @@ -1,5 +1,6 @@ from .abc import AbstractFactory, AbstractConnection -from .connection import AiohttpFactory, BaseFactory, BaseConnection +from .connection import (AiohttpFactory, BaseFactory, BaseConnection, + WebSocketSession) from .client import (create_client, GremlinClient, GremlinResponse, GremlinResponseStream) from .exceptions import RequestError, GremlinServerError, SocketClientError diff --git a/aiogremlin/abc.py b/aiogremlin/abc.py index 83bc1e0..79a9bb2 100644 --- a/aiogremlin/abc.py +++ b/aiogremlin/abc.py @@ -5,14 +5,8 @@ from abc import ABCMeta, abstractmethod class AbstractFactory(metaclass=ABCMeta): - @classmethod @abstractmethod - def connect(cls): - pass - - @property - @abstractmethod - def factory(self): + def ws_connect(cls): pass diff --git a/aiogremlin/client.py b/aiogremlin/client.py index 340f28a..af4d270 100644 --- a/aiogremlin/client.py +++ b/aiogremlin/client.py @@ -62,7 +62,7 @@ class GremlinClient: self.poolsize = poolsize self.timeout = timeout self._pool = pool - self._factory = factory or AiohttpFactory + self._factory = factory or AiohttpFactory() if self._pool is None: self._connected = False self._conn = asyncio.async(self._connect(), loop=self._loop) @@ -90,7 +90,7 @@ class GremlinClient: """ """ loop = kwargs.get("loop", "") or self._loop - connection = yield from self._factory.connect(self.uri, loop=loop) + connection = yield from self._factory.ws_connect(self.uri, loop=loop) self._connected = True return connection diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py index 95ab592..019c3bd 100644 --- a/aiogremlin/connection.py +++ b/aiogremlin/connection.py @@ -6,7 +6,7 @@ import hashlib import os from aiohttp import (client, hdrs, DataQueue, StreamParser, - WSServerHandshakeError) + WSServerHandshakeError, ClientSession, TCPConnector) from aiohttp.errors import WSServerHandshakeError from aiohttp.websocket import WS_KEY, Message from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError @@ -20,68 +20,98 @@ from aiogremlin.exceptions import SocketClientError from aiogremlin.log import INFO, logger -# This is temporary until aiohttp pull #367 is merged/released. -@asyncio.coroutine -def ws_connect(url, protocols=(), timeout=10.0, connector=None, - response_class=None, autoclose=True, autoping=True, loop=None): - """Initiate websocket connection.""" - if loop is None: - loop = asyncio.get_event_loop() +class WebSocketSession(AbstractFactory, ClientSession): + + @asyncio.coroutine + def ws_connect(self, url, protocols=(), timeout=10.0, connector=None, + response_class=None, autoclose=True, autoping=True, + loop=None): + """Initiate websocket connection.""" - sec_key = base64.b64encode(os.urandom(16)) + sec_key = base64.b64encode(os.urandom(16)) - headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_VERSION: '13', - hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), - } - if protocols: - headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) + headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_VERSION: '13', + hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), + } + if protocols: + headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) - # send request - resp = yield from client.request( - 'get', url, headers=headers, - read_until_eof=False, - connector=connector, loop=loop) + # send request + resp = yield from self.request('get', url, headers=headers, + read_until_eof=False) - # check handshake - if resp.status != 101: - raise WSServerHandshakeError('Invalid response status') + # check handshake + if resp.status != 101: + raise WSServerHandshakeError('Invalid response status') - if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket': - raise WSServerHandshakeError('Invalid upgrade header') + if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket': + raise WSServerHandshakeError('Invalid upgrade header') - if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade': - raise WSServerHandshakeError('Invalid connection header') + if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade': + raise WSServerHandshakeError('Invalid connection header') - # key calculation - key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') - match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() - if key != match: - raise WSServerHandshakeError('Invalid challenge response') + # key calculation + key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() + if key != match: + raise WSServerHandshakeError('Invalid challenge response') - # websocket protocol - protocol = None - if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: - resp_protocols = [proto.strip() for proto in - resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] + # websocket protocol + protocol = None + if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: + resp_protocols = [proto.strip() for proto in + resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] - for proto in resp_protocols: - if proto in protocols: - protocol = proto - break + for proto in resp_protocols: + if proto in protocols: + protocol = proto + break - reader = resp.connection.reader.set_parser(WebSocketParser) - writer = WebSocketWriter(resp.connection.writer, use_mask=True) + reader = resp.connection.reader.set_parser(WebSocketParser) + writer = WebSocketWriter(resp.connection.writer, use_mask=True) - if response_class is None: - response_class = ClientWebSocketResponse + if response_class is None: + response_class = ClientWebSocketResponse - return response_class( - reader, writer, protocol, resp, timeout, autoclose, autoping, loop) + return response_class( + reader, writer, protocol, resp, timeout, autoclose, autoping, loop) + def detach(self): + """Detach connector from session without closing the former. + Session is switched to closed state anyway. + """ + self._connector = None + +def ws_connect(url, protocols=(), timeout=10.0, connector=None, + response_class=None, autoclose=True, autoping=True, + loop=None): + if loop is None: + asyncio.get_event_loop() + if connector is None: + connector = TCPConnector(loop=loop, force_close=True) + + ws_session = WebSocketSession(loop=loop, connector=connector) + + try: + resp = yield from ws_session.ws_connect(url, + protocols=protocols, + timeout=timeout, + connector=connector, + response_class=response_class, + autoclose=autoclose, + autoping=autoping, + loop=loop) + return resp + + finally: + ws_session.detach() + + + # Will drop 'pluggable sockets' implementation in favour of aiohttp default. class BaseFactory(AbstractFactory): @property @@ -91,17 +121,18 @@ class BaseFactory(AbstractFactory): class AiohttpFactory(BaseFactory): - @classmethod @asyncio.coroutine - def connect(cls, uri='ws://localhost:8182/', pool=None, protocols=(), - connector=None, autoclose=False, autoping=True, loop=None): - if pool: - loop = loop or pool.loop + def ws_connect(cls, uri='ws://localhost:8182/', protocols=(), + connector=None, autoclose=False, autoping=True, + response_class=None, loop=None): + if response_class is None: + response_class = GremlinClientWebSocketResponse + try: return (yield from ws_connect( uri, protocols=protocols, connector=connector, - response_class=GremlinClientWebSocketResponse, - autoclose=True, autoping=True, loop=loop)) + response_class=response_class, autoclose=True, autoping=True, + loop=loop)) except WSServerHandshakeError as e: raise SocketClientError(e.message) diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py index 8a7b6c6..2f31578 100644 --- a/aiogremlin/pool.py +++ b/aiogremlin/pool.py @@ -1,6 +1,7 @@ import asyncio -from aiogremlin.connection import AiohttpFactory +from aiogremlin.connection import (AiohttpFactory, + GremlinClientWebSocketResponse) from aiogremlin.contextmanager import ConnectionContextManager from aiogremlin.log import logger @@ -11,11 +12,12 @@ def create_pool(): class WebSocketPool: - def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10, - max_retries=10, timeout=None, loop=None, verbose=False): + def __init__(self, url='ws://localhost:8182/', factory=None, poolsize=10, + max_retries=10, timeout=None, loop=None, verbose=False, + response_class=None): """ """ - self.uri = uri + self.url = url self._factory = factory or AiohttpFactory self.poolsize = poolsize self.max_retries = max_retries @@ -25,15 +27,21 @@ class WebSocketPool: self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) self.active_conns = set() self.num_connecting = 0 + self._response_class = response_class or GremlinClientWebSocketResponse self._closed = False if verbose: logger.setLevel(INFO) @asyncio.coroutine def fill_pool(self): - for i in range(self.poolsize): - conn = yield from self.factory.connect(self.uri, pool=self, - loop=self._loop) + tasks = [] + poolsize = self.poolsize + for i in range(poolsize): + task = asyncio.async(self.factory.ws_connect(self.url, + response_class=self._response_class, loop=self._loop), loop=self._loop) + tasks.append(task) + for f in asyncio.as_completed(tasks, loop=self._loop): + conn = yield from f self._put(conn) self._connected = True @@ -84,36 +92,36 @@ class WebSocketPool: yield from conn.close() @asyncio.coroutine - def acquire(self, uri=None, loop=None, num_retries=None): + def acquire(self, url=None, loop=None, num_retries=None): if self._closed: raise RuntimeError("WebsocketPool is closed.") if num_retries is None: num_retries = self.max_retries - uri = uri or self.uri + url = url or self.url loop = loop or self.loop if not self._pool.empty(): socket = self._pool.get_nowait() - logger.info("Reusing socket: {} at {}".format(socket, uri)) + logger.info("Reusing socket: {} at {}".format(socket, url)) elif self.num_active_conns + self.num_connecting >= self.poolsize: logger.info("Waiting for socket...") socket = yield from asyncio.wait_for(self._pool.get(), self.timeout, loop=loop) - logger.info("Socket acquired: {} at {}".format(socket, uri)) + logger.info("Socket acquired: {} at {}".format(socket, url)) else: self.num_connecting += 1 try: - socket = yield from self.factory.connect(uri, pool=self, - loop=loop) + socket = yield from self.factory.ws_connect(url, + response_class=self._response_class, loop=loop) finally: self.num_connecting -= 1 if not socket.closed: logger.info("New connection on socket: {} at {}".format( - socket, uri)) + socket, url)) self.active_conns.add(socket) # Untested. elif num_retries > 0: logger.warning("Got bad socket, retry...") - socket = yield from self.acquire(uri, loop, num_retries - 1) + socket = yield from self.acquire(url, loop, num_retries - 1) else: raise RuntimeError("Unable to connect, max retries exceeded.") return socket diff --git a/benchmark.py b/benchmark.py index 8ca0273..2e80803 100644 --- a/benchmark.py +++ b/benchmark.py @@ -89,6 +89,10 @@ ARGS.add_argument( '-w', '--warmups', action="store", nargs='?', type=int, default=5, help='num warmups (default: `%(default)s`)') +ARGS.add_argument( + '-s', '--session', action="store", + nargs='?', type=str, default="false", + help='use session to establish connections (default: `%(default)s`)') if __name__ == "__main__": @@ -98,12 +102,20 @@ if __name__ == "__main__": concurr = args.concurrency poolsize = args.poolsize num_warmups = args.warmups + session = args.session loop = asyncio.get_event_loop() + t1 = loop.time() + if session == "true": + factory = aiogremlin.WebSocketSession() + else: + factory = aiogremlin.AiohttpFactory() client = loop.run_until_complete( - aiogremlin.create_client(loop=loop, poolsize=poolsize)) + aiogremlin.create_client(loop=loop, factory=factory, poolsize=poolsize)) + t2 = loop.time() + print("time to establish conns: {}".format(t2 - t1)) try: - print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}".format( - num_tests, num_warmups, num_mssg, concurr, poolsize)) + print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}. Use Session: {}".format( + num_tests, num_warmups, num_mssg, concurr, poolsize, session)) main = main(client, num_tests, num_mssg, concurr, num_warmups, loop) loop.run_until_complete(main) finally: diff --git a/conn_bench.py b/conn_bench.py new file mode 100644 index 0000000..d05cff0 --- /dev/null +++ b/conn_bench.py @@ -0,0 +1,55 @@ +import argparse +import asyncio +import aiogremlin + + +@asyncio.coroutine +def create_destroy(loop, factory, poolsize): + client = yield from aiogremlin.create_client(loop=loop, + factory=factory, + poolsize=poolsize) + yield from client.close() + +# NEED TO ADD MORE ARGS/CLEAN UP like benchmark.py +ARGS = argparse.ArgumentParser(description="Run benchmark.") +ARGS.add_argument( + '-t', '--tests', action="store", + nargs='?', type=int, default=10, + help='number of tests (default: `%(default)s`)') + +ARGS.add_argument( + '-s', '--session', action="store", + nargs='?', type=str, default="false", + help='use session to establish connections (default: `%(default)s`)') + + +if __name__ == "__main__": + args = ARGS.parse_args() + tests = args.tests + print("tests", tests) + session = args.session + loop = asyncio.get_event_loop() + if session == "true": + factory = aiogremlin.WebSocketSession() + else: + factory = aiogremlin.AiohttpFactory() + print("factory: {}".format(factory)) + try: + m1 = loop.time() + for x in range(50): + tasks = [] + for x in range(tests): + task = asyncio.async( + create_destroy(loop, factory, 100) + ) + tasks.append(task) + t1 = loop.time() + loop.run_until_complete(asyncio.async(asyncio.gather(*tasks, loop=loop))) + t2 = loop.time() + print("avg: time to establish conn: {}".format((t2 - t1) / (tests * 50))) + m2 = loop.time() + print("time to establish conns: {}".format((m2 - m1))) + print("avg time to establish conns: {}".format((m2 - m1) / (tests * 100 * 50))) + finally: + loop.close() + print("CLOSED CLIENT AND LOOP") diff --git a/tests/tests.py b/tests/tests.py index 6896c02..1c8f5e3 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -7,44 +7,264 @@ import unittest import uuid from aiogremlin import (GremlinClient, RequestError, GremlinServerError, SocketClientError, WebSocketPool, AiohttpFactory, create_client, - GremlinWriter, GremlinResponse) - - -class GremlinClientTests(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop) - - def tearDown(self): - self.loop.run_until_complete(self.gc.close()) - self.loop.close() - - def test_connection(self): - @asyncio.coroutine - def conn_coro(): - conn = yield from self.gc._acquire() - self.assertFalse(conn.closed) - return conn - conn = self.loop.run_until_complete(conn_coro()) - # Clean up the resource. - self.loop.run_until_complete(conn.close()) - - def test_sub(self): - execute = self.gc.execute("x + x", bindings={"x": 4}) - results = self.loop.run_until_complete(execute) - self.assertEqual(results[0].data[0], 8) - - -class GremlinClientPoolTests(unittest.TestCase): + GremlinWriter, GremlinResponse, WebSocketSession) +# +# +# class GremlinClientTests(unittest.TestCase): +# +# def setUp(self): +# self.loop = asyncio.new_event_loop() +# asyncio.set_event_loop(None) +# self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop) +# +# def tearDown(self): +# self.loop.run_until_complete(self.gc.close()) +# self.loop.close() +# +# def test_connection(self): +# @asyncio.coroutine +# def conn_coro(): +# conn = yield from self.gc._acquire() +# self.assertFalse(conn.closed) +# return conn +# conn = self.loop.run_until_complete(conn_coro()) +# # Clean up the resource. +# self.loop.run_until_complete(conn.close()) +# +# def test_sub(self): +# execute = self.gc.execute("x + x", bindings={"x": 4}) +# results = self.loop.run_until_complete(execute) +# self.assertEqual(results[0].data[0], 8) +# +# +# class GremlinClientPoolTests(unittest.TestCase): +# +# def setUp(self): +# self.loop = asyncio.new_event_loop() +# asyncio.set_event_loop(None) +# self.gc = GremlinClient("ws://localhost:8182/", +# factory=AiohttpFactory(), pool=WebSocketPool("ws://localhost:8182/", +# loop=self.loop), +# loop=self.loop) +# +# def tearDown(self): +# self.loop.run_until_complete(self.gc.close()) +# self.loop.close() +# +# def test_connection(self): +# @asyncio.coroutine +# def conn_coro(): +# conn = yield from self.gc._acquire() +# self.assertFalse(conn.closed) +# return conn +# conn = self.loop.run_until_complete(conn_coro()) +# # Clean up the resource. +# self.loop.run_until_complete(conn.close()) +# +# def test_sub(self): +# execute = self.gc.execute("x + x", bindings={"x": 4}) +# results = self.loop.run_until_complete(execute) +# self.assertEqual(results[0].data[0], 8) +# +# def test_sub_waitfor(self): +# sub1 = self.gc.execute("x + x", bindings={"x": 1}) +# sub2 = self.gc.execute("x + x", bindings={"x": 2}) +# sub3 = self.gc.execute("x + x", bindings={"x": 4}) +# coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop), +# asyncio.async(sub2, loop=self.loop), +# asyncio.async(sub3, loop=self.loop)], loop=self.loop) +# # Here I am looking for resource warnings. +# results = self.loop.run_until_complete(coro) +# self.assertIsNotNone(results) +# +# def test_resp_stream(self): +# @asyncio.coroutine +# def stream_coro(): +# results = [] +# resp = yield from self.gc.submit("x + x", bindings={"x": 4}) +# while True: +# f = yield from resp.stream.read() +# if f is None: +# break +# results.append(f) +# self.assertEqual(results[0].data[0], 8) +# self.loop.run_until_complete(stream_coro()) +# +# def test_resp_get(self): +# @asyncio.coroutine +# def get_coro(): +# conn = yield from self.gc.submit("x + x", bindings={"x": 4}) +# results = yield from conn.get() +# self.assertEqual(results[0].data[0], 8) +# self.loop.run_until_complete(get_coro()) +# +# def test_execute_error(self): +# execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4}) +# try: +# self.loop.run_until_complete(execute) +# error = False +# except: +# error = True +# self.assertTrue(error) +# +# def test_session_gen(self): +# execute = self.gc.execute("x + x", processor="session", bindings={"x": 4}) +# results = self.loop.run_until_complete(execute) +# self.assertEqual(results[0].data[0], 8) +# +# def test_session(self): +# @asyncio.coroutine +# def stream_coro(): +# session = str(uuid.uuid4()) +# resp = yield from self.gc.submit("x + x", bindings={"x": 4}, +# session=session) +# while True: +# f = yield from resp.stream.read() +# if f is None: +# break +# self.assertEqual(resp.session, session) +# self.loop.run_until_complete(stream_coro()) +# +# +# class WebSocketPoolTests(unittest.TestCase): +# +# def setUp(self): +# self.loop = asyncio.new_event_loop() +# asyncio.set_event_loop(None) +# self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop, +# factory=AiohttpFactory()) +# +# def tearDown(self): +# self.loop.run_until_complete(self.pool.close()) +# self.loop.close() +# +# def test_connect(self): +# +# @asyncio.coroutine +# def conn(): +# conn = yield from self.pool.acquire() +# self.assertFalse(conn.closed) +# self.pool.release(conn) +# self.assertEqual(self.pool.num_active_conns, 0) +# +# self.loop.run_until_complete(conn()) +# +# def test_multi_connect(self): +# +# @asyncio.coroutine +# def conn(): +# conn1 = yield from self.pool.acquire() +# conn2 = yield from self.pool.acquire() +# self.assertFalse(conn1.closed) +# self.assertFalse(conn2.closed) +# self.pool.release(conn1) +# self.assertEqual(self.pool.num_active_conns, 1) +# self.pool.release(conn2) +# self.assertEqual(self.pool.num_active_conns, 0) +# +# self.loop.run_until_complete(conn()) +# +# def test_timeout(self): +# +# @asyncio.coroutine +# def conn(): +# conn1 = yield from self.pool.acquire() +# conn2 = yield from self.pool.acquire() +# try: +# conn3 = yield from self.pool.acquire() +# timeout = False +# except asyncio.TimeoutError: +# timeout = True +# self.assertTrue(timeout) +# +# self.loop.run_until_complete(conn()) +# +# def test_socket_reuse(self): +# +# @asyncio.coroutine +# def conn(): +# conn1 = yield from self.pool.acquire() +# conn2 = yield from self.pool.acquire() +# try: +# conn3 = yield from self.pool.acquire() +# timeout = False +# except asyncio.TimeoutError: +# timeout = True +# self.assertTrue(timeout) +# self.pool.release(conn2) +# conn3 = yield from self.pool.acquire() +# self.assertFalse(conn1.closed) +# self.assertFalse(conn3.closed) +# self.assertEqual(conn2, conn3) +# +# self.loop.run_until_complete(conn()) +# +# def test_socket_repare(self): +# +# @asyncio.coroutine +# def conn(): +# conn1 = yield from self.pool.acquire() +# conn2 = yield from self.pool.acquire() +# self.assertFalse(conn1.closed) +# self.assertFalse(conn2.closed) +# yield from conn1.close() +# yield from conn2.close() +# self.assertTrue(conn2.closed) +# self.assertTrue(conn2.closed) +# self.pool.release(conn1) +# self.pool.release(conn2) +# conn1 = yield from self.pool.acquire() +# conn2 = yield from self.pool.acquire() +# self.assertFalse(conn1.closed) +# self.assertFalse(conn2.closed) +# +# self.loop.run_until_complete(conn()) +# +# class ContextMngrTest(unittest.TestCase): +# +# def setUp(self): +# self.loop = asyncio.new_event_loop() +# asyncio.set_event_loop(None) +# self.pool = WebSocketPool(poolsize=1, loop=self.loop, +# factory=AiohttpFactory, max_retries=0) +# +# def tearDown(self): +# self.loop.run_until_complete(self.pool.close()) +# self.loop.close() +# +# def test_connection_manager(self): +# results = [] +# @asyncio.coroutine +# def go(): +# with (yield from self.pool) as conn: +# writer = GremlinWriter(conn) +# conn = writer.write("1 + 1") +# resp = GremlinResponse(conn, self.pool, loop=self.loop) +# while True: +# mssg = yield from resp.stream.read() +# if mssg is None: +# break +# results.append(mssg) +# conn = self.pool._pool.get_nowait() +# self.assertTrue(conn.closed) +# writer = GremlinWriter(conn) +# try: +# conn = yield from writer.write("1 + 1") +# error = False +# except RuntimeError: +# error = True +# self.assertTrue(error) +# self.loop.run_until_complete(go()) + + +class GremlinClientPoolSessionTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) self.gc = GremlinClient("ws://localhost:8182/", - factory=AiohttpFactory, pool=WebSocketPool("ws://localhost:8182/", - loop=self.loop), + pool=WebSocketPool("ws://localhost:8182/", loop=self.loop, + factory=WebSocketSession(loop=self.loop)), loop=self.loop) def tearDown(self): @@ -77,185 +297,6 @@ class GremlinClientPoolTests(unittest.TestCase): results = self.loop.run_until_complete(coro) self.assertIsNotNone(results) - def test_resp_stream(self): - @asyncio.coroutine - def stream_coro(): - results = [] - resp = yield from self.gc.submit("x + x", bindings={"x": 4}) - while True: - f = yield from resp.stream.read() - if f is None: - break - results.append(f) - self.assertEqual(results[0].data[0], 8) - self.loop.run_until_complete(stream_coro()) - - def test_resp_get(self): - @asyncio.coroutine - def get_coro(): - conn = yield from self.gc.submit("x + x", bindings={"x": 4}) - results = yield from conn.get() - self.assertEqual(results[0].data[0], 8) - self.loop.run_until_complete(get_coro()) - - def test_execute_error(self): - execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4}) - try: - self.loop.run_until_complete(execute) - error = False - except: - error = True - self.assertTrue(error) - - def test_session_gen(self): - execute = self.gc.execute("x + x", processor="session", bindings={"x": 4}) - results = self.loop.run_until_complete(execute) - self.assertEqual(results[0].data[0], 8) - - def test_session(self): - @asyncio.coroutine - def stream_coro(): - session = str(uuid.uuid4()) - resp = yield from self.gc.submit("x + x", bindings={"x": 4}, - session=session) - while True: - f = yield from resp.stream.read() - if f is None: - break - self.assertEqual(resp.session, session) - self.loop.run_until_complete(stream_coro()) - - -class WebSocketPoolTests(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop, - factory=AiohttpFactory) - - def tearDown(self): - self.loop.run_until_complete(self.pool.close()) - self.loop.close() - - def test_connect(self): - - @asyncio.coroutine - def conn(): - conn = yield from self.pool.acquire() - self.assertFalse(conn.closed) - self.pool.release(conn) - self.assertEqual(self.pool.num_active_conns, 0) - - self.loop.run_until_complete(conn()) - - def test_multi_connect(self): - - @asyncio.coroutine - def conn(): - conn1 = yield from self.pool.acquire() - conn2 = yield from self.pool.acquire() - self.assertFalse(conn1.closed) - self.assertFalse(conn2.closed) - self.pool.release(conn1) - self.assertEqual(self.pool.num_active_conns, 1) - self.pool.release(conn2) - self.assertEqual(self.pool.num_active_conns, 0) - - self.loop.run_until_complete(conn()) - - def test_timeout(self): - - @asyncio.coroutine - def conn(): - conn1 = yield from self.pool.acquire() - conn2 = yield from self.pool.acquire() - try: - conn3 = yield from self.pool.acquire() - timeout = False - except asyncio.TimeoutError: - timeout = True - self.assertTrue(timeout) - - self.loop.run_until_complete(conn()) - - def test_socket_reuse(self): - - @asyncio.coroutine - def conn(): - conn1 = yield from self.pool.acquire() - conn2 = yield from self.pool.acquire() - try: - conn3 = yield from self.pool.acquire() - timeout = False - except asyncio.TimeoutError: - timeout = True - self.assertTrue(timeout) - self.pool.release(conn2) - conn3 = yield from self.pool.acquire() - self.assertFalse(conn1.closed) - self.assertFalse(conn3.closed) - self.assertEqual(conn2, conn3) - - self.loop.run_until_complete(conn()) - - def test_socket_repare(self): - - @asyncio.coroutine - def conn(): - conn1 = yield from self.pool.acquire() - conn2 = yield from self.pool.acquire() - self.assertFalse(conn1.closed) - self.assertFalse(conn2.closed) - yield from conn1.close() - yield from conn2.close() - self.assertTrue(conn2.closed) - self.assertTrue(conn2.closed) - self.pool.release(conn1) - self.pool.release(conn2) - conn1 = yield from self.pool.acquire() - conn2 = yield from self.pool.acquire() - self.assertFalse(conn1.closed) - self.assertFalse(conn2.closed) - - self.loop.run_until_complete(conn()) - -class ContextMngrTest(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.pool = WebSocketPool(poolsize=1, loop=self.loop, - factory=AiohttpFactory, max_retries=0) - - def tearDown(self): - self.loop.run_until_complete(self.pool.close()) - self.loop.close() - - def test_connection_manager(self): - results = [] - @asyncio.coroutine - def go(): - with (yield from self.pool) as conn: - writer = GremlinWriter(conn) - conn = writer.write("1 + 1") - resp = GremlinResponse(conn, self.pool, loop=self.loop) - while True: - mssg = yield from resp.stream.read() - if mssg is None: - break - results.append(mssg) - conn = self.pool._pool.get_nowait() - self.assertTrue(conn.closed) - writer = GremlinWriter(conn) - try: - conn = yield from writer.write("1 + 1") - error = False - except RuntimeError: - error = True - self.assertTrue(error) - self.loop.run_until_complete(go()) - class CreateClientTests(unittest.TestCase): -- GitLab