From 0bd35d3c0db7f44217af203b02f074340ea74fbe Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Mon, 22 Jan 2018 16:10:04 -0800 Subject: [PATCH] cleaner message deserialzation --- aiogremlin/driver/protocol.py | 27 +++++-------------- .../driver/test_driver_remote_connection.py | 1 + 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/aiogremlin/driver/protocol.py b/aiogremlin/driver/protocol.py index 02a3e90..5a53ef5 100644 --- a/aiogremlin/driver/protocol.py +++ b/aiogremlin/driver/protocol.py @@ -42,23 +42,15 @@ class GremlinServerWSProtocol(protocol.AbstractBaseProtocol): await func async def data_received(self, data, results_dict): - serializer_version = self._message_serializer.version data = data.decode('utf-8') - message = json.loads(data) + message = self._message_serializer.deserialize_message(json.loads(data)) request_id = message['requestId'] status_code = message['status']['code'] data = message['result']['data'] msg = message['status']['message'] if request_id in results_dict: result_set = results_dict[request_id] - if serializer_version == b"application/vnd.gremlin-v2.0+json": - aggregate_to = data['result']['meta'].get('aggregateTo', 'list') - else: - meta_aggregate_to = message['result']['meta']['@value'] - if len(meta_aggregate_to) > 1: - aggregate_to = meta_aggregate_to[1] - else: - aggregate_to = 'list' + aggregate_to = message['result']['meta'].get('aggregateTo', 'list') result_set.aggregate_to = aggregate_to if status_code == 407: @@ -72,18 +64,11 @@ class GremlinServerWSProtocol(protocol.AbstractBaseProtocol): result_set.queue_result(None) else: if data: - if serializer_version == b"application/vnd.gremlin-v2.0+json": - for result in data: - result = self._message_serializer.deserialize_message(result) - message = Message(status_code, result, msg) - result_set.queue_result(message) - else: - results = self._message_serializer.deserialize_message(data['@value']) - for result in results: - message = Message(status_code, result, msg) - result_set.queue_result(message) + for result in data: + result = self._message_serializer.deserialize_message(result) + message = Message(status_code, result, msg) + result_set.queue_result(message) else: - data = self._message_serializer.deserialize_message(data) message = Message(status_code, data, msg) result_set.queue_result(message) if status_code != 206: diff --git a/tests/test_gremlin_python/driver/test_driver_remote_connection.py b/tests/test_gremlin_python/driver/test_driver_remote_connection.py index 2cdbe84..883d4c1 100644 --- a/tests/test_gremlin_python/driver/test_driver_remote_connection.py +++ b/tests/test_gremlin_python/driver/test_driver_remote_connection.py @@ -38,6 +38,7 @@ class TestDriverRemoteConnection(object): statics.load_statics(globals()) g = Graph().traversal().withRemote(remote_connection) result = await g.V().limit(1).toList() + await remote_connection.close() @pytest.mark.asyncio async def test_traversals(self, remote_connection): -- GitLab