From aaa87f83c5d38df458b85a3c6096192c3a585c05 Mon Sep 17 00:00:00 2001
From: Guy Rozendorn <guy@rzn.co.il>
Date: Mon, 10 Aug 2020 21:24:31 +0300
Subject: [PATCH] Feat(other): add support for custom request headers

Add support for custom request headers

ISSUES CLOSED: #2
---
 CONTRIBUTORS.md                        | 1 +
 aiogremlin/driver/aiohttp/transport.py | 4 ++--
 aiogremlin/driver/cluster.py           | 1 +
 aiogremlin/driver/connection.py        | 4 +++-
 aiogremlin/driver/pool.py              | 6 ++++--
 aiogremlin/driver/server.py            | 6 ++++--
 setup.py                               | 2 +-
 7 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 87cbcc5..9254b59 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -2,3 +2,4 @@
 
 * Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com> - Project owner from 2019 to present.
 * David M. Brown <davebshow@gmail.com> - Project founder and project owner from 2016 - 2019.
+* Guy Rozendorn <guy@rzn.co.il> 
diff --git a/aiogremlin/driver/aiohttp/transport.py b/aiogremlin/driver/aiohttp/transport.py
index 3fe33ed..620131d 100644
--- a/aiogremlin/driver/aiohttp/transport.py
+++ b/aiogremlin/driver/aiohttp/transport.py
@@ -10,13 +10,13 @@ class AiohttpTransport(transport.AbstractBaseTransport):
         self._loop = loop
         self._connected = False
 
-    async def connect(self, url, *, ssl_context=None):
+    async def connect(self, url, *, ssl_context=None, headers=None):
         await self.close()
         connector = aiohttp.TCPConnector(
             ssl_context=ssl_context, loop=self._loop)
         self._client_session = aiohttp.ClientSession(
             loop=self._loop, connector=connector)
-        self._ws = await self._client_session.ws_connect(url)
+        self._ws = await self._client_session.ws_connect(url, headers=headers)
         self._connected = True
 
     async def write(self, message):
diff --git a/aiogremlin/driver/cluster.py b/aiogremlin/driver/cluster.py
index 9895cda..d10325e 100644
--- a/aiogremlin/driver/cluster.py
+++ b/aiogremlin/driver/cluster.py
@@ -45,6 +45,7 @@ class Cluster:
         'ssl_certfile': '',
         'ssl_keyfile': '',
         'ssl_password': '',
+        'headers': None,
         'username': '',
         'password': '',
         'response_timeout': None,
diff --git a/aiogremlin/driver/connection.py b/aiogremlin/driver/connection.py
index cc018a6..ff7f52f 100644
--- a/aiogremlin/driver/connection.py
+++ b/aiogremlin/driver/connection.py
@@ -58,6 +58,7 @@ class Connection:
                    protocol=None,
                    transport_factory=None,
                    ssl_context=None,
+                   headers=None,
                    username='',
                    password='',
                    max_inflight=64,
@@ -73,6 +74,7 @@ class Connection:
             Protocol implementation
         :param transport_factory: Factory function for transports
         :param ssl.SSLContext ssl_context:
+        :param dict(str, str) headers:
         :param str username: Username for database auth
         :param str password: Password for database auth
 
@@ -89,7 +91,7 @@ class Connection:
         if not transport_factory:
             transport_factory = lambda: AiohttpTransport(loop)
         transport = transport_factory()
-        await transport.connect(url, ssl_context=ssl_context)
+        await transport.connect(url, ssl_context=ssl_context, headers=headers)
         return cls(url, transport, protocol, loop, username, password,
                    max_inflight, response_timeout, message_serializer,
                    provider)
diff --git a/aiogremlin/driver/pool.py b/aiogremlin/driver/pool.py
index 793d82e..9505d44 100644
--- a/aiogremlin/driver/pool.py
+++ b/aiogremlin/driver/pool.py
@@ -81,6 +81,7 @@ class ConnectionPool:
     :param str url: url for host Gremlin Server
     :param asyncio.BaseEventLoop loop:
     :param ssl.SSLContext ssl_context:
+    :param dict(str, str) headers:
     :param str username: Username for database auth
     :param str password: Password for database auth
     :param float response_timeout: (optional) `None` by default
@@ -92,12 +93,13 @@ class ConnectionPool:
         one time on the connection
     """
 
-    def __init__(self, url, loop, ssl_context, username, password, max_conns,
+    def __init__(self, url, loop, ssl_context, headers, username, password, max_conns,
                  min_conns, max_times_acquired, max_inflight, response_timeout,
                  message_serializer, provider):
         self._url = url
         self._loop = loop
         self._ssl_context = ssl_context
+        self._headers = headers
         self._username = username
         self._password = password
         self._max_conns = max_conns
@@ -194,7 +196,7 @@ class ConnectionPool:
     async def _get_connection(self, username, password, max_inflight,
                               response_timeout, message_serializer, provider):
         conn = await connection.Connection.open(
-            self._url, self._loop, ssl_context=self._ssl_context,
+            self._url, self._loop, ssl_context=self._ssl_context, headers=self._headers,
             username=username, password=password,
             response_timeout=response_timeout,
             message_serializer=message_serializer, provider=provider)
diff --git a/aiogremlin/driver/server.py b/aiogremlin/driver/server.py
index 5a17390..2c3272c 100644
--- a/aiogremlin/driver/server.py
+++ b/aiogremlin/driver/server.py
@@ -35,6 +35,7 @@ class GremlinServer:
             self._ssl_context = ssl_context
         else:
             self._ssl_context = None
+        self._headers = config['headers']
 
     @property
     def url(self):
@@ -66,8 +67,8 @@ class GremlinServer:
 
     async def initialize(self):
         conn_pool = pool.ConnectionPool(
-            self._url, self._loop, self._ssl_context, self._username,
-            self._password, self._max_conns, self._min_conns,
+            self._url, self._loop, self._ssl_context, self._headers,
+            self._username, self._password, self._max_conns, self._min_conns,
             self._max_times_acquired, self._max_inflight,
             self._response_timeout, self._message_serializer, self._provider)
         await conn_pool.init_pool()
@@ -81,6 +82,7 @@ class GremlinServer:
         :param str url: url for host Gremlin Server
         :param asyncio.BaseEventLoop loop:
         :param ssl.SSLContext ssl_context:
+        :param dict(str, str) headers:
         :param str username: Username for database auth
         :param str password: Password for database auth
         :param float response_timeout: (optional) `None` by default
diff --git a/setup.py b/setup.py
index 8d892d5..7a14f3f 100644
--- a/setup.py
+++ b/setup.py
@@ -32,7 +32,7 @@ setup(
               'aiogremlin.remote'],
     python_requires='>=3.5',
     install_requires=[
-        'gremlinpython<=3.4.3',
+        'gremlinpython>=3.4.8',
         'aenum>=1.4.5',  # required gremlinpython dep
         'aiohttp>=2.2.5',
         'PyYAML>=3.12',
-- 
GitLab