diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 87cbcc5c27346e46897617002c7b585eff95ca96..9254b59934de4de1957d1297ae6f1ea6f92a5e93 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -2,3 +2,4 @@ * Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com> - Project owner from 2019 to present. * David M. Brown <davebshow@gmail.com> - Project founder and project owner from 2016 - 2019. +* Guy Rozendorn <guy@rzn.co.il> diff --git a/aiogremlin/driver/aiohttp/transport.py b/aiogremlin/driver/aiohttp/transport.py index 3fe33ed1bdca38772b29cb9f1cea7af2dbed2126..620131d72cfd41b476a409aff926837671e7f05a 100644 --- a/aiogremlin/driver/aiohttp/transport.py +++ b/aiogremlin/driver/aiohttp/transport.py @@ -10,13 +10,13 @@ class AiohttpTransport(transport.AbstractBaseTransport): self._loop = loop self._connected = False - async def connect(self, url, *, ssl_context=None): + async def connect(self, url, *, ssl_context=None, headers=None): await self.close() connector = aiohttp.TCPConnector( ssl_context=ssl_context, loop=self._loop) self._client_session = aiohttp.ClientSession( loop=self._loop, connector=connector) - self._ws = await self._client_session.ws_connect(url) + self._ws = await self._client_session.ws_connect(url, headers=headers) self._connected = True async def write(self, message): diff --git a/aiogremlin/driver/cluster.py b/aiogremlin/driver/cluster.py index 9895cda1003021f82b06c8a2b0d279b075fbf246..d10325ea757561caeff893b1659717c06867daab 100644 --- a/aiogremlin/driver/cluster.py +++ b/aiogremlin/driver/cluster.py @@ -45,6 +45,7 @@ class Cluster: 'ssl_certfile': '', 'ssl_keyfile': '', 'ssl_password': '', + 'headers': None, 'username': '', 'password': '', 'response_timeout': None, diff --git a/aiogremlin/driver/connection.py b/aiogremlin/driver/connection.py index cc018a6fdfa911af5f21b4c83be923d61dd14a4b..ff7f52f05450282c1c2cb5d7d207fc2335f682d0 100644 --- a/aiogremlin/driver/connection.py +++ b/aiogremlin/driver/connection.py @@ -58,6 +58,7 @@ class Connection: protocol=None, transport_factory=None, ssl_context=None, + headers=None, username='', password='', max_inflight=64, @@ -73,6 +74,7 @@ class Connection: Protocol implementation :param transport_factory: Factory function for transports :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth @@ -89,7 +91,7 @@ class Connection: if not transport_factory: transport_factory = lambda: AiohttpTransport(loop) transport = transport_factory() - await transport.connect(url, ssl_context=ssl_context) + await transport.connect(url, ssl_context=ssl_context, headers=headers) return cls(url, transport, protocol, loop, username, password, max_inflight, response_timeout, message_serializer, provider) diff --git a/aiogremlin/driver/pool.py b/aiogremlin/driver/pool.py index 793d82eea7984defe6fc640f794bb463ce7784cb..9505d44450b788aa8b56aaa168bcd06203c6b83b 100644 --- a/aiogremlin/driver/pool.py +++ b/aiogremlin/driver/pool.py @@ -81,6 +81,7 @@ class ConnectionPool: :param str url: url for host Gremlin Server :param asyncio.BaseEventLoop loop: :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth :param float response_timeout: (optional) `None` by default @@ -92,12 +93,13 @@ class ConnectionPool: one time on the connection """ - def __init__(self, url, loop, ssl_context, username, password, max_conns, + def __init__(self, url, loop, ssl_context, headers, username, password, max_conns, min_conns, max_times_acquired, max_inflight, response_timeout, message_serializer, provider): self._url = url self._loop = loop self._ssl_context = ssl_context + self._headers = headers self._username = username self._password = password self._max_conns = max_conns @@ -194,7 +196,7 @@ class ConnectionPool: async def _get_connection(self, username, password, max_inflight, response_timeout, message_serializer, provider): conn = await connection.Connection.open( - self._url, self._loop, ssl_context=self._ssl_context, + self._url, self._loop, ssl_context=self._ssl_context, headers=self._headers, username=username, password=password, response_timeout=response_timeout, message_serializer=message_serializer, provider=provider) diff --git a/aiogremlin/driver/server.py b/aiogremlin/driver/server.py index 5a1739037853d405fc3b5cd243ffc83996312242..2c3272c12fa57273f5cda6638f47b21f60173495 100644 --- a/aiogremlin/driver/server.py +++ b/aiogremlin/driver/server.py @@ -35,6 +35,7 @@ class GremlinServer: self._ssl_context = ssl_context else: self._ssl_context = None + self._headers = config['headers'] @property def url(self): @@ -66,8 +67,8 @@ class GremlinServer: async def initialize(self): conn_pool = pool.ConnectionPool( - self._url, self._loop, self._ssl_context, self._username, - self._password, self._max_conns, self._min_conns, + self._url, self._loop, self._ssl_context, self._headers, + self._username, self._password, self._max_conns, self._min_conns, self._max_times_acquired, self._max_inflight, self._response_timeout, self._message_serializer, self._provider) await conn_pool.init_pool() @@ -81,6 +82,7 @@ class GremlinServer: :param str url: url for host Gremlin Server :param asyncio.BaseEventLoop loop: :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth :param float response_timeout: (optional) `None` by default diff --git a/setup.py b/setup.py index 8d892d52ca21e43d794795c628a55053250494fb..7a14f3f2ca85c7dc0a88122d33dc90758cab6ca0 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( 'aiogremlin.remote'], python_requires='>=3.5', install_requires=[ - 'gremlinpython<=3.4.3', + 'gremlinpython>=3.4.8', 'aenum>=1.4.5', # required gremlinpython dep 'aiohttp>=2.2.5', 'PyYAML>=3.12',