From 075a9f1344a407ea8b9030cdf4881278f5f58966 Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Mon, 30 Oct 2017 15:06:36 -0700 Subject: [PATCH] starting to decouple code from aiohttp --- aiogremlin/driver/aiohttp/transport.py | 17 ++++++++++++++++- aiogremlin/driver/connection.py | 9 +++------ requirements.txt | 2 -- .../driver/test_driver_remote_connection.py | 8 ++++++-- 4 files changed, 25 insertions(+), 11 deletions(-) delete mode 100644 requirements.txt diff --git a/aiogremlin/driver/aiohttp/transport.py b/aiogremlin/driver/aiohttp/transport.py index b9d0114..7801f5b 100644 --- a/aiogremlin/driver/aiohttp/transport.py +++ b/aiogremlin/driver/aiohttp/transport.py @@ -22,7 +22,22 @@ class AiohttpTransport(transport.AbstractBaseTransport): self._ws.send_bytes(message) async def read(self): - return await self._ws.receive() + data = await self._ws.receive() + if data.tp == aiohttp.WSMsgType.close: + await self._transport.close() + raise RuntimeError("Connection closed by server") + elif data.tp == aiohttp.WSMsgType.error: + # This won't raise properly, fix + raise data.data + elif data.tp == aiohttp.WSMsgType.closed: + # Hmm + raise RuntimeError("Connection closed by server") + elif data.tp == aiohttp.WSMsgType.text: + # Should return bytes + data = data.data.strip().encode('utf-8') + else: + data = data.data + return data async def close(self): if self._connected: diff --git a/aiogremlin/driver/connection.py b/aiogremlin/driver/connection.py index d0c04e1..3c171e0 100644 --- a/aiogremlin/driver/connection.py +++ b/aiogremlin/driver/connection.py @@ -1,12 +1,7 @@ -import abc import asyncio -import base64 -import collections import logging import uuid -import aiohttp - try: import ujson as json except ImportError: @@ -135,7 +130,9 @@ class Connection: request_id, message) if self._transport.closed: await self._transport.connect(self.url) - self._transport.write(message) + func = self._transport.write(message) + if asyncio.iscoroutine(func): + await func result_set = resultset.ResultSet(request_id, self._response_timeout, self._loop) self._result_sets[request_id] = result_set diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index e746163..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -aiohttp==1.3.3 -PyYAML==3.12 diff --git a/tests/test_gremlin_python/driver/test_driver_remote_connection.py b/tests/test_gremlin_python/driver/test_driver_remote_connection.py index 0432fd5..2cdbe84 100644 --- a/tests/test_gremlin_python/driver/test_driver_remote_connection.py +++ b/tests/test_gremlin_python/driver/test_driver_remote_connection.py @@ -21,8 +21,6 @@ import pytest from gremlin_python import statics from gremlin_python.statics import long -from aiogremlin.remote.driver_remote_connection import ( - DriverRemoteConnection) from gremlin_python.process.traversal import Traverser from gremlin_python.process.traversal import TraversalStrategy from gremlin_python.process.graph_traversal import __ @@ -35,6 +33,12 @@ __author__ = 'Marko A. Rodriguez (http://markorodriguez.com)' class TestDriverRemoteConnection(object): + @pytest.mark.asyncio + async def test_label(self, remote_connection): + statics.load_statics(globals()) + g = Graph().traversal().withRemote(remote_connection) + result = await g.V().limit(1).toList() + @pytest.mark.asyncio async def test_traversals(self, remote_connection): statics.load_statics(globals()) -- GitLab