diff --git a/goblin/api.py b/goblin/api.py index 67d8e767053a6487d90a0f3f8ac73f2a00b382d3..6e4d72a0052800caf3fc0d118891db7b9a34279a 100644 --- a/goblin/api.py +++ b/goblin/api.py @@ -1,20 +1,33 @@ """Main OGM API classes and constructors""" import collections +import logging from goblin import gremlin_python from goblin import mapper from goblin import properties from goblin import query -from goblin.gremlin_python_driver import driver +from goblin import gremlin_python_driver + + +logger = logging.getLogger(__name__) # Constructor API -async def create_engine(url, loop): +async def create_engine(url, + loop, + maxsize=256, + force_close=False, + force_release=True): """Constructor function for :py:class:`Engine`. Connects to database and builds a dictionary of relevant vendor implmentation features""" features = {} # Will use a pool here - async with driver.create_connection(url, loop) as conn: + pool = gremlin_python_driver.create_pool(url, + loop, + maxsize=maxsize, + force_close=force_close, + force_release=force_release) + async with pool.driver.get() as conn: # Propbably just use a parser to parse the whole feature list stream = conn.submit( 'graph.features().graph().supportsComputer()') @@ -37,7 +50,7 @@ async def create_engine(url, loop): msg = await stream.fetch_data() features['threaded_transactions'] = msg.data[0] - return Engine(url, loop, **features) + return Engine(url, loop, pool=pool, **features) # Main API classes @@ -46,17 +59,16 @@ class Engine: database connections. Used as a factory to create :py:class:`Session` objects. More config coming soon.""" - def __init__(self, url, loop, *, force_close=True, **features): + def __init__(self, url, loop, *, pool=None, force_close=True, **features): self._url = url self._loop = loop self._force_close = force_close self._features = features self._translator = gremlin_python.GroovyTranslator('g') # This will be a pool - self._driver = driver.Driver(self._url, self._loop) - - def session(self, *, use_session=False): - return Session(self, use_session=use_session) + if pool is None: + pool = gremlin_python_driver.Pool(url, loop) + self._pool = pool @property def translator(self): @@ -64,15 +76,22 @@ class Engine: @property def url(self): - return url + return self._url + + @property + def pool(self): + return self._pool + + def session(self, *, use_session=False): + return Session(self, use_session=use_session) async def execute(self, query, *, bindings=None, session=None): - conn = await self._driver.connect(force_close=self._force_close) + conn = await self.pool.acquire() return conn.submit(query, bindings=bindings) async def close(self): - await self._driver.close() - self._driver = None + await self.pool.close() + self._pool = None class Session: @@ -278,6 +297,7 @@ class ElementMeta(type): new_namespace[k] = v new_namespace['__mapping__'] = mapper.create_mapping(namespace, props) + logger.warning("Creating new Element class: {}".format(name)) result = type.__new__(cls, name, bases, new_namespace) return result diff --git a/goblin/gremlin_python_driver/__init__.py b/goblin/gremlin_python_driver/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e0283f0ee682ea821400f3cb5bbe9eb2c51b8493 100644 --- a/goblin/gremlin_python_driver/__init__.py +++ b/goblin/gremlin_python_driver/__init__.py @@ -0,0 +1,2 @@ +from goblin.gremlin_python_driver.driver import Driver +from goblin.gremlin_python_driver.pool import create_pool, Pool diff --git a/goblin/gremlin_python_driver/driver.py b/goblin/gremlin_python_driver/driver.py index e99c5653974735860fd35945ecf68a23e98ac1ab..dd341181cd84aea398a2caf6414c2fec48aa1fa3 100644 --- a/goblin/gremlin_python_driver/driver.py +++ b/goblin/gremlin_python_driver/driver.py @@ -12,40 +12,50 @@ Message = collections.namedtuple( ["status_code", "data", "message", "metadata"]) -def create_connection(url, loop): - """Driver constructor function.""" - return Driver(url, loop) - - class Driver: - def __init__(self, url, loop): + def __init__(self, url, loop, *, client_session=None): self._url = url self._loop = loop - self._session = aiohttp.ClientSession(loop=self._loop) - self._conn = None + if not client_session: + client_session = aiohttp.ClientSession(loop=self._loop) + self._client_session = client_session @property def conn(self): return self._conn - async def __aenter__(self): - conn = await self.connect(force_close=False) - self._conn = conn - return conn - - async def __aexit__(self, exc_type, exc, tb): - await self.close() + def get(self): + return AsyncDriverConnectionContextManager(self) - async def connect(self, *, force_close=True): - ws = await self._session.ws_connect(self._url) - return Connection(ws, self._loop, force_close=force_close) + async def connect(self, *, force_close=True, force_release=False, pool=None): + ws = await self._client_session.ws_connect(self._url) + return Connection(ws, self._loop, force_close=force_close, + force_release=force_release, pool=pool) async def close(self): if self._conn: await self._conn.close() self._conn = None - await self._session.close() + await self._client_session.close() + self._client_session = None + + +class AsyncDriverConnectionContextManager: + + __slots__ = ('_driver', '_conn') + + def __init__(self, driver): + self._driver = driver + + async def __aenter__(self): + self._conn = await self._driver.connect(force_close=False) + return self._conn + + async def __aexit__(self, exc_type, exc, tb): + await self._conn.close() + self._conn = None + self._driver = None class AsyncResponseIter: @@ -55,6 +65,7 @@ class AsyncResponseIter: self._loop = loop self._conn = conn self._force_close = self._conn.force_close + self._force_release = self._conn.force_release self._closed = False async def __aiter__(self): @@ -84,29 +95,53 @@ class AsyncResponseIter: message["result"]["meta"]) if message.status_code in [200, 206, 204]: if message.status_code != 206: - self._closed = True - if self._force_close: - await self.close() + await self._term() return message elif message.status_code == 407: pass # auth else: + await self._term() raise RuntimeError("{0} {1}".format(message.status_code, message.message)) - + async def _term(self): + self._closed = True + if self._force_close: + await self.close() + elif self._force_release: + await self._conn.release() class Connection: - def __init__(self, ws, loop, *, force_close=True): + def __init__(self, ws, loop, *, force_close=True, force_release=False, + pool=None): self._ws = ws self._loop = loop self._force_close = force_close + self._force_release = force_release + self._pool = pool + self._closed = False + + @property + def closed(self): + return self._closed @property def force_close(self): return self._force_close + @property + def force_release(self): + return self._force_release + + @property + def pool(self): + return self._pool + + async def release(self): + if self._pool: + await self._pool.release(self) + def submit(self, gremlin, *, @@ -133,6 +168,7 @@ class Connection: async def close(self): await self._ws.close() + self._closed = True def _prepare_message(self, gremlin, bindings, lang, aliases, op, processor, session, request_id): diff --git a/goblin/gremlin_python_driver/pool.py b/goblin/gremlin_python_driver/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7f76d6b4ce8650e76f6176f11fb61e20894f86 --- /dev/null +++ b/goblin/gremlin_python_driver/pool.py @@ -0,0 +1,155 @@ +import collections +import logging + +from goblin.gremlin_python_driver import driver + + +logger = logging.getLogger(__name__) + + +def create_pool(url, + loop, + maxsize=256, + force_close=False, + force_release=True): + return Pool(url, loop, maxsize=maxsize, force_close=force_close, + force_release=force_release) + + +class Pool(object): + """ + Pool of :py:class:`goblin.gremlin_python_driver.Connection` objects. + :param str url: url for Gremlin Server. + :param float timeout: timeout for establishing connection (optional). + Values ``0`` or ``None`` mean no timeout + :param str username: Username for SASL auth + :param str password: Password for SASL auth + :param gremlinclient.graph.GraphDatabase graph: The graph instances + used to create connections + :param int maxsize: Maximum number of connections. + :param loop: event loop + """ + def __init__(self, url, loop, maxsize=256, force_close=False, + force_release=True): + self._graph = url + self._loop = loop + self._maxsize = maxsize + self._force_close = force_close + self._force_release = force_release + self._pool = collections.deque() + self._acquired = set() + self._acquiring = 0 + self._closed = False + self._driver = driver.Driver(url, loop) + self._conn = None + + @property + def freesize(self): + """ + Number of free connections + :returns: int + """ + return len(self._pool) + + @property + def size(self): + """ + Total number of connections + :returns: int + """ + return len(self._acquired) + self._acquiring + self.freesize + + @property + def maxsize(self): + """ + Maximum number of connections + :returns: in + """ + return self._maxsize + + @property + def driver(self): + """ + Associated graph instance used for creating connections + :returns: :py:class:`gremlinclient.graph.GraphDatabase` + """ + return self._driver + + @property + def pool(self): + """ + Object that stores unused connections + :returns: :py:class:`collections.deque` + """ + return self._pool + + @property + def closed(self): + """ + Check if pool has been closed + :returns: bool + """ + return self._closed or self._graph is None + + def get(self): + return AsyncPoolConnectionContextManager(self) + + async def acquire(self): + """ + Acquire a connection from the Pool + :returns: Future - + :py:class:`asyncio.Future`, :py:class:`trollius.Future`, or + :py:class:`tornado.concurrent.Future` + """ + if self._pool: + while self._pool: + conn = self._pool.popleft() + if not conn.closed: + logger.debug("Reusing connection: {}".format(conn)) + self._acquired.add(conn) + break + else: + logger.debug( + "Discarded closed connection: {}".format(conn)) + elif self.size < self.maxsize: + self._acquiring += 1 + conn = await self.driver.connect(force_close=self._force_close, + force_release=self._force_release, pool=self) + self._acquiring -= 1 + self._acquired.add(conn) + logger.debug( + "Acquired new connection: {}".format(conn)) + + return conn + + async def release(self, conn): + """ + Release a connection back to the pool. + :param gremlinclient.connection.Connection: The connection to be + released + """ + if self.size <= self.maxsize: + if conn.closed: + # conn has been closed + logger.info( + "Released closed connection: {}".format(conn)) + self._acquired.remove(conn) + conn = None + else: + self._pool.append(conn) + self._acquired.remove(conn) + else: + await conn.close() + + async def close(self): + """ + Close pool + """ + while self.pool: + conn = self.pool.popleft() + await conn.close() + await self.driver.close() + self._driver = None + self._closed = True + logger.info( + "Connection pool {} has been closed".format(self)) diff --git a/goblin/properties.py b/goblin/properties.py index 9be2fd7be769ed57030964733c41d0ebcc3e108a..77e70969f0bf7037bf2a37b566ca2bed19950e2b 100644 --- a/goblin/properties.py +++ b/goblin/properties.py @@ -1,5 +1,9 @@ """Classes to handle proerties and data type definitions""" import abc +import logging + + +logger = logging.getLogger(__name__) class Property: diff --git a/tests/test_driver.py b/tests/test_driver.py index ea573a1487ce141f83a0c309fd5146431a16622a..ad00e9477231840b6dad30191eed2e16c3631b1e 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,7 +1,7 @@ import asyncio import unittest -from goblin.gremlin_python_driver.driver import create_connection +from goblin.gremlin_python_driver.driver import Driver class TestDriver(unittest.TestCase): @@ -13,16 +13,20 @@ class TestDriver(unittest.TestCase): def test_connect(self): async def go(): - async with create_connection("http://localhost:8182/", self.loop) as conn: + driver = Driver("http://localhost:8182/", self.loop) + async with driver.get() as conn: self.assertFalse(conn._ws.closed) + await driver.close() self.loop.run_until_complete(go()) def test_submit(self): async def go(): - async with create_connection("http://localhost:8182/", self.loop) as conn: + driver = Driver("http://localhost:8182/", self.loop) + async with driver.get() as conn: async for msg in conn.submit("1 + 1"): self.assertEqual(msg.data[0], 2) + await driver.close() self.loop.run_until_complete(go()) diff --git a/tests/test_engine.py b/tests/test_engine.py index fb0d3081ccf74e8c5317ceff0c9ab966ee6e216e..42f40902b5d2a7c87c18e324a54923eae8611e0d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -41,6 +41,7 @@ class TestEngine(unittest.TestCase): self.assertIs(leif, current) self.assertEqual(leif.id, current.id) await engine.close() + print(engine) self.loop.run_until_complete(go())