diff --git a/aiogremlin/driver/protocol.py b/aiogremlin/driver/protocol.py index 84432f729082c4f2d4125d61b9d090ac61bd6871..97f1292d82a4b228dac40a88df756c2e9c980f5a 100644 --- a/aiogremlin/driver/protocol.py +++ b/aiogremlin/driver/protocol.py @@ -1,9 +1,8 @@ +import asyncio import base64 import collections import logging -import aiohttp - try: import ujson as json except ImportError: @@ -35,53 +34,43 @@ class GremlinServerWSProtocol(protocol.AbstractBaseProtocol): def connection_made(self, transport): self._transport = transport - def write(self, request_id, request_message): + async def write(self, request_id, request_message): message = self._message_serializer.serialize_message( request_id, request_message) - self._transport.write(message) + func = self._transport.write(message) + if asyncio.iscoroutine(func): + await func async def data_received(self, data, results_dict): - if data.tp == aiohttp.WSMsgType.close: - await self._transport.close() - elif data.tp == aiohttp.WSMsgType.error: - # This won't raise properly, fix - raise data.data - elif data.tp == aiohttp.WSMsgType.closed: - # Hmm - pass - else: - if data.tp == aiohttp.WSMsgType.binary: - data = data.data.decode() - elif data.tp == aiohttp.WSMsgType.text: - data = data.data.strip() - message = json.loads(data) - request_id = message['requestId'] - status_code = message['status']['code'] - data = message['result']['data'] - msg = message['status']['message'] - if request_id in results_dict: - result_set = results_dict[request_id] - aggregate_to = message['result']['meta'].get('aggregateTo', - 'list') - result_set.aggregate_to = aggregate_to - if status_code == 407: - auth = b''.join([b'\x00', self._username.encode('utf-8'), - b'\x00', self._password.encode('utf-8')]) - request_message = request.RequestMessage( - 'traversal', 'authentication', - {'sasl': base64.b64encode(auth).decode()}) - self.write(request_id, request_message) - elif status_code == 204: - result_set.queue_result(None) - else: - if data: - for result in data: - result = self._message_serializer.deserialize_message(result) - message = Message(status_code, result, msg) - result_set.queue_result(message) - else: - data = self._message_serializer.deserialize_message(data) - message = Message(status_code, data, msg) + data = data.decode('utf-8') + message = json.loads(data) + request_id = message['requestId'] + status_code = message['status']['code'] + data = message['result']['data'] + msg = message['status']['message'] + if request_id in results_dict: + result_set = results_dict[request_id] + aggregate_to = message['result']['meta'].get('aggregateTo', + 'list') + result_set.aggregate_to = aggregate_to + if status_code == 407: + auth = b''.join([b'\x00', self._username.encode('utf-8'), + b'\x00', self._password.encode('utf-8')]) + request_message = request.RequestMessage( + 'traversal', 'authentication', + {'sasl': base64.b64encode(auth).decode()}) + await self.write(request_id, request_message) + elif status_code == 204: + result_set.queue_result(None) + else: + if data: + for result in data: + result = self._message_serializer.deserialize_message(result) + message = Message(status_code, result, msg) result_set.queue_result(message) - if status_code != 206: - result_set.queue_result(None) + else: + data = self._message_serializer.deserialize_message(data) + message = Message(status_code, data, msg) + result_set.queue_result(message) + if status_code != 206: + result_set.queue_result(None) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f5fd80c8093647e552bd88e8ae291f274d1c0d85 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +gremlinpython==3.2.6 diff --git a/setup.py b/setup.py index 55f988786ff21a74940b056d5514029a087fb5b9..e6004dbcbe44e8a78b812589379664722960ae27 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import setup setup( name='aiogremlin', - version='3.2.6rc1', + version='3.2.6rc2', url='', license='Apache Software License', author='davebshow',