From fd312df9c07feb914cf7e152c04e8ff0c290b68d Mon Sep 17 00:00:00 2001 From: Leifur Halldor Asgeirsson <lasgeirsson@zerofail.com> Date: Tue, 1 Nov 2016 13:32:34 -0400 Subject: [PATCH] SASL auth --- goblin/driver/connection.py | 13 ++++----- tests/test_connection.py | 57 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index f7a5040..4acf871 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -17,6 +17,7 @@ import abc import asyncio +import base64 import collections import functools import logging @@ -230,14 +231,13 @@ class Connection(AbstractConnection): self._loop.create_task(self._terminate_response(resp, request_id)) return resp - def _authenticate(self, username, password, session): + def _authenticate(self, username, password, request_id): auth = b''.join([b'\x00', username.encode('utf-8'), b'\x00', password.encode('utf-8')]) - request_id = str(uuid.uuid4()) - args = {'sasl': base64.b64encode(auth).decode()} + args = {'sasl': base64.b64encode(auth).decode(), 'saslMechanism': 'PLAIN'} message = self._message_serializer.serialize_message( request_id, '', 'authentication', **args) - self._ws.send_bytes(message, binary=True) + self._ws.send_bytes(message) async def close(self): """**coroutine** Close underlying connection and mark as closed.""" @@ -264,7 +264,7 @@ class Connection(AbstractConnection): if data.tp == aiohttp.MsgType.binary: data = data.data.decode() elif data.tp == aiohttp.MsgType.text: - data = data.strip() + data = data.data.strip() message = json.loads(data) request_id = message['requestId'] status_code = message['status']['code'] @@ -272,8 +272,7 @@ class Connection(AbstractConnection): msg = message['status']['message'] response_queue = self._response_queues[request_id] if status_code == 407: - await self._authenticate(self._username, self._password, - self._processor) + self._authenticate(self._username, self._password, request_id) elif status_code == 204: response_queue.put_nowait(None) else: diff --git a/tests/test_connection.py b/tests/test_connection.py index 5ed44f6..39d6ce6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -15,8 +15,15 @@ # You should have received a copy of the GNU Affero General Public License # along with Goblin. If not, see <http://www.gnu.org/licenses/>. import asyncio +import json + +import base64 import pytest +import aiohttp +from aiohttp import web + +from goblin import driver from goblin import exception @@ -101,3 +108,53 @@ async def test_connection_response_timeout(connection): stream = await connection.submit(gremlin="1 + 1") async for msg in stream: pass + + +@pytest.mark.asyncio +async def test_authenticated_connection(event_loop, unused_tcp_port): + authentication_request_queue = asyncio.Queue(loop=event_loop) + + username, password = 'test_username', 'test_password' + + async def fake_auth(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive() + data = json.loads(msg.data.decode()[17:]) + await authentication_request_queue.put(data) + + auth_resp = { + "requestId": data["requestId"], + "status": {"code": 407, "attributes": {}, "message": ""}, + "result": {"data": None, "meta": {}} + } + resp_payload = json.dumps(auth_resp) + ws.send_str(resp_payload) + + auth_msg = await ws.receive() + auth_msg_data = json.loads(auth_msg.data.decode()[17:]) + await authentication_request_queue.put(auth_msg_data) + + return ws + + aiohttp_app = web.Application(loop=event_loop) + aiohttp_app.router.add_route('GET', '/gremlin', fake_auth) + handler = aiohttp_app.make_handler() + srv = await event_loop.create_server(handler, '0.0.0.0', unused_tcp_port) + + async with aiohttp.ClientSession(loop=event_loop) as session: + url = 'ws://0.0.0.0:{}/gremlin'.format(unused_tcp_port) + async with session.ws_connect(url) as ws_client: + connection = driver.Connection( + url=url, ws=ws_client, loop=event_loop, client_session=session, + username=username, password=password, max_inflight=64, response_timeout=None, + message_serializer=driver.GraphSONMessageSerializer + ) + event_loop.create_task(connection.submit(gremlin="1+1")) + initial_request = await authentication_request_queue.get() + auth_request = await authentication_request_queue.get() + print(auth_request) + auth_str = auth_request['args']['sasl'] + assert base64.b64decode(auth_str).decode().split('\x00')[1:] == [username, password] + assert auth_request['requestId'] == initial_request['requestId'] -- GitLab