Commit 566f65a5 authored by davebshow's avatar davebshow
Browse files

clients can be assigned to a specific host

parent d2812818
......@@ -16,11 +16,12 @@ class Client:
:param asyncio.BaseEventLoop loop:
:param dict aliases: Optional mapping for aliases. Default is `None`
"""
def __init__(self, cluster, loop, *, aliases=None):
def __init__(self, cluster, loop, *, hostname=None, aliases=None):
self._cluster = cluster
self._loop = loop
if aliases is None:
aliases = {}
self._hostname = hostname
self._aliases = aliases
@property
......@@ -76,7 +77,7 @@ class Client:
'aliases': self._aliases})
if bindings:
message.args.update({'bindings': bindings})
conn = await self.cluster.get_connection()
conn = await self.cluster.get_connection(hostname=self._hostname)
resp = await conn.write(message)
self._loop.create_task(conn.release_task(resp))
return resp
......@@ -63,6 +63,7 @@ class Cluster:
default_config.update(config)
self._config = self._process_config_imports(default_config)
self._hosts = collections.deque()
self._hostmap = {}
self._closed = False
if aliases is None:
aliases = {}
......@@ -100,7 +101,7 @@ class Cluster:
"""
return self._config
async def get_connection(self):
async def get_connection(self, hostname=None):
"""
**coroutine** Get connection from next available host in a round robin
fashion.
......@@ -109,7 +110,14 @@ class Cluster:
"""
if not self._hosts:
await self.establish_hosts()
host = self._hosts.popleft()
if hostname:
try:
host = self._hostmap[hostname]
except KeyError:
raise exception.ConfigError(
'Unknown host: {}'.format(hostname))
else:
host = self._hosts.popleft()
conn = await host.get_connection()
self._hosts.append(host)
return conn
......@@ -121,11 +129,12 @@ class Cluster:
scheme = self._config['scheme']
hosts = self._config['hosts']
port = self._config['port']
for host in hosts:
url = '{}://{}:{}/gremlin'.format(scheme, host, port)
for hostname in hosts:
url = '{}://{}:{}/gremlin'.format(scheme, hostname, port)
host = await driver.GremlinServer.open(
url, self._loop, **dict(self._config))
self._hosts.append(host)
self._hostmap[hostname] = host
def config_from_file(self, filename):
"""
......@@ -186,7 +195,7 @@ class Cluster:
config = self._process_config_imports(config)
self.config.update(config)
async def connect(self, aliases=None):
async def connect(self, hostname=None, aliases=None):
"""
**coroutine** Get a connected client. Main API method.
......@@ -202,7 +211,8 @@ class Cluster:
# aliases=aliases)
# self._hosts.append(host)
# else:
client = driver.Client(self, self._loop, aliases=aliases)
client = driver.Client(self, self._loop, hostname=hostname,
aliases=aliases)
return client
async def close(self):
......
......@@ -22,4 +22,4 @@ __author__ = 'Marko A. Rodriguez (http://markorodriguez.com)'
from aiogremlin.gremlin_python.statics import *
from aiogremlin.gremlin_python.process.graph_traversal import __
from aiogremlin.gremlin_python.process.strategies import *
from aiogremlin.gremlin_python.process.traversal import Binding
from aiogremlin.gremlin_python.process.traversal import Binding, Cardinality
......@@ -3,7 +3,7 @@ from setuptools import setup
setup(
name='aiogremlin',
version='3.2.4b1',
version='3.2.4b2',
url='',
license='Apache Software License',
author='davebshow',
......
......@@ -45,6 +45,17 @@ async def test_alias(cluster):
await cluster.close()
@pytest.mark.asyncio
async def test_client_auto_release(cluster):
client = await cluster.connect(hostname='localhost')
resp = await client.submit("1 + 1")
async for msg in resp:
assert msg == 2
assert client._hostname == 'localhost'
await cluster.close()
# @pytest.mark.asyncio
# async def test_sessioned_client(cluster):
# session = str(uuid.uuid4())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment