diff --git a/goblin/app.py b/goblin/app.py index 6bd6fc3eea9181901b958a0f77ba886940ef9481..25c12da1b37baa40a180120c301d0ced09108db7 100644 --- a/goblin/app.py +++ b/goblin/app.py @@ -40,11 +40,11 @@ class Goblin: :param dict config: Config parameters for application """ - def __init__(self, cluster, *, translator=None, aliases=None, + def __init__(self, cluster, *, translator=None, traversal_source=None, get_hashable_id=None): self._cluster = cluster self._loop = self._cluster._loop - self._aliases = aliases + self._traversal_source = traversal_source self._transactions = None self._cluster = cluster self._vertices = collections.defaultdict( @@ -121,14 +121,14 @@ class Goblin: self._get_hashable_id, transactions, use_session=use_session, - aliases=self._aliases) + traversal_source=self._traversal_source) async def supports_transactions(self): if self._transactions is None: conn = await self._cluster.get_connection() stream = await conn.submit( 'graph.features().graph().supportsTransactions()', - aliases=self._aliases) + traversal_source=self._traversal_source) msg = await stream.fetch_data() stream.close() self._transactions = msg diff --git a/goblin/driver/client.py b/goblin/driver/client.py index 75d0694f4b5f4fe0a88b5cec3a379daffd322e4f..05a460a9f2872a12da59ec7583eb500bc9709b35 100644 --- a/goblin/driver/client.py +++ b/goblin/driver/client.py @@ -13,17 +13,16 @@ class Client: *, bindings=None, lang=None, - aliases=None, + traversal_source=None, session=None): conn = await self.cluster.get_connection() resp = await conn.submit(gremlin, bindings=bindings, lang=lang, - aliases=aliases, + traversal_source=traversal_source, session=session) self._loop.create_task(conn.release_task(resp)) return resp async def close(self): - await self._cluster.close() self._cluster = None diff --git a/goblin/driver/cluster.py b/goblin/driver/cluster.py index ad7f9274a828fc988458dcc1e10a206d4bd0703b..798d76016ef7ddccad0fb10978d2ff02f44b73d2 100644 --- a/goblin/driver/cluster.py +++ b/goblin/driver/cluster.py @@ -18,7 +18,11 @@ class Cluster: 'ssl_password': '', 'username': '', 'password': '', - 'response_timeout': None + 'response_timeout': None, + 'max_conns': 4, + 'min_conns': 1, + 'max_times_acquired': 16, + 'max_inflight': 64 } def __init__(self, loop, **config): @@ -56,6 +60,10 @@ class Cluster: response_timeout = self._config['response_timeout'] username = self._config['username'] password = self._config['password'] + max_times_acquired = self._config['max_times_acquired'] + max_conns = self._config['max_conns'] + min_conns = self._config['min_conns'] + max_inflight = self._config['max_inflight'] if scheme in ['https', 'wss']: certfile = self._config['ssl_certfile'] keyfile = self._config['ssl_keyfile'] @@ -70,7 +78,9 @@ class Cluster: host = await driver.GremlinServer.open( url, self._loop, ssl_context=ssl_context, response_timeout=response_timeout, username=username, - password=password) + password=password, max_times_acquired=max_times_acquired, + max_conns=max_conns, min_conns=min_conns, + max_inflight=max_inflight) self._hosts.append(host) def config_from_file(self, filename): diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 0867b1d90587d114b0292062a8447cdd7a59b39b..d3fd357b506e027a897a99c4f1d037a922a840cb 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -109,16 +109,16 @@ class Connection(AbstractConnection): websocket connection. Not instantiated directly. Instead use :py:meth:`connect<goblin.driver.server.connect>`. """ - def __init__(self, url, ws, loop, conn_factory, *, aliases=None, + def __init__(self, url, ws, loop, conn_factory, *, traversal_source=None, response_timeout=None, lang='gremlin-groovy', username=None, password=None, max_inflight=64): self._url = url self._ws = ws self._loop = loop self._conn_factory = conn_factory - if aliases is None: - aliases = {} - self._aliases = aliases + if traversal_source is None: + traversal_source = {} + self._traversal_source = traversal_source self._response_timeout = response_timeout self._lang = lang self._username = username @@ -150,7 +150,7 @@ class Connection(AbstractConnection): *, bindings=None, lang=None, - aliases=None, + traversal_source=None, session=None): """ Submit a script and bindings to the Gremlin Server @@ -159,7 +159,7 @@ class Connection(AbstractConnection): :param dict bindings: A mapping of bindings for Gremlin script. :param str lang: Language of scripts submitted to the server. "gremlin-groovy" by default - :param dict aliases: Rebind ``Graph`` and ``TraversalSource`` + :param dict traversal_source: Rebind ``Graph`` and ``TraversalSource`` objects to different variable names in the current request :param str op: Gremlin Server op argument. "eval" by default. :param str processor: Gremlin Server processor argument. "" by default. @@ -169,14 +169,14 @@ class Connection(AbstractConnection): :returns: :py:class:`Response` object """ await self.semaphore.acquire() - if aliases is None: - aliases = self._aliases + if traversal_source is None: + traversal_source = self._traversal_source lang = lang or self._lang request_id = str(uuid.uuid4()) message = self._prepare_message(gremlin, bindings, lang, - aliases, + traversal_source, session, request_id) response_queue = asyncio.Queue(loop=self._loop) @@ -195,7 +195,7 @@ class Connection(AbstractConnection): self._closed = True await self._conn_factory.close() - def _prepare_message(self, gremlin, bindings, lang, aliases, session, + def _prepare_message(self, gremlin, bindings, lang, traversal_source, session, request_id): message = { 'requestId': request_id, @@ -205,7 +205,7 @@ class Connection(AbstractConnection): 'gremlin': gremlin, 'bindings': bindings, 'language': lang, - 'aliases': aliases + 'aliases': traversal_source } } message = self._finalize_message(message, session) diff --git a/goblin/driver/pool.py b/goblin/driver/pool.py index caabe3b331f518f331d671e6f57a96034d8ba362..6e46984874fd565a0f81276e557a9765f13037a1 100644 --- a/goblin/driver/pool.py +++ b/goblin/driver/pool.py @@ -7,13 +7,15 @@ from goblin.driver import connection async def connect(url, loop, *, ssl_context=None, username='', password='', - lang='gremlin-groovy', aliases=None): + lang='gremlin-groovy', traversal_source=None, + max_inflight=64, response_timeout=None): connector = aiohttp.TCPConnector(ssl_context=ssl_context, loop=loop) client_session = aiohttp.ClientSession(loop=loop, connector=connector) ws = await client_session.ws_connect(url) return connection.Connection(url, ws, loop, client_session, - aliases=aliases, lang=lang, - username=username, password=password) + traversal_source=traversal_source, lang=lang, + username=username, password=password, + response_timeout=response_timeout) class PooledConnection: @@ -38,17 +40,17 @@ class PooledConnection: *, bindings=None, lang=None, - aliases=None, + traversal_source=None, session=None): return await self._conn.submit(gremlin, bindings=bindings, lang=lang, - aliases=aliases, session=session) + traversal_source=traversal_source, session=session) async def release_task(self, resp): await resp.done.wait() - await self.release() + self.release() - async def release(self): - await self._pool.release(self) + def release(self): + self._pool.release(self) async def close(self): # close pool? @@ -65,7 +67,8 @@ class ConnectionPool: def __init__(self, url, loop, *, ssl_context=None, username='', password='', lang='gremlin-groovy', max_conns=4, - max_times_acquired=8, aliases=None, **kwargs): + min_conns=1, max_times_acquired=16, max_inflight=64, + traversal_source=None, response_timeout=None): self._url = url self._loop = loop self._ssl_context = ssl_context @@ -73,23 +76,33 @@ class ConnectionPool: self._password = password self._lang = lang self._max_conns = max_conns + self._min_conns = min_conns self._max_times_acquired = max_times_acquired + self._max_inflight = max_inflight + self._response_timeout = response_timeout self._condition = asyncio.Condition(loop=self._loop) - self._lock = asyncio.Lock(loop=self._loop) self._available = collections.deque() self._acquired = collections.deque() - self._aliases = aliases + self._traversal_source = traversal_source @property def url(self): return self._url - async def release(self, conn): + async def init_pool(self): + for i in range(self._min_conns): + conn = await self._get_connection() + self._available.append(conn) + + def release(self, conn): + conn.decrement_acquired() + if not conn.times_acquired: + self._acquired.remove(conn) + self._available.append(conn) + self._loop.create_task(self._notify()) + + async def _notify(self): async with self._condition: - conn.decrement_acquired() - if not conn.times_acquired: - self._acquired.remove(conn) - self._available.append(conn) self._condition.notify() async def acquire(self): @@ -102,7 +115,7 @@ class ConnectionPool: self._acquired.append(conn) return conn if len(self._acquired) < self._max_conns: - conn = await self.get_connection() + conn = await self._get_connection() conn.increment_acquired() self._acquired.append(conn) return conn @@ -111,7 +124,7 @@ class ConnectionPool: conn = self._acquired.popleft() if conn.times_acquired < self._max_times_acquired: conn.increment_acquired() - self._aquired.append(conn) + self._acquired.append(conn) return conn self._acquired.append(conn) else: @@ -127,8 +140,9 @@ class ConnectionPool: waiters.append(conn.close()) await asyncio.gather(*waiters) - async def get_connection(self, username=None, password=None, lang=None, - aliases=None): + async def _get_connection(self, username=None, password=None, lang=None, + traversal_source=None, max_inflight=None, + response_timeout=None): """ Open a connection to the Gremlin Server. @@ -141,10 +155,15 @@ class ConnectionPool: """ username = username or self._username password = password or self._password - aliasess = aliases or self._aliases + traversal_source = traversal_source or self._traversal_source + response_timeout = response_timeout or self._response_timeout + max_inflight = max_inflight or self._max_inflight lang = lang or self._lang conn = await connect(self._url, self._loop, ssl_context=self._ssl_context, username=username, password=password, lang=lang, - aliases=aliases) - return PooledConnection(conn, self) + traversal_source=traversal_source, + max_inflight=max_inflight, + response_timeout=response_timeout) + conn = PooledConnection(conn, self) + return conn diff --git a/goblin/driver/server.py b/goblin/driver/server.py index 2c33c4c0fcbb9937d35407e52f463eba59c52057..fa9d7f190a04ab20e146afdc1b89ec48b53dea97 100644 --- a/goblin/driver/server.py +++ b/goblin/driver/server.py @@ -26,7 +26,7 @@ class GremlinServer: def __init__(self, pool, *, ssl_context=None, username='', password='', lang='gremlin-groovy', - aliases=None): + traversal_source=None): self._pool = pool self._url = self._pool.url self._loop = self._pool._loop @@ -34,7 +34,7 @@ class GremlinServer: self._username = username self._password = password self._lang = lang - self._aliases = aliases + self._traversal_source = traversal_source async def close(self): await self._pool.close() @@ -47,10 +47,16 @@ class GremlinServer: @classmethod async def open(cls, url, loop, *, ssl_context=None, username='', password='', lang='gremlin-groovy', - aliases=None, **kwargs): + traversal_source=None, max_conns=4, min_conns=1, + max_times_acquired=16, max_inflight=64, + response_timeout=None): conn_pool = pool.ConnectionPool( url, loop, ssl_context=ssl_context, username=username, - password=password, lang=lang, aliases=aliases) + password=password, lang=lang, traversal_source=traversal_source, + max_conns=max_conns, min_conns=min_conns, + max_times_acquired=max_times_acquired, max_inflight=max_inflight, + response_timeout=response_timeout) + await conn_pool.init_pool() return cls(conn_pool, ssl_context=ssl_context, username=username, - password=password, lang=lang, aliases=aliases) + password=password, lang=lang, traversal_source=traversal_source) diff --git a/goblin/session.py b/goblin/session.py index a0e18a2bc14ea2d7b62d72dad6b4a7d1f4b87548..0586def4565e23c21319cc01a882dda442565ffa 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -42,12 +42,12 @@ class Session(connection.AbstractConnection): """ def __init__(self, app, conn, get_hashable_id, transactions, *, - use_session=False, aliases=None): + use_session=False, traversal_source=None): self._app = app self._conn = conn self._loop = self._app._loop self._use_session = False - self._aliases = aliases or dict() + self._traversal_source = traversal_source or dict() self._pending = collections.deque() self._current = weakref.WeakValueDictionary() self._get_hashable_id = get_hashable_id @@ -80,8 +80,6 @@ class Session(connection.AbstractConnection): def close(self): """ - Close the underlying db connection and disconnect session from Goblin - application. """ # await self.conn.close() self._conn = None @@ -130,7 +128,7 @@ class Session(connection.AbstractConnection): """ await self.flush() async_iter = await self.conn.submit( - gremlin, bindings=bindings, lang=lang, aliases=self._aliases) + gremlin, bindings=bindings, lang=lang, traversal_source=self._traversal_source) response_queue = asyncio.Queue(loop=self._loop) self._loop.create_task( self._receive(async_iter, response_queue)) @@ -320,7 +318,7 @@ class Session(connection.AbstractConnection): async def _simple_traversal(self, traversal, element): stream = await self.conn.submit( repr(traversal), bindings=traversal.bindings, - aliases=self._aliases) + traversal_source=self._traversal_source) msg = await stream.fetch_data() stream.close() if msg: @@ -369,7 +367,7 @@ class Session(connection.AbstractConnection): async def _check_vertex(self, vertex): """Used to check for existence, does not update session vertex""" traversal = self.g.V(vertex.id) - stream = await self.conn.submit(repr(traversal), aliases=self._aliases) + stream = await self.conn.submit(repr(traversal), traversal_source=self._traversal_source) msg = await stream.fetch_data() stream.close() return msg @@ -377,7 +375,7 @@ class Session(connection.AbstractConnection): async def _check_edge(self, edge): """Used to check for existence, does not update session edge""" traversal = self.g.E(edge.id) - stream = await self.conn.submit(repr(traversal), aliases=self._aliases) + stream = await self.conn.submit(repr(traversal), traversal_source=self._traversal_source) msg = await stream.fetch_data() stream.close() return msg @@ -414,7 +412,7 @@ class Session(connection.AbstractConnection): db_name).hasValue(value).property(key, val) stream = await self.conn.submit( repr(traversal), bindings=traversal.bindings, - aliases=self._aliases) + traversal_source=self._traversal_source) await stream.fetch_data() stream.close() else: diff --git a/tests/conftest.py b/tests/conftest.py index c9e89934f1d667d60a3e3bb3e20221a7cbf54a0c..0e4fb14f0e533afa428316b4c0995a5d16333fed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ import asyncio import pytest from goblin import Goblin, driver, element, properties, Cardinality +from goblin.driver import pool from gremlin_python import process @@ -69,6 +70,11 @@ def connection(event_loop): return conn +@pytest.fixture +def connection_pool(event_loop): + return pool.ConnectionPool("http://localhost:8182/", event_loop) + + @pytest.fixture def cluster(event_loop): return driver.Cluster(event_loop) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6b9ea4df4a91af40f0b7fecfefc570a89fe7d4 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,31 @@ +# Copyright 2016 ZEROFAIL +# +# This file is part of Goblin. +# +# Goblin is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Goblin is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Goblin. If not, see <http://www.gnu.org/licenses/>. + +import asyncio +import pytest + + +@pytest.mark.asyncio +async def test_client_auto_release(cluster): + client = await cluster.connect() + resp = await client.submit("1 + 1") + async for msg in resp: + pass + await asyncio.sleep(0) + host = cluster._hosts.popleft() + assert len(host._pool._available) == 1 + await host.close() diff --git a/tests/test_connection.py b/tests/test_connection.py index 691b5a9efcc233a4617c02276fd247420babe36c..a2c3a84dd8cdb31163c5e437ca8e9fc33c880424 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -79,7 +79,7 @@ async def test_resp_queue_removed_from_conn(connection): stream = await connection.submit("1 + 1") async for msg in stream: pass - await asyncio.sleep(0.1) + await asyncio.sleep(0) assert stream._response_queue not in list( connection._response_queues.values()) diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..6f9e65b7738492189a1638431a38b97b4953b510 --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,105 @@ +# Copyright 2016 ZEROFAIL +# +# This file is part of Goblin. +# +# Goblin is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Goblin is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Goblin. If not, see <http://www.gnu.org/licenses/>. + +import asyncio +import pytest + + +@pytest.mark.asyncio +async def test_pool_init(connection_pool): + await connection_pool.init_pool() + assert len(connection_pool._available) == 1 + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_acquire_release(connection_pool): + conn = await connection_pool.acquire() + assert not len(connection_pool._available) + assert len(connection_pool._acquired) == 1 + assert conn.times_acquired == 1 + connection_pool.release(conn) + assert len(connection_pool._available) == 1 + assert not len(connection_pool._acquired) + assert not conn.times_acquired + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_acquire_multiple(connection_pool): + conn1 = await connection_pool.acquire() + conn2 = await connection_pool.acquire() + assert not conn1 is conn2 + assert len(connection_pool._acquired) == 2 + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_share(connection_pool): + connection_pool._max_conns = 1 + conn1 = await connection_pool.acquire() + conn2 = await connection_pool.acquire() + assert conn1 is conn2 + assert conn1.times_acquired == 2 + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_acquire_multiple_and_share(connection_pool): + connection_pool._max_conns = 2 + connection_pool._max_times_acquired = 2 + conn1 = await connection_pool.acquire() + conn2 = await connection_pool.acquire() + assert not conn1 is conn2 + conn3 = await connection_pool.acquire() + conn4 = await connection_pool.acquire() + assert not conn3 is conn4 + assert conn3 is conn1 + assert conn4 is conn2 + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_max_acquired(connection_pool): + connection_pool._max_conns = 2 + connection_pool._max_times_acquired = 2 + conn1 = await connection_pool.acquire() + conn2 = await connection_pool.acquire() + conn3 = await connection_pool.acquire() + conn4 = await connection_pool.acquire() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(connection_pool.acquire(), timeout=0.1) + await connection_pool.close() + + +@pytest.mark.asyncio +async def test_release_notify(connection_pool): + connection_pool._max_conns = 2 + connection_pool._max_times_acquired = 2 + conn1 = await connection_pool.acquire() + conn2 = await connection_pool.acquire() + conn3 = await connection_pool.acquire() + conn4 = await connection_pool.acquire() + + async def release(conn): + conn.release() + + results = await asyncio.gather( + *[connection_pool.acquire(), release(conn4)]) + conn4 = results[0] + assert conn4 is conn2 + await connection_pool.close()