From 91fc71cd3292c59166b9a3155b38996b21a24e69 Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Sun, 17 May 2015 14:49:44 -0400 Subject: [PATCH] working on some refactoring --- aiogremlin/__init__.py | 4 +- aiogremlin/abc.py | 27 +-- aiogremlin/client.py | 122 ++++++++---- aiogremlin/connection.py | 367 ++++++++++++++++++----------------- aiogremlin/contextmanager.py | 38 ++++ aiogremlin/log.py | 3 +- aiogremlin/pool.py | 139 +++++++++++++ aiogremlin/protocol.py | 39 +++- tests/tests.py | 104 ++++++---- 9 files changed, 557 insertions(+), 286 deletions(-) create mode 100644 aiogremlin/contextmanager.py create mode 100644 aiogremlin/pool.py diff --git a/aiogremlin/__init__.py b/aiogremlin/__init__.py index e8e3780..f6a17d5 100644 --- a/aiogremlin/__init__.py +++ b/aiogremlin/__init__.py @@ -1,8 +1,8 @@ from .abc import AbstractFactory, AbstractConnection -from .connection import (WebsocketPool, AiohttpFactory, BaseFactory, - BaseConnection) +from .connection import AiohttpFactory, BaseFactory, BaseConnection from .client import (create_client, GremlinClient, GremlinResponse, GremlinResponseStream) from .exceptions import RequestError, GremlinServerError, SocketClientError +from .pool import WebSocketPool from .protocol import GremlinWriter __version__ = "0.0.6" diff --git a/aiogremlin/abc.py b/aiogremlin/abc.py index 9400e07..dd8391e 100644 --- a/aiogremlin/abc.py +++ b/aiogremlin/abc.py @@ -18,32 +18,23 @@ class AbstractFactory(metaclass=ABCMeta): class AbstractConnection(metaclass=ABCMeta): - @abstractmethod - def feed_pool(self): - pass - - @abstractmethod - def release(self): - pass - - @property - @abstractmethod - def pool(self): - pass + # @property + # @abstractmethod + # def closed(self): + # pass - @property @abstractmethod - def closed(self): + def close(): pass @abstractmethod - def close(): + def _close(): pass @abstractmethod def send(self): pass - @abstractmethod - def _receive(self): - pass + # @abstractmethod + # def receive(self): + # pass diff --git a/aiogremlin/client.py b/aiogremlin/client.py index 9b40de2..518f61e 100644 --- a/aiogremlin/client.py +++ b/aiogremlin/client.py @@ -2,12 +2,14 @@ import asyncio import ssl -import uuid import aiohttp -from aiogremlin.connection import WebsocketPool -from aiogremlin.log import client_logger, INFO +from aiogremlin.connection import AiohttpFactory +from aiogremlin.contextmanager import ClientContextManager +from aiogremlin.exceptions import RequestError +from aiogremlin.log import logger, INFO +from aiogremlin.pool import WebSocketPool from aiogremlin.protocol import gremlin_response_parser, GremlinWriter @@ -16,14 +18,14 @@ def create_client(uri='ws://localhost:8182/', loop=None, ssl=None, protocol=None, lang="gremlin-groovy", op="eval", processor="", pool=None, factory=None, poolsize=10, timeout=None, verbose=False, **kwargs): - pool = WebsocketPool(uri, + pool = WebSocketPool(uri, factory=factory, poolsize=poolsize, timeout=timeout, loop=loop, verbose=verbose) - yield from pool.init_pool() + yield from pool.fill_pool() return GremlinClient(uri=uri, loop=loop, @@ -42,7 +44,7 @@ class GremlinClient: def __init__(self, uri='ws://localhost:8182/', loop=None, ssl=None, protocol=None, lang="gremlin-groovy", op="eval", processor="", pool=None, factory=None, poolsize=10, - timeout=None, verbose=True, **kwargs): + timeout=None, verbose=False, **kwargs): """ """ self.uri = uri @@ -60,11 +62,15 @@ class GremlinClient: self.processor = processor or "" self.poolsize = poolsize self.timeout = timeout - self.pool = pool or WebsocketPool(uri, factory=factory, - poolsize=poolsize, timeout=timeout, loop=self._loop) - self.factory = factory or self.pool.factory + self._pool = pool + self._factory = factory or AiohttpFactory + if self._pool is None: + self._connected = False + self._conn = asyncio.async(self._connect(), loop=self._loop) + else: + self._connected = self._pool._connected if verbose: - client_logger.setLevel(INFO) + logger.setLevel(INFO) @property def loop(self): @@ -72,45 +78,54 @@ class GremlinClient: @asyncio.coroutine def close(self): - yield from self.pool.close() + try: + if self._pool: + yield from self._pool.close() + elif self._connected: + yield from self._conn.close() + finally: + self._connected = False @asyncio.coroutine - def connect(self, **kwargs): + def _connect(self, **kwargs): """ """ - loop = kwargs.get("loop", "") or self.loop - connection = yield from self.factory.connect(self.uri, loop=loop, - **kwargs) + loop = kwargs.get("loop", "") or self._loop + connection = yield from self._factory.connect(self.uri, loop=loop) + self._connected = True return connection @asyncio.coroutine - def submit(self, gremlin, connection=None, bindings=None, lang=None, - op=None, processor=None, session=None, binary=True): + def _acquire(self, **kwargs): + if self._pool: + conn = yield from self._pool.acquire() + elif self._connected: + conn = self._conn + else: + conn = yield from self._conn + return conn + # Check here for error + # except Error: + # conn = yield from self._connect() + + @asyncio.coroutine + def submit(self, gremlin, conn=None, bindings=None, lang=None, op=None, + processor=None, session=None, binary=True): """ """ lang = lang or self.lang op = op or self.op processor = processor or self.processor - message = { - "requestId": str(uuid.uuid4()), - "op": op, - "processor": processor, - "args":{ - "gremlin": gremlin, - "bindings": bindings, - "language": lang - } - } - if processor == "session": - session = session or str(uuid.uuid4()) - message["args"]["session"] = session - client_logger.info( - "Session ID: {}".format(message["args"]["session"])) - if connection is None: - connection = yield from self.pool.connect(self.uri, loop=self.loop) - writer = GremlinWriter(connection) - connection = yield from writer.write(message, binary=binary) - return GremlinResponse(connection, session=session, loop=self._loop) + if conn is None: + conn = yield from self._acquire() + writer = GremlinWriter(conn) + conn = yield from writer.write(gremlin, bindings=bindings, + lang=lang, op=op, processor=processor, session=session, + binary=binary) + return GremlinResponse(conn, + self, + session=session, + loop=self._loop) @asyncio.coroutine def execute(self, gremlin, bindings=None, lang=None, @@ -127,10 +142,11 @@ class GremlinClient: class GremlinResponse: - def __init__(self, conn, session=None, loop=None): + def __init__(self, conn, client, session=None, loop=None): self._loop = loop or asyncio.get_event_loop() + self._client = client self._session = session - self._stream = GremlinResponseStream(conn, loop=self._loop) + self._stream = GremlinResponseStream(conn, self, loop=self._loop) @property def stream(self): @@ -156,11 +172,24 @@ class GremlinResponse: results.append(message) return results + # aioredis style + def __enter__(self): + raise RuntimeError( + "'yield from' should be used as a context manager expression") + + def __exit__(self, *args): + pass + + def __iter__(self): + yield from self._pool.create_pool() + return ClientContextManager(self) + class GremlinResponseStream: - def __init__(self, conn, loop=None): + def __init__(self, conn, resp, loop=None): self._conn = conn + self._resp = resp self._loop = loop or asyncio.get_event_loop() data_stream = aiohttp.DataQueue(loop=self._loop) self._stream = self._conn.parser.set_parser(gremlin_response_parser, @@ -170,13 +199,20 @@ class GremlinResponseStream: def read(self): # For 3.0.0.M9 # if self._stream.at_eof(): - # self._conn.feed_pool() + # self._pool.release(self._conn) # message = None # else: # This will be different 3.0.0.M9 - yield from self._conn._receive() + pool = self._resp._client._pool + try: + yield from self._conn.read() + except RequestError: + if pool: + pool.release(self._conn) + print("fed pool") if self._stream.is_eof(): - self._conn.feed_pool() + if pool: + pool.release(self._conn) message = None else: message = yield from self._stream.read() diff --git a/aiogremlin/connection.py b/aiogremlin/connection.py index 633a502..26cc132 100644 --- a/aiogremlin/connection.py +++ b/aiogremlin/connection.py @@ -1,128 +1,85 @@ """ """ import asyncio - -import aiohttp +import base64 +import hashlib +import os + +from aiohttp import (client, hdrs, DataQueue, StreamParser, + WSServerHandshakeError) +from aiohttp.errors import WSServerHandshakeError +from aiohttp.websocket import WS_KEY, Message +from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError +from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING, + MSG_PONG) +from aiohttp.websocket_client import (MsgType, closedMessage, + ClientWebSocketResponse) from aiogremlin.abc import AbstractFactory, AbstractConnection from aiogremlin.exceptions import SocketClientError -from aiogremlin.log import INFO, conn_logger - - -class WebsocketPool: - - def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10, - max_retries=10, timeout=None, loop=None, verbose=False): - """ - """ - self.uri = uri - self._factory = factory or AiohttpFactory - self.poolsize = poolsize - self.max_retries = max_retries - self.timeout = timeout - self._loop = loop or asyncio.get_event_loop() - self.pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) - self.active_conns = set() - self.num_connecting = 0 - self._closed = False - if verbose: - conn_logger.setLevel(INFO) - - @asyncio.coroutine - def init_pool(self): - for i in range(self.poolsize): - conn = yield from self.factory.connect(self.uri, pool=self, - loop=self._loop) - self._put(conn) - - @property - def loop(self): - return self._loop - - @property - def factory(self): - return self._factory - - @property - def closed(self): - return self._closed - - @property - def num_active_conns(self): - return len(self.active_conns) - - def feed_pool(self, conn): - if self._closed: - raise RuntimeError("WebsocketPool is closed.") - self.active_conns.discard(conn) - self._put(conn) - - @asyncio.coroutine - def close(self): - if not self._closed: - if self.active_conns: - yield from self._close_active_conns() - yield from self._purge_pool() - self._closed = True - - @asyncio.coroutine - def _close_active_conns(self): - tasks = [asyncio.async(conn.close(), loop=self.loop) for conn - in self.active_conns] - yield from asyncio.wait(tasks, loop=self.loop) - - @asyncio.coroutine - def _purge_pool(self): - while True: - try: - conn = self.pool.get_nowait() - except asyncio.QueueEmpty: +from aiogremlin.log import INFO, logger + + +# This is temporary until aiohttp pull #367 is merged/released. +@asyncio.coroutine +def ws_connect(url, protocols=(), timeout=10.0, connector=None, + response_class=None, autoclose=True, autoping=True, loop=None): + """Initiate websocket connection.""" + if loop is None: + loop = asyncio.get_event_loop() + + sec_key = base64.b64encode(os.urandom(16)) + + headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_VERSION: '13', + hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), + } + if protocols: + headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) + + # send request + resp = yield from client.request( + 'get', url, headers=headers, + read_until_eof=False, + connector=connector, loop=loop) + + # check handshake + if resp.status != 101: + raise WSServerHandshakeError('Invalid response status') + + if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket': + raise WSServerHandshakeError('Invalid upgrade header') + + if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade': + raise WSServerHandshakeError('Invalid connection header') + + # key calculation + key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() + if key != match: + raise WSServerHandshakeError('Invalid challenge response') + + # websocket protocol + protocol = None + if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: + resp_protocols = [proto.strip() for proto in + resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')] + + for proto in resp_protocols: + if proto in protocols: + protocol = proto break - else: - yield from conn.close() - @asyncio.coroutine - def connect(self, uri=None, loop=None, num_retries=None): - if self._closed: - raise RuntimeError("WebsocketPool is closed.") - if num_retries is None: - num_retries = self.max_retries - uri = uri or self.uri - loop = loop or self.loop - if not self.pool.empty(): - socket = self.pool.get_nowait() - conn_logger.info("Reusing socket: {} at {}".format(socket, uri)) - elif self.num_active_conns + self.num_connecting >= self.poolsize: - conn_logger.info("Waiting for socket...") - socket = yield from asyncio.wait_for(self.pool.get(), - self.timeout, loop=loop) - conn_logger.info("Socket acquired: {} at {}".format(socket, uri)) - else: - self.num_connecting += 1 - try: - socket = yield from self.factory.connect(uri, pool=self, - loop=loop) - finally: - self.num_connecting -= 1 - if not socket.closed: - conn_logger.info("New connection on socket: {} at {}".format( - socket, uri)) - self.active_conns.add(socket) - # Untested. - elif num_retries > 0: - conn_logger.warning("Got bad socket, retry...") - socket = yield from self.connect(uri, loop, num_retries - 1) - else: - raise RuntimeError("Unable to connect, max retries exceeded.") - return socket + reader = resp.connection.reader.set_parser(WebSocketParser) + writer = WebSocketWriter(resp.connection.writer, use_mask=True) - def _put(self, socket): - try: - self.pool.put_nowait(socket) - except asyncio.QueueFull: - pass - # This should be - not working - # yield from socket.release() + if response_class is None: + response_class = ClientWebSocketResponse + + return response_class( + reader, writer, protocol, resp, timeout, autoclose, autoping, loop) class BaseFactory(AbstractFactory): @@ -141,107 +98,163 @@ class AiohttpFactory(BaseFactory): if pool: loop = loop or pool.loop try: - socket = yield from aiohttp.ws_connect(uri, protocols=protocols, - connector=connector, autoclose=False, autoping=True, - loop=loop) - except aiohttp.WSServerHandshakeError as e: + return (yield from ws_connect( + uri, protocols=protocols, connector=connector, + response_class=GremlinClientWebSocketResponse, + autoclose=True, autoping=True, loop=loop)) + except WSServerHandshakeError as e: raise SocketClientError(e.message) - return AiohttpConnection(socket, pool, loop=loop) class BaseConnection(AbstractConnection): - def __init__(self, socket, pool=None, loop=None): - self.socket = socket + def __init__(self, loop=None): self._loop = loop or asyncio.get_event_loop() - self._pool = pool - self._parser = aiohttp.StreamParser( - buf=aiohttp.DataQueue(loop=self._loop), loop=self._loop) + self._parser = StreamParser( + buf=DataQueue(loop=self._loop), loop=self._loop) @property def parser(self): return self._parser - def feed_pool(self): - if self.pool: - if self in self.pool.active_conns: - self.pool.feed_pool(self) - @asyncio.coroutine - def release(self): - try: - yield from self.close() - finally: - if self in self.pool.active_conns: - self.pool.active_conns.discard(self) - - @property - def pool(self): - return self._pool +class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse): + def __init__(self, reader, writer, protocol, response, timeout, autoclose, + autoping, loop): + BaseConnection.__init__(self, loop=loop) + ClientWebSocketResponse.__init__(self, reader, writer, protocol, + response, timeout, autoclose, autoping, loop) -class AiohttpConnection(BaseConnection): + @asyncio.coroutine + def close(self, *, code=1000, message=b''): + if not self._closed: + self._closed = True + closed = self._close() + if closed: + return True + while True: + try: + msg = yield from asyncio.wait_for( + self._reader.read(), self._timeout, loop=self._loop) + except asyncio.CancelledError: + self._close_code = 1006 + self._response.close(force=True) + raise + except Exception as exc: + self._close_code = 1006 + self._exception = exc + self._response.close(force=True) + return True + + if msg.tp == MsgType.close: + self._close_code = msg.data + self._response.close(force=True) + return True + else: + return False - @property - def closed(self): - return self.socket.closed + def _close(self): + try: + self._writer.close(code, message) + except asyncio.CancelledError: + self._close_code = 1006 + self._response.close(force=True) + raise + except Exception as exc: + self._close_code = 1006 + self._exception = exc + self._response.close(force=True) + return True - @asyncio.coroutine - def close(self): - if not self.socket._closed: - try: - yield from self.socket.close() - finally: - # Socket should close despite errors. - self._closed = True + if self._closing: + self._response.close(force=True) + return True - @asyncio.coroutine + # @asyncio.coroutine def send(self, message, binary=True): if binary: - method = self.socket.send_bytes + method = self.send_bytes else: - method = self.socket.send_str + method = self.send_str try: method(message) except RuntimeError: # Socket closed. - yield from self.release() raise except TypeError: # Bytes/string input error. - yield from self.release() raise @asyncio.coroutine - def _receive(self): + def read(self): """Implements a dispatcher using the aiohttp websocket protocol.""" try: - message = yield from self.socket.receive() + message = yield from self.receive() except (asyncio.CancelledError, asyncio.TimeoutError): - yield from self.release() - raise - except RuntimeError: - yield from self.release() + # Hmm maybe don't close here + yield from self.close() raise - if message.tp == aiohttp.MsgType.binary: + if message.tp == MsgType.binary: try: self.parser.feed_data(message.data.decode()) except Exception: - self.release() + # Hmm maybe don't close here + yield from self.close() raise - elif message.tp == aiohttp.MsgType.text: + elif message.tp == MsgType.text: try: self.parser.feed_data(message.data.strip()) except Exception: - self.release() + # Hmm maybe don't close here + yield from self.close() raise else: - try: - if message.tp == aiohttp.MsgType.close: - raise RuntimeError("Socket connection closed by server.") - elif message.tp == aiohttp.MsgType.error: - raise SocketClientError(self.socket.exception()) - elif message.tp == aiohttp.MsgType.closed: - raise RuntimeError("Socket closed.") - finally: - yield from self.release() + if message.tp == MsgType.close: + raise RuntimeError("Socket connection closed by server.") + elif message.tp == MsgType.error: + raise SocketClientError(self.socket.exception()) + elif message.tp == MsgType.closed: + raise RuntimeError("Socket closed.") + + # @asyncio.coroutine + # def receive(self): + # if self._waiting: + # raise RuntimeError('Concurrent call to receive() is not allowed') + # + # self._waiting = True + # try: + # while True: + # if self._closed: + # return closedMessage + # + # try: + # msg = yield from self._reader.read() + # except (asyncio.CancelledError, asyncio.TimeoutError): + # raise + # except WebSocketError as exc: + # self._close_code = exc.code + # yield from self.close(code=exc.code) + # return Message(MsgType.error, exc, None) + # except Exception as exc: + # self._exception = exc + # self._closing = True + # self._close_code = 1006 + # yield from self.close() + # return Message(MsgType.error, exc, None) + # + # if msg.tp == MsgType.close: + # self._closing = True + # self._close_code = msg.data + # if not self._closed and self._autoclose: + # yield from self.close() + # return msg + # elif not self._closed: + # if msg.tp == MsgType.ping and self._autoping: + # self._writer.pong(msg.data) + # elif msg.tp == MsgType.pong and self._autoping: + # continue + # else: + # return msg + # finally: + # self._waiting = False diff --git a/aiogremlin/contextmanager.py b/aiogremlin/contextmanager.py new file mode 100644 index 0000000..f5c1f3a --- /dev/null +++ b/aiogremlin/contextmanager.py @@ -0,0 +1,38 @@ +class ClientContextManager: + + __slots__ = ("_client") + + def __init__(self, client): + self._client = client + + def __enter__(self): + return self._client + + def __exit__(self, *args): + try: + yield from self._client.close() + finally: + self._client = None + + +class ConnectionContextManager: + + __slots__ = ("_conn", "_pool") + + def __init__(self, conn, pool): + self._conn = conn + self._pool = pool + + def __enter__(self): + return self._conn + + def __exit__(self, exception_type, exception_value, traceback): + print("in __exit__") + import ipdb; ipdb.set_trace() + print("agains") + try: + print("hello") + yield from self._conn.release() + finally: + self._conn = None + self._pool = None diff --git a/aiogremlin/log.py b/aiogremlin/log.py index 8a4a47d..4db9367 100644 --- a/aiogremlin/log.py +++ b/aiogremlin/log.py @@ -8,5 +8,4 @@ logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -client_logger = logging.getLogger("aiogremlin.client") -conn_logger = logging.getLogger("aiogremlin.connection") +logger = logging.getLogger("aiogremlin") diff --git a/aiogremlin/pool.py b/aiogremlin/pool.py new file mode 100644 index 0000000..8a7b6c6 --- /dev/null +++ b/aiogremlin/pool.py @@ -0,0 +1,139 @@ +import asyncio + +from aiogremlin.connection import AiohttpFactory +from aiogremlin.contextmanager import ConnectionContextManager +from aiogremlin.log import logger + + +def create_pool(): + pass + + +class WebSocketPool: + + def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10, + max_retries=10, timeout=None, loop=None, verbose=False): + """ + """ + self.uri = uri + self._factory = factory or AiohttpFactory + self.poolsize = poolsize + self.max_retries = max_retries + self.timeout = timeout + self._connected = False + self._loop = loop or asyncio.get_event_loop() + self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) + self.active_conns = set() + self.num_connecting = 0 + self._closed = False + if verbose: + logger.setLevel(INFO) + + @asyncio.coroutine + def fill_pool(self): + for i in range(self.poolsize): + conn = yield from self.factory.connect(self.uri, pool=self, + loop=self._loop) + self._put(conn) + self._connected = True + + @property + def loop(self): + return self._loop + + @property + def factory(self): + return self._factory + + @property + def closed(self): + return self._closed + + @property + def num_active_conns(self): + return len(self.active_conns) + + def release(self, conn): + if self._closed: + raise RuntimeError("WebsocketPool is closed.") + self.active_conns.discard(conn) + self._put(conn) + + @asyncio.coroutine + def close(self): + if not self._closed: + if self.active_conns: + yield from self._close_active_conns() + yield from self._purge_pool() + self._closed = True + + @asyncio.coroutine + def _close_active_conns(self): + tasks = [asyncio.async(conn.close(), loop=self.loop) for conn + in self.active_conns] + yield from asyncio.wait(tasks, loop=self.loop) + + @asyncio.coroutine + def _purge_pool(self): + while True: + try: + conn = self._pool.get_nowait() + except asyncio.QueueEmpty: + break + else: + yield from conn.close() + + @asyncio.coroutine + def acquire(self, uri=None, loop=None, num_retries=None): + if self._closed: + raise RuntimeError("WebsocketPool is closed.") + if num_retries is None: + num_retries = self.max_retries + uri = uri or self.uri + loop = loop or self.loop + if not self._pool.empty(): + socket = self._pool.get_nowait() + logger.info("Reusing socket: {} at {}".format(socket, uri)) + elif self.num_active_conns + self.num_connecting >= self.poolsize: + logger.info("Waiting for socket...") + socket = yield from asyncio.wait_for(self._pool.get(), + self.timeout, loop=loop) + logger.info("Socket acquired: {} at {}".format(socket, uri)) + else: + self.num_connecting += 1 + try: + socket = yield from self.factory.connect(uri, pool=self, + loop=loop) + finally: + self.num_connecting -= 1 + if not socket.closed: + logger.info("New connection on socket: {} at {}".format( + socket, uri)) + self.active_conns.add(socket) + # Untested. + elif num_retries > 0: + logger.warning("Got bad socket, retry...") + socket = yield from self.acquire(uri, loop, num_retries - 1) + else: + raise RuntimeError("Unable to connect, max retries exceeded.") + return socket + + def _put(self, socket): + try: + self._pool.put_nowait(socket) + except asyncio.QueueFull: + pass + # This should be - not working + # yield from socket.release() + + # aioredis style + def __enter__(self): + raise RuntimeError( + "'yield from' should be used as a context manager expression") + + def __exit__(self, *args): + pass + + def __iter__(self): + conn = yield from self.acquire() + return ConnectionContextManager(conn, self) diff --git a/aiogremlin/protocol.py b/aiogremlin/protocol.py index 897acac..9144cb7 100644 --- a/aiogremlin/protocol.py +++ b/aiogremlin/protocol.py @@ -2,10 +2,15 @@ import asyncio import collections +import uuid -import ujson +try: + import ujson as json +except ImportError: + import json from aiogremlin.exceptions import RequestError, GremlinServerError +from aiogremlin.log import logger Message = collections.namedtuple("Message", ["status_code", "data", "message", @@ -15,7 +20,7 @@ Message = collections.namedtuple("Message", ["status_code", "data", "message", def gremlin_response_parser(out, buf): while True: message = yield - message = ujson.loads(message) + message = json.loads(message) message = Message(message["status"]["code"], message["result"]["data"], message["result"]["meta"], @@ -46,11 +51,15 @@ class GremlinWriter: self._connection = connection @asyncio.coroutine - def write(self, message, binary=True, mime_type="application/json"): - message = ujson.dumps(message) + def write(self, gremlin, bindings=None, lang="gremlin-groovy", op="eval", + processor="", session=None, binary=True, + mime_type="application/json"): + message = self._prepare_message(gremlin, bindings=bindings, + lang=lang, op=op, processor=processor, session=session) + message = json.dumps(message) if binary: message = self._set_message_header(message, mime_type) - yield from self._connection.send(message, binary) + self._connection.send(message, binary) return self._connection @staticmethod @@ -61,3 +70,23 @@ class GremlinWriter: else: raise ValueError("Unknown mime type.") return b"".join([mime_len, mime_type, bytes(message, "utf-8")]) + + @staticmethod + def _prepare_message(gremlin, bindings=None, lang="gremlin-groovy", op="eval", + processor="", session=None): + message = { + "requestId": str(uuid.uuid4()), + "op": op, + "processor": processor, + "args":{ + "gremlin": gremlin, + "bindings": bindings, + "language": lang + } + } + if processor == "session": + session = session or str(uuid.uuid4()) + message["args"]["session"] = session + logger.info( + "Session ID: {}".format(message["args"]["session"])) + return message diff --git a/tests/tests.py b/tests/tests.py index 5ba9c13..28c2d0f 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -6,16 +6,19 @@ import itertools import unittest import uuid from aiogremlin import (GremlinClient, RequestError, GremlinServerError, - SocketClientError, WebsocketPool, AiohttpFactory, create_client) + SocketClientError, WebSocketPool, AiohttpFactory, create_client, + GremlinWriter, GremlinResponse) # -class GremlinClientTests(unittest.TestCase): +class GremlinClientPoolTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) self.gc = GremlinClient("ws://localhost:8182/", - factory=AiohttpFactory, loop=self.loop) + factory=AiohttpFactory, pool=WebSocketPool("ws://localhost:8182/", + loop=self.loop), + loop=self.loop) def tearDown(self): self.loop.run_until_complete(self.gc.close()) @@ -24,7 +27,7 @@ class GremlinClientTests(unittest.TestCase): def test_connection(self): @asyncio.coroutine def conn_coro(): - conn = yield from self.gc.connect() + conn = yield from self.gc._acquire() self.assertFalse(conn.closed) return conn conn = self.loop.run_until_complete(conn_coro()) @@ -96,12 +99,12 @@ class GremlinClientTests(unittest.TestCase): self.loop.run_until_complete(stream_coro()) -class WebsocketPoolTests(unittest.TestCase): +class WebSocketPoolTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) - self.pool = WebsocketPool(poolsize=2, timeout=1, loop=self.loop, + self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop, factory=AiohttpFactory) def tearDown(self): @@ -112,10 +115,9 @@ class WebsocketPoolTests(unittest.TestCase): @asyncio.coroutine def conn(): - conn = yield from self.pool.connect() - self.assertIsNotNone(conn.socket) + conn = yield from self.pool.acquire() self.assertFalse(conn.closed) - conn.feed_pool() + self.pool.release(conn) self.assertEqual(self.pool.num_active_conns, 0) self.loop.run_until_complete(conn()) @@ -124,15 +126,13 @@ class WebsocketPoolTests(unittest.TestCase): @asyncio.coroutine def conn(): - conn1 = yield from self.pool.connect() - conn2 = yield from self.pool.connect() - self.assertIsNotNone(conn1.socket) + conn1 = yield from self.pool.acquire() + conn2 = yield from self.pool.acquire() self.assertFalse(conn1.closed) - self.assertIsNotNone(conn2.socket) self.assertFalse(conn2.closed) - conn1.feed_pool() + self.pool.release(conn1) self.assertEqual(self.pool.num_active_conns, 1) - conn2.feed_pool() + self.pool.release(conn2) self.assertEqual(self.pool.num_active_conns, 0) self.loop.run_until_complete(conn()) @@ -141,10 +141,10 @@ class WebsocketPoolTests(unittest.TestCase): @asyncio.coroutine def conn(): - conn1 = yield from self.pool.connect() - conn2 = yield from self.pool.connect() + conn1 = yield from self.pool.acquire() + conn2 = yield from self.pool.acquire() try: - conn3 = yield from self.pool.connect() + conn3 = yield from self.pool.acquire() timeout = False except asyncio.TimeoutError: timeout = True @@ -156,21 +156,19 @@ class WebsocketPoolTests(unittest.TestCase): @asyncio.coroutine def conn(): - conn1 = yield from self.pool.connect() - conn2 = yield from self.pool.connect() + conn1 = yield from self.pool.acquire() + conn2 = yield from self.pool.acquire() try: - conn3 = yield from self.pool.connect() + conn3 = yield from self.pool.acquire() timeout = False except asyncio.TimeoutError: timeout = True self.assertTrue(timeout) - conn2.feed_pool() - conn3 = yield from self.pool.connect() - self.assertIsNotNone(conn1.socket) + self.pool.release(conn2) + conn3 = yield from self.pool.acquire() self.assertFalse(conn1.closed) - self.assertIsNotNone(conn3.socket) self.assertFalse(conn3.closed) - self.assertEqual(conn2.socket, conn3.socket) + self.assertEqual(conn2, conn3) self.loop.run_until_complete(conn()) @@ -178,26 +176,53 @@ class WebsocketPoolTests(unittest.TestCase): @asyncio.coroutine def conn(): - conn1 = yield from self.pool.connect() - conn2 = yield from self.pool.connect() - self.assertIsNotNone(conn1.socket) + conn1 = yield from self.pool.acquire() + conn2 = yield from self.pool.acquire() self.assertFalse(conn1.closed) - self.assertIsNotNone(conn2.socket) self.assertFalse(conn2.closed) - yield from conn1.socket.close() - yield from conn2.socket.close() + yield from conn1.close() + yield from conn2.close() self.assertTrue(conn2.closed) self.assertTrue(conn2.closed) - conn1.feed_pool() - conn2.feed_pool() - conn1 = yield from self.pool.connect() - conn2 = yield from self.pool.connect() - self.assertIsNotNone(conn1.socket) + self.pool.release(conn1) + self.pool.release(conn2) + conn1 = yield from self.pool.acquire() + conn2 = yield from self.pool.acquire() self.assertFalse(conn1.closed) - self.assertIsNotNone(conn2.socket) self.assertFalse(conn2.closed) self.loop.run_until_complete(conn()) +# +# class ContextMngrTest(unittest.TestCase): +# +# def setUp(self): +# self.loop = asyncio.new_event_loop() +# asyncio.set_event_loop(None) +# self.pool = WebSocketPool(poolsize=1, loop=self.loop, +# factory=AiohttpFactory) +# +# def tearDown(self): +# self.loop.run_until_complete(self.pool.close()) +# self.loop.close() +# +# def test_connection_manager(self): +# results = [] +# @asyncio.coroutine +# def go(): +# +# # import ipdb; ipdb.set_trace() +# with (yield from self.pool) as conn: +# writer = GremlinWriter(conn) +# conn = yield from writer.write("1 + 1") +# resp = GremlinResponse(conn, loop=self.loop) +# while True: +# mssg = yield from resp.stream.read() +# if mssg is None: +# break +# results.append(mssg) +# conn = self.pool._pool.get_nowait() +# self.assertTrue(conn.closed) +# self.loop.run_until_complete(go()) class CreateClientTests(unittest.TestCase): @@ -206,7 +231,8 @@ class CreateClientTests(unittest.TestCase): @asyncio.coroutine def go(loop): gc = yield from create_client(poolsize=10, loop=loop) - self.assertEqual(gc.pool.pool.qsize(), 10) + self.assertEqual(gc._pool._pool.qsize(), 10) + yield from gc.close() loop = asyncio.new_event_loop() asyncio.set_event_loop(None) -- GitLab