diff --git a/.gitignore b/.gitignore index 72364f99fe4bf8d5262df3b19b33102aeaa791e5..e56ebb8acc0bbd50d0d5241a19049af6b0540e63 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,6 @@ ENV/ # Rope project settings .ropeproject + +# Pycharm +.idea/ diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 063dddd26794d8abf1d05b143d0785ad8d530ed5..82d4e3a686e3e2004cc00657ebfa1a2538beeeec 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -14,12 +14,12 @@ Message = collections.namedtuple( ["status_code", "data", "message", "metadata"]) -class AsyncResponseIter: +class Response: - def __init__(self, response_queue, loop, conn): + def __init__(self, response_queue, loop): self._response_queue = response_queue self._loop = loop - self._conn = conn + self._done = False async def __aiter__(self): return self @@ -32,12 +32,14 @@ class AsyncResponseIter: raise StopAsyncIteration async def fetch_data(self): - if not self._response_queue.empty(): - message = self._response_queue.get_nowait() - else: - self._loop.create_task(self._conn.get_data()) - message = await self._response_queue.get() - return message + if self._done: + return None + + msg = await self._response_queue.get() + if msg is None: # end of response sentinel + self._done = True + + return msg class AbstractConnection(abc.ABC): @@ -112,7 +114,8 @@ class Connection(AbstractConnection): response_queue = asyncio.Queue(loop=self._loop) self.response_queues[request_id] = response_queue self._ws.send_bytes(message) - return AsyncResponseIter(response_queue, self._loop, self) + self._loop.create_task(self.receive()) + return Response(response_queue, self._loop, self) async def close(self): await self._ws.close() @@ -168,7 +171,7 @@ class Connection(AbstractConnection): raise ValueError("Unknown mime type.") return b"".join([mime_len, mime_type, message.encode("utf-8")]) - async def get_data(self): + async def receive(self): data = await self._ws.receive() # parse aiohttp response here message = json.loads(data.data.decode("utf-8")) @@ -177,18 +180,20 @@ class Connection(AbstractConnection): message["result"]["data"], message["status"]["message"], message["result"]["meta"]) + response_queue = self._response_queues[request_id] + if message.status_code not in (206, 407): + # this message concludes the response + await response_queue.put(None) + del self._response_queues[request_id] if message.status_code in [200, 206, 204]: - response_queue = self.response_queues[request_id] response_queue.put_nowait(message) - if message.status_code != 206: - await self.term() - response_queue.put_nowait(None) + if message.status_code == 206: + self._loop.create_task(self.receive()) elif message.status_code == 407: self._authenticate(self._username, self._password, self._processor, self._session) - message = await self.fetch_data() + self._loop.create_task(self.receive()) else: - await self.term() raise RuntimeError("{0} {1}".format(message.status_code, message.message))