diff --git a/aiogremlin/client.py b/aiogremlin/client.py index 83cc0ab5ac002f4259eba4c21a33b017c9510f63..e35ec163a7e8cd1ece1da0e0a1dae5dd6a772a69 100644 --- a/aiogremlin/client.py +++ b/aiogremlin/client.py @@ -5,7 +5,7 @@ import ssl import aiohttp -from aiogremlin.connection import (GremlinFactory, WebSocketSession, +from aiogremlin.connection import (GremlinFactory, GremlinClientWebSocketResponse) from aiogremlin.exceptions import RequestError from aiogremlin.log import logger, INFO @@ -23,7 +23,7 @@ def create_client(*, url='ws://localhost:8182/', loop=None, timeout=None, verbose=False, fill_pool=True, connector=None): if factory is None: - factory = WebSocketSession( + factory = aiohttp.ClientSession( connector=connector, ws_response_class=GremlinClientWebSocketResponse, loop=loop) @@ -79,7 +79,8 @@ class GremlinClient: else: self._connected = False self._conn = asyncio.async(self._connect(), loop=self._loop) - self._factory = factory or GremlinFactory(connector=self._connector) + self._factory = factory or GremlinFactory(connector=self._connector, + loop=self._loop) if verbose: logger.setLevel(INFO) @@ -101,8 +102,7 @@ class GremlinClient: def _connect(self): """ """ - connection = yield from self._factory.ws_connect(self.url, - loop=self._loop) + connection = yield from self._factory.ws_connect(self.url) self._connected = True return connection diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py index ddc78134836324871303b8219740378e0e12cb82..5fbb6bc3623ae9b421cb754c684ed19a4e414b51 100644 --- a/aiogremlin/connection.py +++ b/aiogremlin/connection.py @@ -5,21 +5,13 @@ import base64 import hashlib import os -from aiohttp import (client, hdrs, DataQueue, StreamParser, - WSServerHandshakeError, ClientSession, TCPConnector) -from aiohttp.errors import WSServerHandshakeError -from aiohttp.websocket import WS_KEY, Message -from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError -from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING, - MSG_PONG) -from aiohttp.websocket_client import (MsgType, closedMessage, - ClientWebSocketResponse) +import aiohttp +from aiohttp.websocket_client import ClientWebSocketResponse from aiogremlin.exceptions import SocketClientError from aiogremlin.log import INFO, logger -__all__ = ('WebSocketSession', 'GremlinFactory', - 'GremlinClientWebSocketResponse') +__all__ = ('GremlinFactory', 'GremlinClientWebSocketResponse') class GremlinClientWebSocketResponse(ClientWebSocketResponse): @@ -29,7 +21,8 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse): ClientWebSocketResponse.__init__(self, reader, writer, protocol, response, timeout, autoclose, autoping, loop) - self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop) + self._parser = aiohttp.StreamParser(buf=aiohttp.DataQueue(loop=loop), + loop=loop) @property def parser(self): @@ -55,7 +48,7 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse): self._response.close(force=True) return True - if msg.tp == MsgType.close: + if msg.tp == aiohttp.MsgType.close: self._close_code = msg.data self._response.close(force=True) return True @@ -97,139 +90,32 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse): @asyncio.coroutine def receive(self): msg = yield from super().receive() - if msg.tp == MsgType.binary: + if msg.tp == aiohttp.MsgType.binary: self.parser.feed_data(msg.data.decode()) - elif msg.tp == MsgType.text: + elif msg.tp == aiohttp.MsgType.text: self.parser.feed_data(msg.data.strip()) else: - if msg.tp == MsgType.close: + if msg.tp == aiohttp.MsgType.close: yield from ws.close() - elif msg.tp == MsgType.error: + elif msg.tp == aiohttp.MsgType.error: raise msg.data - elif msg.tp == MsgType.closed: + elif msg.tp == aiohttp.MsgType.closed: pass -# Basically cut and paste from aiohttp until merge/release of #374 -class WebSocketSession(ClientSession): - - def __init__(self, *, connector=None, loop=None, - cookies=None, headers=None, auth=None, - ws_response_class=GremlinClientWebSocketResponse): - - super().__init__(connector=connector, loop=loop, - cookies=cookies, headers=headers, auth=auth) - - self._ws_response_class = ws_response_class - - @asyncio.coroutine - def ws_connect(self, url, *, - protocols=(), - timeout=10.0, - autoclose=True, - autoping=True, - loop=None): - """Initiate websocket connection.""" - - sec_key = base64.b64encode(os.urandom(16)) - - headers = { - hdrs.UPGRADE: hdrs.WEBSOCKET, - hdrs.CONNECTION: hdrs.UPGRADE, - hdrs.SEC_WEBSOCKET_VERSION: '13', - hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), - } - if protocols: - headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) - - # send request - resp = yield from self.request('get', url, headers=headers, - read_until_eof=False) - - # check handshake - if resp.status != 101: - raise WSServerHandshakeError('Invalid response status') - - if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket': - raise WSServerHandshakeError('Invalid upgrade header') - - if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade': - raise WSServerHandshakeError('Invalid connection header') - - # key calculation - key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') - match = base64.b64encode( - hashlib.sha1(sec_key + WS_KEY).digest()).decode() - if key != match: - raise WSServerHandshakeError('Invalid challenge response') - - # websocket protocol - protocol = None - if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: - resp_protocols = [ - proto.strip() for proto in - resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] - - for proto in resp_protocols: - if proto in protocols: - protocol = proto - break - - reader = resp.connection.reader.set_parser(WebSocketParser) - writer = WebSocketWriter(resp.connection.writer, use_mask=True) - - return self._ws_response_class( - reader, writer, protocol, resp, timeout, autoclose, autoping, loop) - - def detach(self): - """Detach connector from session without closing the former. - Session is switched to closed state anyway. - """ - self._connector = None - - -# Cut and paste from aiohttp until merge/release of #374 -def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, - ws_response_class=None, autoclose=True, autoping=True, - loop=None): - if loop is None: - asyncio.get_event_loop() - if connector is None: - connector = TCPConnector(loop=loop, force_close=True) - if ws_response_class is None: - ws_response_class = GremlinClientWebSocketResponse - - ws_session = WebSocketSession(loop=loop, connector=connector, - ws_response_class=ws_response_class) - try: - resp = yield from ws_session.ws_connect( - url, - protocols=protocols, - timeout=timeout, - autoclose=autoclose, - autoping=autoping, - loop=loop) - return resp - - finally: - ws_session.detach() - - class GremlinFactory: - def __init__(self, connector=None, ws_response_class=None): + def __init__(self, connector=None, loop=None): self._connector = connector - if ws_response_class is None: - ws_response_class = GremlinClientWebSocketResponse - self._ws_response_class = ws_response_class + self._loop = loop or asyncio.get_event_loop() @asyncio.coroutine def ws_connect(self, url='ws://localhost:8182/', protocols=(), - autoclose=False, autoping=True, loop=None): + autoclose=False, autoping=True): try: - return (yield from ws_connect( + return (yield from aiohttp.ws_connect( url, protocols=protocols, connector=self._connector, - ws_response_class=self._ws_response_class, autoclose=True, - autoping=True, loop=loop)) - except WSServerHandshakeError as e: + ws_response_class=GremlinClientWebSocketResponse, + autoclose=True, autoping=True, loop=self._loop)) + except aiohttp.WSServerHandshakeError as e: raise SocketClientError(e.message) diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py index 328e780c9f69576acb34779258c3f7d23220aeba..2ef8496cbf9f7d7489e5dd94515e00a450b28332 100644 --- a/aiogremlin/pool.py +++ b/aiogremlin/pool.py @@ -18,14 +18,13 @@ class WebSocketPool: self.url = url if ws_response_class is None: ws_response_class = GremlinClientWebSocketResponse - self._factory = factory or GremlinFactory( - connector=connector, - ws_response_class=ws_response_class) self.poolsize = poolsize self.max_retries = max_retries self.timeout = timeout self._connected = False self._loop = loop or asyncio.get_event_loop() + self._factory = factory or GremlinFactory(connector=connector, + loop=self._loop) self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) self.active_conns = set() self.num_connecting = 0 @@ -38,9 +37,7 @@ class WebSocketPool: tasks = [] poolsize = self.poolsize for i in range(poolsize): - coro = self.factory.ws_connect( - self.url, - loop=self._loop) + coro = self.factory.ws_connect(self.url) task = asyncio.async(coro, loop=self._loop) tasks.append(task) for f in asyncio.as_completed(tasks, loop=self._loop): @@ -72,6 +69,10 @@ class WebSocketPool: @asyncio.coroutine def close(self): + try: + self._factory.close() + except AttributeError: + pass if not self._closed: if self.active_conns: yield from self._close_active_conns() @@ -109,9 +110,7 @@ class WebSocketPool: else: self.num_connecting += 1 try: - socket = yield from self.factory.ws_connect( - url, - loop=loop) + socket = yield from self.factory.ws_connect(url) finally: self.num_connecting -= 1 if not socket.closed: diff --git a/setup.py b/setup.py index 44e90480e6c0de45ee7d59cdf33a5a072620ca0d..2f609875498bf288b13a5ce36bd7ccbac7b5650b 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=open("README.txt").read(), packages=["aiogremlin", "tests"], install_requires=[ - "aiohttp==0.15.3" + "aiohttp==0.16.0" ], test_suite="tests", classifiers=[ diff --git a/tests/tests.py b/tests/tests.py index 2079052b02274e1f7906e5c59bf348a1d07fb6fb..6d31329e40fb0bc2c6e75927653da31ee62ed9be 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -5,10 +5,12 @@ import asyncio import itertools import unittest import uuid + +import aiohttp from aiogremlin import (GremlinClient, RequestError, GremlinServerError, SocketClientError, WebSocketPool, GremlinFactory, create_client, GremlinWriter, GremlinResponse, - WebSocketSession) + GremlinClientWebSocketResponse) class GremlinClientTests(unittest.TestCase): @@ -48,7 +50,7 @@ class GremlinClientPoolTests(unittest.TestCase): asyncio.set_event_loop(None) pool = WebSocketPool("ws://localhost:8182/", loop=self.loop) self.gc = GremlinClient(url="ws://localhost:8182/", - factory=GremlinFactory(), + factory=GremlinFactory(loop=self.loop), pool=pool, loop=self.loop) @@ -142,7 +144,7 @@ class WebSocketPoolTests(unittest.TestCase): poolsize=2, timeout=1, loop=self.loop, - factory=GremlinFactory()) + factory=GremlinFactory(loop=self.loop)) def tearDown(self): self.loop.run_until_complete(self.pool.close()) @@ -239,7 +241,7 @@ class ContextMngrTest(unittest.TestCase): self.pool = WebSocketPool("ws://localhost:8182/", poolsize=1, loop=self.loop, - factory=GremlinFactory(), + factory=GremlinFactory(loop=self.loop), max_retries=0) def tearDown(self): @@ -338,14 +340,18 @@ class GremlinClientPoolSessionTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) - pool = WebSocketPool("ws://localhost:8182/", - loop=self.loop, - factory=WebSocketSession(loop=self.loop)) + pool = WebSocketPool( + "ws://localhost:8182/", + loop=self.loop, + factory=aiohttp.ClientSession( + loop=self.loop, + ws_response_class=GremlinClientWebSocketResponse)) self.gc = GremlinClient("ws://localhost:8182/", pool=pool, loop=self.loop) def tearDown(self): + self.gc._pool._factory.close() self.loop.run_until_complete(self.gc.close()) self.loop.close()