diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 671454aece9329e0990250e6fb48c2e4f6f90191..2f2bd52947c098d4d3f0c667f3e2192eda534586 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -153,8 +153,6 @@ class Connection(AbstractConnection): self._receive_task = self._loop.create_task(self._receive()) self._semaphore = asyncio.Semaphore(value=max_inflight, loop=self._loop) - if isinstance(message_serializer, type): - message_serializer = message_serializer() self._message_serializer = message_serializer @classmethod diff --git a/goblin/driver/graph.py b/goblin/driver/graph.py index 4720b8ddcb7bf9b1e3d986b9498e27d943130b1c..b302222597998da228c325c56972fd6e20d53a03 100644 --- a/goblin/driver/graph.py +++ b/goblin/driver/graph.py @@ -40,8 +40,8 @@ class AsyncRemoteTraversalSideEffects(RemoteTraversalSideEffects): class AsyncRemoteStrategy(RemoteStrategy): async def apply(self, traversal): - if isinstance(self.remote_connection.message_serializer, - GraphSON2MessageSerializer): + serializer = self.remote_connection.message_serializer + if serializer is GraphSON2MessageSerializer: processor = 'traversal' op = 'bytecode' side_effects = AsyncRemoteTraversalSideEffects diff --git a/goblin/driver/serializer.py b/goblin/driver/serializer.py index 5205c659293db487c1be71ee18da49ec10975494..474686bc5d7baa28ccd75f783f5420c561e0b6f7 100644 --- a/goblin/driver/serializer.py +++ b/goblin/driver/serializer.py @@ -57,37 +57,42 @@ class GraphSONMessageSerializer: pass - def get_processor(self, processor): - processor = getattr(self, processor, None) + @classmethod + def get_processor(cls, processor): + processor = getattr(cls, processor, None) if not processor: raise Exception("Unknown processor") return processor() - def serialize_message(self, request_id, processor, op, **args): + @classmethod + def serialize_message(cls, request_id, processor, op, **args): if not processor: - processor_obj = self.get_processor('standard') + processor_obj = cls.get_processor('standard') else: - processor_obj = self.get_processor(processor) + processor_obj = cls.get_processor(processor) op_method = processor_obj.get_op(op) args = op_method(args) - message = self.build_message(request_id, processor, op, args) + message = cls.build_message(request_id, processor, op, args) return message - def build_message(self, request_id, processor, op, args): + @classmethod + def build_message(cls, request_id, processor, op, args): message = { 'requestId': request_id, 'processor': processor, 'op': op, 'args': args } - return self.finalize_message(message, b'\x10', b'application/json') + return cls.finalize_message(message, b'\x10', b'application/json') - def finalize_message(self, message, mime_len, mime_type): + @classmethod + def finalize_message(cls, message, mime_len, mime_type): message = json.dumps(message) message = b''.join([mime_len, mime_type, message.encode('utf-8')]) return message - def deserialize_message(self, message): + @classmethod + def deserialize_message(cls, message): return Traverser(message) @@ -131,17 +136,19 @@ class GraphSON2MessageSerializer(GraphSONMessageSerializer): args['sideEffect'] = {'@type': 'g:UUID', '@value': side_effect} return args - def build_message(self, request_id, processor, op, args): + @classmethod + def build_message(cls, request_id, processor, op, args): message = { 'requestId': {'@type': 'g:UUID', '@value': request_id}, 'processor': processor, 'op': op, 'args': args } - return self.finalize_message(message, b"\x21", + return cls.finalize_message(message, b"\x21", b"application/vnd.gremlin-v2.0+json") - def deserialize_message(self, message): + @classmethod + def deserialize_message(cls, message): if isinstance(message, dict): if message.get('@type', '') == 'g:Traverser': obj = GraphSONReader._objectify(message) diff --git a/tests/test_graph.py b/tests/test_graph.py index 2967b641d90804e809f74a6c7d6ab2109bb28ddc..73ecd5324d34dbaa279c21beb736735a6c2cea95 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -57,7 +57,7 @@ async def test_submit_traversal(event_loop, remote_graph, connection): @pytest.mark.asyncio async def test_side_effects(remote_graph, connection): async with connection: - connection._message_serializer = serializer.GraphSON2MessageSerializer() + connection._message_serializer = serializer.GraphSON2MessageSerializer g = remote_graph.traversal().withRemote(connection) # create some nodes resp = g.addV('person').property('name', 'leifur') @@ -87,3 +87,41 @@ async def test_side_effects(remote_graph, connection): async for msg in resp: side_effects.append(msg) assert side_effects + + +@pytest.mark.asyncio +async def test_side_effects_with_client(event_loop, remote_graph): + cluster = await driver.Cluster.open(event_loop) + client = await cluster.connect() + + g = remote_graph.traversal().withRemote(client) + # create some nodes + resp = g.addV('person').property('name', 'leifur') + leif = await resp.next() + resp.traversers.close() + resp = g.addV('person').property('name', 'dave') + dave = await resp.next() + resp.traversers.close() + resp = g.addV('person').property('name', 'jon') + jonthan = await resp.next() + resp.traversers.close() + traversal = g.V().aggregate('a').aggregate('b') + async for msg in traversal: + pass + keys = [] + resp = await traversal.side_effects.keys() + async for msg in resp: + keys.append(msg) + assert keys == ['a', 'b'] + side_effects = [] + resp = await traversal.side_effects.get('a') + async for msg in resp: + side_effects.append(msg) + assert side_effects + side_effects = [] + resp = await traversal.side_effects.get('b') + async for msg in resp: + side_effects.append(msg) + assert side_effects + + await cluster.close()