diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 2cbf049e6503ea825af90dbbcde1f2b2f8063cc0..50373c588efec2e8ccbf2915dbddf54e129be3b4 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -14,15 +14,12 @@ Message = collections.namedtuple( ["status_code", "data", "message", "metadata"]) -class AsyncResponseIter: +class Response: - def __init__(self, response_queue, loop, conn, username, password, - processor, session): + def __init__(self, response_queue, loop): self._response_queue = response_queue self._loop = loop - self._conn = conn - self._force_close = self._conn.force_close - self._force_release = self._conn.force_release + self._done = False async def __aiter__(self): return self @@ -35,17 +32,14 @@ class AsyncResponseIter: raise StopAsyncIteration async def fetch_data(self): - if not self._response_queue.empty(): - message = self._response_queue.get_nowait() - else: - self._loop.create_task(self._conn.get_data()) - message = await self._response_queue.get() - return message + if self._done: + return None - async def close(self): - if self._conn: - await self._conn.close() - self._conn = None + msg = await self._response_queue.get() + if msg is None: # end of response sentinel + self._done = True + + return msg class AbstractConnection(abc.ABC): @@ -143,9 +137,10 @@ class Connection(AbstractConnection): response_queue = asyncio.Queue(loop=self._loop) self.response_queues[request_id] = response_queue self._ws.send_bytes(message) - return AsyncResponseIter(response_queue, self._loop, self, - self._username, self._password, - processor, session) + self._loop.create_task(self.receive()) + return Response(response_queue, self._loop, self, + self._username, self._password, + processor, session) async def close(self): await self._ws.close() @@ -201,7 +196,7 @@ class Connection(AbstractConnection): raise ValueError("Unknown mime type.") return b"".join([mime_len, mime_type, message.encode("utf-8")]) - async def get_data(self): + async def receive(self): data = await self._ws.receive() # parse aiohttp response here message = json.loads(data.data.decode("utf-8")) @@ -210,29 +205,23 @@ class Connection(AbstractConnection): message["result"]["data"], message["status"]["message"], message["result"]["meta"]) + response_queue = self._response_queues[request_id] + if message.status_code not in (206, 407): + # this message concludes the response + await response_queue.put(None) + del self._response_queues[request_id] if message.status_code in [200, 206, 204]: - response_queue = self.response_queues[request_id] response_queue.put_nowait(message) - if message.status_code != 206: - await self.term() - response_queue.put_nowait(None) + if message.status_code == 206: + self._loop.create_task(self.receive()) elif message.status_code == 407: self._authenticate(self._username, self._password, self._processor, self._session) - message = await self.fetch_data() + self._loop.create_task(self.receive()) else: - await self.term() raise RuntimeError("{0} {1}".format(message.status_code, message.message)) - async def term(self): - self.remove_inflight() - self.semaphore.release() - if self._force_close: - await self.close() - elif self._force_release: - await self.release() - async def __aenter__(self): return self