From fd312df9c07feb914cf7e152c04e8ff0c290b68d Mon Sep 17 00:00:00 2001
From: Leifur Halldor Asgeirsson <lasgeirsson@zerofail.com>
Date: Tue, 1 Nov 2016 13:32:34 -0400
Subject: [PATCH] SASL auth

---
 goblin/driver/connection.py | 13 ++++-----
 tests/test_connection.py    | 57 +++++++++++++++++++++++++++++++++++++
 2 files changed, 63 insertions(+), 7 deletions(-)

diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py
index f7a5040..4acf871 100644
--- a/goblin/driver/connection.py
+++ b/goblin/driver/connection.py
@@ -17,6 +17,7 @@
 
 import abc
 import asyncio
+import base64
 import collections
 import functools
 import logging
@@ -230,14 +231,13 @@ class Connection(AbstractConnection):
         self._loop.create_task(self._terminate_response(resp, request_id))
         return resp
 
-    def _authenticate(self, username, password, session):
+    def _authenticate(self, username, password, request_id):
         auth = b''.join([b'\x00', username.encode('utf-8'),
                          b'\x00', password.encode('utf-8')])
-        request_id = str(uuid.uuid4())
-        args = {'sasl': base64.b64encode(auth).decode()}
+        args = {'sasl': base64.b64encode(auth).decode(), 'saslMechanism': 'PLAIN'}
         message = self._message_serializer.serialize_message(
             request_id, '', 'authentication', **args)
-        self._ws.send_bytes(message, binary=True)
+        self._ws.send_bytes(message)
 
     async def close(self):
         """**coroutine** Close underlying connection and mark as closed."""
@@ -264,7 +264,7 @@ class Connection(AbstractConnection):
                 if data.tp == aiohttp.MsgType.binary:
                     data = data.data.decode()
                 elif data.tp == aiohttp.MsgType.text:
-                    data = data.strip()
+                    data = data.data.strip()
                 message = json.loads(data)
                 request_id = message['requestId']
                 status_code = message['status']['code']
@@ -272,8 +272,7 @@ class Connection(AbstractConnection):
                 msg = message['status']['message']
                 response_queue = self._response_queues[request_id]
                 if status_code == 407:
-                    await self._authenticate(self._username, self._password,
-                                             self._processor)
+                    self._authenticate(self._username, self._password, request_id)
                 elif status_code == 204:
                     response_queue.put_nowait(None)
                 else:
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 5ed44f6..39d6ce6 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -15,8 +15,15 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with Goblin.  If not, see <http://www.gnu.org/licenses/>.
 import asyncio
+import json
+
+import base64
 import pytest
 
+import aiohttp
+from aiohttp import web
+
+from goblin import driver
 from goblin import exception
 
 
@@ -101,3 +108,53 @@ async def test_connection_response_timeout(connection):
             stream = await connection.submit(gremlin="1 + 1")
             async for msg in stream:
                 pass
+
+
+@pytest.mark.asyncio
+async def test_authenticated_connection(event_loop, unused_tcp_port):
+    authentication_request_queue = asyncio.Queue(loop=event_loop)
+
+    username, password = 'test_username', 'test_password'
+
+    async def fake_auth(request):
+        ws = web.WebSocketResponse()
+        await ws.prepare(request)
+
+        msg = await ws.receive()
+        data = json.loads(msg.data.decode()[17:])
+        await authentication_request_queue.put(data)
+
+        auth_resp = {
+            "requestId": data["requestId"],
+            "status": {"code": 407, "attributes": {}, "message": ""},
+            "result": {"data": None, "meta": {}}
+        }
+        resp_payload = json.dumps(auth_resp)
+        ws.send_str(resp_payload)
+
+        auth_msg = await ws.receive()
+        auth_msg_data = json.loads(auth_msg.data.decode()[17:])
+        await authentication_request_queue.put(auth_msg_data)
+
+        return ws
+
+    aiohttp_app = web.Application(loop=event_loop)
+    aiohttp_app.router.add_route('GET', '/gremlin', fake_auth)
+    handler = aiohttp_app.make_handler()
+    srv = await event_loop.create_server(handler, '0.0.0.0', unused_tcp_port)
+
+    async with aiohttp.ClientSession(loop=event_loop) as session:
+        url = 'ws://0.0.0.0:{}/gremlin'.format(unused_tcp_port)
+        async with session.ws_connect(url) as ws_client:
+            connection = driver.Connection(
+                url=url, ws=ws_client, loop=event_loop, client_session=session,
+                username=username, password=password, max_inflight=64, response_timeout=None,
+                message_serializer=driver.GraphSONMessageSerializer
+            )
+            event_loop.create_task(connection.submit(gremlin="1+1"))
+            initial_request = await authentication_request_queue.get()
+            auth_request = await authentication_request_queue.get()
+            print(auth_request)
+            auth_str = auth_request['args']['sasl']
+            assert base64.b64decode(auth_str).decode().split('\x00')[1:] == [username, password]
+            assert auth_request['requestId'] == initial_request['requestId']
-- 
GitLab