diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py index f249ee855304ced69d18230027516bd0371cbfd0..f68e4308816236dc4aca446638141d593396c954 100644 --- a/aiogremlin/connection.py +++ b/aiogremlin/connection.py @@ -22,12 +22,105 @@ __all__ = ('WebSocketSession', 'GremlinFactory', 'GremlinClientWebSocketResponse') +class GremlinClientWebSocketResponse(ClientWebSocketResponse): + + def __init__(self, reader, writer, protocol, response, timeout, autoclose, + autoping, loop): + ClientWebSocketResponse.__init__(self, reader, writer, protocol, + response, timeout, autoclose, + autoping, loop) + self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop) + + @property + def parser(self): + return self._parser + + @property + def closed(self): + """Required by ABC.""" + return self._closed + + @asyncio.coroutine + def close(self, *, code=1000, message=b''): + if not self._closed: + do_close = self._close() + if do_close: + return True + while True: + try: + msg = yield from asyncio.wait_for( + self._reader.read(), self._timeout, loop=self._loop) + except asyncio.CancelledError: + self._close_code = 1006 + self._response.close(force=True) + raise + except Exception as exc: + self._close_code = 1006 + self._exception = exc + self._response.close(force=True) + return True + + if msg.tp == MsgType.close: + self._close_code = msg.data + self._response.close(force=True) + return True + else: + return False + + def _close(self, code=1000, message=b''): + self._closed = True + try: + self._writer.close(code, message) + except asyncio.CancelledError: + self._close_code = 1006 + self._response.close(force=True) + raise + except Exception as exc: + self._close_code = 1006 + self._exception = exc + self._response.close(force=True) + return True + + if self._closing: + self._response.close(force=True) + return True + + def send(self, message, binary=True): + if binary: + method = self.send_bytes + else: + method = self.send_str + try: + method(message) + except RuntimeError: + # Socket closed. + raise + except TypeError: + # Bytes/string input error. + raise + + @asyncio.coroutine + def receive(self): + msg = yield from super().receive() + if msg.tp == MsgType.binary: + self.parser.feed_data(msg.data.decode()) + elif msg.tp == MsgType.text: + self.parser.feed_data(msg.data.strip()) + else: + if msg.tp == MsgType.close: + yield from ws.close() + elif msg.tp == MsgType.error: + raise msg[1] + elif msg.tp == MsgType.closed: + pass + + # Basically cut and paste from aiohttp until merge/release of #374 class WebSocketSession(ClientSession): def __init__(self, *, connector=None, loop=None, cookies=None, headers=None, auth=None, - ws_response_class=None): + ws_response_class=GremlinClientWebSocketResponse): super().__init__(connector=connector, loop=loop, cookies=cookies, headers=headers, auth=auth) @@ -40,7 +133,6 @@ class WebSocketSession(ClientSession): timeout=10.0, autoclose=True, autoping=True, - ws_response_class=None, loop=None): """Initiate websocket connection.""" @@ -91,11 +183,7 @@ class WebSocketSession(ClientSession): reader = resp.connection.reader.set_parser(WebSocketParser) writer = WebSocketWriter(resp.connection.writer, use_mask=True) - if ws_response_class is None: - ws_response_class = (self._ws_response_class or - ClientWebSocketResponse) - - return ws_response_class( + return self._ws_response_class( reader, writer, protocol, resp, timeout, autoclose, autoping, loop) def detach(self): @@ -113,14 +201,16 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, asyncio.get_event_loop() if connector is None: connector = TCPConnector(loop=loop, force_close=True) + if ws_response_class is None: + ws_response_class = GremlinClientWebSocketResponse - ws_session = WebSocketSession(loop=loop, connector=connector) + ws_session = WebSocketSession(loop=loop, connector=connector, + ws_response_class=ws_response_class) try: resp = yield from ws_session.ws_connect( url, protocols=protocols, timeout=timeout, - ws_response_class=ws_response_class, autoclose=autoclose, autoping=autoping, loop=loop) @@ -132,114 +222,19 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, class GremlinFactory: - def __init__(self, connector=None): + def __init__(self, connector=None, ws_response_class=None): self._connector = connector + if ws_response_class is None: + ws_response_class = GremlinClientWebSocketResponse + self._ws_response_class = ws_response_class @asyncio.coroutine def ws_connect(self, url='ws://localhost:8182/', protocols=(), - connector=None, autoclose=False, autoping=True, - ws_response_class=None, loop=None): - if connector is None: - connector = self._connector - if ws_response_class is None: - ws_response_class = GremlinClientWebSocketResponse + autoclose=False, autoping=True, loop=None): try: return (yield from ws_connect( - url, protocols=protocols, connector=connector, - ws_response_class=ws_response_class, autoclose=True, + url, protocols=protocols, connector=self._connector, + ws_response_class=self._ws_response_class, autoclose=True, autoping=True, loop=loop)) except WSServerHandshakeError as e: raise SocketClientError(e.message) - - -class GremlinClientWebSocketResponse(ClientWebSocketResponse): - - def __init__(self, reader, writer, protocol, response, timeout, autoclose, - autoping, loop): - ClientWebSocketResponse.__init__(self, reader, writer, protocol, - response, timeout, autoclose, - autoping, loop) - self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop) - - @property - def parser(self): - return self._parser - - @property - def closed(self): - """Required by ABC.""" - return self._closed - - @asyncio.coroutine - def close(self, *, code=1000, message=b''): - if not self._closed: - do_close = self._close() - if do_close: - return True - while True: - try: - msg = yield from asyncio.wait_for( - self._reader.read(), self._timeout, loop=self._loop) - except asyncio.CancelledError: - self._close_code = 1006 - self._response.close(force=True) - raise - except Exception as exc: - self._close_code = 1006 - self._exception = exc - self._response.close(force=True) - return True - - if msg.tp == MsgType.close: - self._close_code = msg.data - self._response.close(force=True) - return True - else: - return False - - def _close(self, code=1000, message=b''): - self._closed = True - try: - self._writer.close(code, message) - except asyncio.CancelledError: - self._close_code = 1006 - self._response.close(force=True) - raise - except Exception as exc: - self._close_code = 1006 - self._exception = exc - self._response.close(force=True) - return True - - if self._closing: - self._response.close(force=True) - return True - - def send(self, message, binary=True): - if binary: - method = self.send_bytes - else: - method = self.send_str - try: - method(message) - except RuntimeError: - # Socket closed. - raise - except TypeError: - # Bytes/string input error. - raise - - @asyncio.coroutine - def receive(self): - msg = yield from super().receive() - if msg.tp == MsgType.binary: - self.parser.feed_data(msg.data.decode()) - elif msg.tp == MsgType.text: - self.parser.feed_data(msg.data.strip()) - else: - if msg.tp == MsgType.close: - yield from ws.close() - elif msg.tp == MsgType.error: - raise msg[1] - elif msg.tp == MsgType.closed: - pass diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py index d53314ee6cec714dd0c2da8b126a74540e2df8b0..c6e90b05d28a368f06fb62476827a395bcb30488 100644 --- a/aiogremlin/pool.py +++ b/aiogremlin/pool.py @@ -16,7 +16,11 @@ class WebSocketPool: """ """ self.url = url - self._factory = factory or GremlinFactory(connector=connector) + self._ws_response_class = (ws_response_class or + GremlinClientWebSocketResponse) + self._factory = factory or GremlinFactory( + connector=connector, + ws_response_class=self._ws_response_class) self.poolsize = poolsize self.max_retries = max_retries self.timeout = timeout @@ -25,8 +29,6 @@ class WebSocketPool: self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) self.active_conns = set() self.num_connecting = 0 - self._response_class = (ws_response_class or - GremlinClientWebSocketResponse) self._closed = False if verbose: logger.setLevel(INFO) @@ -38,7 +40,6 @@ class WebSocketPool: for i in range(poolsize): coro = self.factory.ws_connect( self.url, - ws_response_class=self._response_class, loop=self._loop) task = asyncio.async(coro, loop=self._loop) tasks.append(task) @@ -110,7 +111,6 @@ class WebSocketPool: try: socket = yield from self.factory.ws_connect( url, - ws_response_class=self._response_class, loop=loop) finally: self.num_connecting -= 1