diff --git a/goblin/api.py b/goblin/api.py index 54f92e9517857dbb5ebb06c6512ff58fef162952..ece472224136c2b35def526a9d63edf9bb061067 100644 --- a/goblin/api.py +++ b/goblin/api.py @@ -92,6 +92,7 @@ class Session: def __init__(self, engine, *, use_session=False): self._engine = engine + self._loop = self._engine._loop self._use_session = False self._session = None self._traversal = traversal.TraversalSource(self.engine.translator) diff --git a/goblin/query.py b/goblin/query.py index d378e429cd2201148e6381100ca485b8d7ae6313..9b6d9fa36de9cd574cea5f97374afe6e41c2802c 100644 --- a/goblin/query.py +++ b/goblin/query.py @@ -1,5 +1,5 @@ """Query API and helpers""" -import collections +import asyncio import logging from goblin import mapper @@ -13,31 +13,24 @@ def parse_traversal(traversal): return script, bindings -class AsyncQueryResponseIter: +class QueryResponse: - def __init__(self, async_iter, query): - self._async_iter = async_iter - self._query = query - self._queue = collections.deque() + def __init__(self, response_queue): + self._queue = response_queue + self._done = False async def __aiter__(self): return self async def __anext__(self): - if not self._queue: - msg = await self._async_iter.fetch_data() - if msg: - results = msg.data - for result in results: - current = self._query.session.current.get(result['id'], None) - if not current: - current = self._query._element_class() - element = self._query._mapper(result, current, - current.__mapping__) - self._queue.append(element) - else: - raise StopAsyncIteration - return self._queue.popleft() + if self._done: + return + msg = await self._queue.get() + if msg: + return msg + else: + self._done = True + raise StopAsyncIteration class Query: @@ -46,6 +39,7 @@ class Query: self._session = session self._engine = session.engine self._element_class = element_class + self._loop = self._session._loop if element_class.__type__ == 'vertex': self._traversal = self.session.traversal.g.V().hasLabel( element_class.__mapping__.label) @@ -70,4 +64,18 @@ class Query: async def all(self): """Get all results generated by query""" async_iter = await self.session.execute_traversal(self._traversal) - return AsyncQueryResponseIter(async_iter, self) + response_queue = asyncio.Queue(loop=self._loop) + self._loop.create_task(self._receive(async_iter, response_queue)) + return QueryResponse(response_queue) + + async def _receive(self, async_iter, response_queue): + async for msg in async_iter: + results = msg.data + for result in results: + current = self.session.current.get(result['id'], None) + if not current: + current = self._element_class() + element = self._mapper(result, current, + current.__mapping__) + response_queue.put_nowait(element) + response_queue.put_nowait(None)