From aaa87f83c5d38df458b85a3c6096192c3a585c05 Mon Sep 17 00:00:00 2001 From: Guy Rozendorn <guy@rzn.co.il> Date: Mon, 10 Aug 2020 21:24:31 +0300 Subject: [PATCH] Feat(other): add support for custom request headers Add support for custom request headers ISSUES CLOSED: #2 --- CONTRIBUTORS.md | 1 + aiogremlin/driver/aiohttp/transport.py | 4 ++-- aiogremlin/driver/cluster.py | 1 + aiogremlin/driver/connection.py | 4 +++- aiogremlin/driver/pool.py | 6 ++++-- aiogremlin/driver/server.py | 6 ++++-- setup.py | 2 +- 7 files changed, 16 insertions(+), 8 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 87cbcc5..9254b59 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -2,3 +2,4 @@ * Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com> - Project owner from 2019 to present. * David M. Brown <davebshow@gmail.com> - Project founder and project owner from 2016 - 2019. +* Guy Rozendorn <guy@rzn.co.il> diff --git a/aiogremlin/driver/aiohttp/transport.py b/aiogremlin/driver/aiohttp/transport.py index 3fe33ed..620131d 100644 --- a/aiogremlin/driver/aiohttp/transport.py +++ b/aiogremlin/driver/aiohttp/transport.py @@ -10,13 +10,13 @@ class AiohttpTransport(transport.AbstractBaseTransport): self._loop = loop self._connected = False - async def connect(self, url, *, ssl_context=None): + async def connect(self, url, *, ssl_context=None, headers=None): await self.close() connector = aiohttp.TCPConnector( ssl_context=ssl_context, loop=self._loop) self._client_session = aiohttp.ClientSession( loop=self._loop, connector=connector) - self._ws = await self._client_session.ws_connect(url) + self._ws = await self._client_session.ws_connect(url, headers=headers) self._connected = True async def write(self, message): diff --git a/aiogremlin/driver/cluster.py b/aiogremlin/driver/cluster.py index 9895cda..d10325e 100644 --- a/aiogremlin/driver/cluster.py +++ b/aiogremlin/driver/cluster.py @@ -45,6 +45,7 @@ class Cluster: 'ssl_certfile': '', 'ssl_keyfile': '', 'ssl_password': '', + 'headers': None, 'username': '', 'password': '', 'response_timeout': None, diff --git a/aiogremlin/driver/connection.py b/aiogremlin/driver/connection.py index cc018a6..ff7f52f 100644 --- a/aiogremlin/driver/connection.py +++ b/aiogremlin/driver/connection.py @@ -58,6 +58,7 @@ class Connection: protocol=None, transport_factory=None, ssl_context=None, + headers=None, username='', password='', max_inflight=64, @@ -73,6 +74,7 @@ class Connection: Protocol implementation :param transport_factory: Factory function for transports :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth @@ -89,7 +91,7 @@ class Connection: if not transport_factory: transport_factory = lambda: AiohttpTransport(loop) transport = transport_factory() - await transport.connect(url, ssl_context=ssl_context) + await transport.connect(url, ssl_context=ssl_context, headers=headers) return cls(url, transport, protocol, loop, username, password, max_inflight, response_timeout, message_serializer, provider) diff --git a/aiogremlin/driver/pool.py b/aiogremlin/driver/pool.py index 793d82e..9505d44 100644 --- a/aiogremlin/driver/pool.py +++ b/aiogremlin/driver/pool.py @@ -81,6 +81,7 @@ class ConnectionPool: :param str url: url for host Gremlin Server :param asyncio.BaseEventLoop loop: :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth :param float response_timeout: (optional) `None` by default @@ -92,12 +93,13 @@ class ConnectionPool: one time on the connection """ - def __init__(self, url, loop, ssl_context, username, password, max_conns, + def __init__(self, url, loop, ssl_context, headers, username, password, max_conns, min_conns, max_times_acquired, max_inflight, response_timeout, message_serializer, provider): self._url = url self._loop = loop self._ssl_context = ssl_context + self._headers = headers self._username = username self._password = password self._max_conns = max_conns @@ -194,7 +196,7 @@ class ConnectionPool: async def _get_connection(self, username, password, max_inflight, response_timeout, message_serializer, provider): conn = await connection.Connection.open( - self._url, self._loop, ssl_context=self._ssl_context, + self._url, self._loop, ssl_context=self._ssl_context, headers=self._headers, username=username, password=password, response_timeout=response_timeout, message_serializer=message_serializer, provider=provider) diff --git a/aiogremlin/driver/server.py b/aiogremlin/driver/server.py index 5a17390..2c3272c 100644 --- a/aiogremlin/driver/server.py +++ b/aiogremlin/driver/server.py @@ -35,6 +35,7 @@ class GremlinServer: self._ssl_context = ssl_context else: self._ssl_context = None + self._headers = config['headers'] @property def url(self): @@ -66,8 +67,8 @@ class GremlinServer: async def initialize(self): conn_pool = pool.ConnectionPool( - self._url, self._loop, self._ssl_context, self._username, - self._password, self._max_conns, self._min_conns, + self._url, self._loop, self._ssl_context, self._headers, + self._username, self._password, self._max_conns, self._min_conns, self._max_times_acquired, self._max_inflight, self._response_timeout, self._message_serializer, self._provider) await conn_pool.init_pool() @@ -81,6 +82,7 @@ class GremlinServer: :param str url: url for host Gremlin Server :param asyncio.BaseEventLoop loop: :param ssl.SSLContext ssl_context: + :param dict(str, str) headers: :param str username: Username for database auth :param str password: Password for database auth :param float response_timeout: (optional) `None` by default diff --git a/setup.py b/setup.py index 8d892d5..7a14f3f 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( 'aiogremlin.remote'], python_requires='>=3.5', install_requires=[ - 'gremlinpython<=3.4.3', + 'gremlinpython>=3.4.8', 'aenum>=1.4.5', # required gremlinpython dep 'aiohttp>=2.2.5', 'PyYAML>=3.12', -- GitLab