Commit 075a9f13 authored by davebshow's avatar davebshow
Browse files

starting to decouple code from aiohttp

parent 4317096c
......@@ -22,7 +22,22 @@ class AiohttpTransport(transport.AbstractBaseTransport):
self._ws.send_bytes(message)
async def read(self):
return await self._ws.receive()
data = await self._ws.receive()
if data.tp == aiohttp.WSMsgType.close:
await self._transport.close()
raise RuntimeError("Connection closed by server")
elif data.tp == aiohttp.WSMsgType.error:
# This won't raise properly, fix
raise data.data
elif data.tp == aiohttp.WSMsgType.closed:
# Hmm
raise RuntimeError("Connection closed by server")
elif data.tp == aiohttp.WSMsgType.text:
# Should return bytes
data = data.data.strip().encode('utf-8')
else:
data = data.data
return data
async def close(self):
if self._connected:
......
import abc
import asyncio
import base64
import collections
import logging
import uuid
import aiohttp
try:
import ujson as json
except ImportError:
......@@ -135,7 +130,9 @@ class Connection:
request_id, message)
if self._transport.closed:
await self._transport.connect(self.url)
self._transport.write(message)
func = self._transport.write(message)
if asyncio.iscoroutine(func):
await func
result_set = resultset.ResultSet(request_id, self._response_timeout,
self._loop)
self._result_sets[request_id] = result_set
......
......@@ -21,8 +21,6 @@ import pytest
from gremlin_python import statics
from gremlin_python.statics import long
from aiogremlin.remote.driver_remote_connection import (
DriverRemoteConnection)
from gremlin_python.process.traversal import Traverser
from gremlin_python.process.traversal import TraversalStrategy
from gremlin_python.process.graph_traversal import __
......@@ -35,6 +33,12 @@ __author__ = 'Marko A. Rodriguez (http://markorodriguez.com)'
class TestDriverRemoteConnection(object):
@pytest.mark.asyncio
async def test_label(self, remote_connection):
statics.load_statics(globals())
g = Graph().traversal().withRemote(remote_connection)
result = await g.V().limit(1).toList()
@pytest.mark.asyncio
async def test_traversals(self, remote_connection):
statics.load_statics(globals())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment