diff --git a/goblin/driver/api.py b/goblin/driver/api.py index b3fecf9042709515152bbe291b261e980a3bb8da..1c12b3ea74362623e57b516058c4561609860758 100644 --- a/goblin/driver/api.py +++ b/goblin/driver/api.py @@ -12,29 +12,24 @@ class GremlinServer: url: str, loop: asyncio.BaseEventLoop, *, - conn_factory: aiohttp.ClientSession=None, - max_inflight: int=None, + client_session: aiohttp.ClientSession=None, force_close: bool=False, - force_release: bool=False, - pool: pool.Pool=None, username: str=None, password: str=None) -> connection.Connection: - if conn_factory is None: - conn_factory = aiohttp.ClientSession(loop=loop) - ws = await conn_factory.ws_connect(url) - return connection.Connection(ws, loop, conn_factory, - max_inflight=max_inflight, + # Use connection factory here + if client_session is None: + client_session = aiohttp.ClientSession(loop=loop) + ws = await client_session.ws_connect(url) + return connection.Connection(ws, loop, client_session, force_close=force_close, - force_release=force_release, - pool=pool, username=username, - password=password) + username=username, password=password) @classmethod async def create_client(cls, url: str, loop: asyncio.BaseEventLoop, *, - conn_factory: aiohttp.ClientSession=None, + client_session: aiohttp.ClientSession=None, max_inflight: int=None, max_connections: int=None, force_close: bool=False, diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 1f249703ccc369991aed18f67b5c111eed2fe88b..063dddd26794d8abf1d05b143d0785ad8d530ed5 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -53,9 +53,9 @@ class AbstractConnection(abc.ABC): class Connection(AbstractConnection): - def __init__(self, ws, loop, conn_factory, *, max_inflight=None, - force_close=True, force_release=False, - pool=None, username=None, password=None): + def __init__(self, ws, loop, conn_factory, *, force_close=True, + force_release=False, pool=None, username=None, + password=None): self._ws = ws self._loop = loop self._conn_factory = conn_factory @@ -66,32 +66,11 @@ class Connection(AbstractConnection): self._password = password self._closed = False self._response_queues = {} - self._inflight = 0 - if not max_inflight: - max_inflight = 32 - self._max_inflight = 32 - self._semaphore = asyncio.Semaphore(self._max_inflight, - loop=self._loop) - - @property - def max_inflight(self): - return self._max_inflight - - @property - def max_inflight(self): - return self._max_inflight - - def remove_inflight(self): - self._inflight -= 1 @property def response_queues(self): return self._response_queues - @property - def semaphore(self): - return self._semaphore - @property def closed(self): return self._closed @@ -130,8 +109,6 @@ class Connection(AbstractConnection): processor, session, request_id) - await self.semaphore.acquire() - self._inflight += 1 response_queue = asyncio.Queue(loop=self._loop) self.response_queues[request_id] = response_queue self._ws.send_bytes(message) @@ -216,8 +193,6 @@ class Connection(AbstractConnection): message.message)) async def term(self): - self.remove_inflight() - self.semaphore.release() if self._force_close: await self.close() elif self._force_release: