diff --git a/aiogremlin/driver/client.py b/aiogremlin/driver/client.py index 2759d68649b0864dd729b1fbdd4006257bd7e792..73852b581c47d412552c5e4b69afb0c01ac38a26 100644 --- a/aiogremlin/driver/client.py +++ b/aiogremlin/driver/client.py @@ -16,11 +16,12 @@ class Client: :param asyncio.BaseEventLoop loop: :param dict aliases: Optional mapping for aliases. Default is `None` """ - def __init__(self, cluster, loop, *, aliases=None): + def __init__(self, cluster, loop, *, hostname=None, aliases=None): self._cluster = cluster self._loop = loop if aliases is None: aliases = {} + self._hostname = hostname self._aliases = aliases @property @@ -76,7 +77,7 @@ class Client: 'aliases': self._aliases}) if bindings: message.args.update({'bindings': bindings}) - conn = await self.cluster.get_connection() + conn = await self.cluster.get_connection(hostname=self._hostname) resp = await conn.write(message) self._loop.create_task(conn.release_task(resp)) return resp diff --git a/aiogremlin/driver/cluster.py b/aiogremlin/driver/cluster.py index 6cc403fb81d0243c1da49a0a6362c4dd125e0b50..6f1e02f68d5bc8d05240e6b48ddfef38d289aed8 100644 --- a/aiogremlin/driver/cluster.py +++ b/aiogremlin/driver/cluster.py @@ -63,6 +63,7 @@ class Cluster: default_config.update(config) self._config = self._process_config_imports(default_config) self._hosts = collections.deque() + self._hostmap = {} self._closed = False if aliases is None: aliases = {} @@ -100,7 +101,7 @@ class Cluster: """ return self._config - async def get_connection(self): + async def get_connection(self, hostname=None): """ **coroutine** Get connection from next available host in a round robin fashion. @@ -109,7 +110,14 @@ class Cluster: """ if not self._hosts: await self.establish_hosts() - host = self._hosts.popleft() + if hostname: + try: + host = self._hostmap[hostname] + except KeyError: + raise exception.ConfigError( + 'Unknown host: {}'.format(hostname)) + else: + host = self._hosts.popleft() conn = await host.get_connection() self._hosts.append(host) return conn @@ -121,11 +129,12 @@ class Cluster: scheme = self._config['scheme'] hosts = self._config['hosts'] port = self._config['port'] - for host in hosts: - url = '{}://{}:{}/gremlin'.format(scheme, host, port) + for hostname in hosts: + url = '{}://{}:{}/gremlin'.format(scheme, hostname, port) host = await driver.GremlinServer.open( url, self._loop, **dict(self._config)) self._hosts.append(host) + self._hostmap[hostname] = host def config_from_file(self, filename): """ @@ -186,7 +195,7 @@ class Cluster: config = self._process_config_imports(config) self.config.update(config) - async def connect(self, aliases=None): + async def connect(self, hostname=None, aliases=None): """ **coroutine** Get a connected client. Main API method. @@ -202,7 +211,8 @@ class Cluster: # aliases=aliases) # self._hosts.append(host) # else: - client = driver.Client(self, self._loop, aliases=aliases) + client = driver.Client(self, self._loop, hostname=hostname, + aliases=aliases) return client async def close(self): diff --git a/aiogremlin/gremlin_python/__init__.py b/aiogremlin/gremlin_python/__init__.py index 41aee8bc82bf4bb98ed95c3c6982612273e92e68..a2a5a4bd6bd921bfa65b3665da8858bc7b9846cd 100644 --- a/aiogremlin/gremlin_python/__init__.py +++ b/aiogremlin/gremlin_python/__init__.py @@ -22,4 +22,4 @@ __author__ = 'Marko A. Rodriguez (http://markorodriguez.com)' from aiogremlin.gremlin_python.statics import * from aiogremlin.gremlin_python.process.graph_traversal import __ from aiogremlin.gremlin_python.process.strategies import * -from aiogremlin.gremlin_python.process.traversal import Binding +from aiogremlin.gremlin_python.process.traversal import Binding, Cardinality diff --git a/setup.py b/setup.py index 3e068f0afe905c4c46fa3fefcae5aa082259b8a8..9a8b7bc3196939d463e033217526552467e8d54e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup( name='aiogremlin', - version='3.2.4b1', + version='3.2.4b2', url='', license='Apache Software License', author='davebshow', diff --git a/tests/test_client.py b/tests/test_client.py index bf9aacde9766d4d5fc3c4ecc04142058b1c066f9..226e4dd2f2834ce992681f43f73455840b1cd7c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -45,6 +45,17 @@ async def test_alias(cluster): await cluster.close() +@pytest.mark.asyncio +async def test_client_auto_release(cluster): + client = await cluster.connect(hostname='localhost') + resp = await client.submit("1 + 1") + async for msg in resp: + assert msg == 2 + assert client._hostname == 'localhost' + await cluster.close() + + + # @pytest.mark.asyncio # async def test_sessioned_client(cluster): # session = str(uuid.uuid4())