From fbaf4a4257cfd7e1a57d9cfbb22d54303b2a0495 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Tue, 5 Jul 2016 10:32:44 -0400
Subject: [PATCH] working on driver inflight control

---
 goblin/api.py                          |  12 +-
 goblin/gremlin_python_driver/driver.py | 159 +++++++++++++++++--------
 tests/test_driver.py                   |   3 +-
 3 files changed, 118 insertions(+), 56 deletions(-)

diff --git a/goblin/api.py b/goblin/api.py
index b3ac5f0..65408ad 100644
--- a/goblin/api.py
+++ b/goblin/api.py
@@ -25,23 +25,23 @@ async def create_engine(url,
     driver = gremlin_python_driver.Driver(url, loop)
     async with driver.get() as conn:
         # Propbably just use a parser to parse the whole feature list
-        stream = conn.submit(
+        stream = await conn.submit(
             'graph.features().graph().supportsComputer()')
         msg = await stream.fetch_data()
         features['computer'] = msg.data[0]
-        stream = conn.submit(
+        stream = await conn.submit(
             'graph.features().graph().supportsTransactions()')
         msg = await stream.fetch_data()
         features['transactions'] = msg.data[0]
-        stream = conn.submit(
+        stream = await conn.submit(
             'graph.features().graph().supportsPersistence()')
         msg = await stream.fetch_data()
         features['persistence'] = msg.data[0]
-        stream = conn.submit(
+        stream = await conn.submit(
             'graph.features().graph().supportsConcurrentAccess()')
         msg = await stream.fetch_data()
         features['concurrent_access'] = msg.data[0]
-        stream = conn.submit(
+        stream = await conn.submit(
             'graph.features().graph().supportsThreadedTransactions()')
         msg = await stream.fetch_data()
         features['threaded_transactions'] = msg.data[0]
@@ -83,7 +83,7 @@ class Engine:
 
     async def execute(self, query, *, bindings=None, session=None):
         conn = await self.driver.recycle()
-        return conn.submit(query, bindings=bindings)
+        return await conn.submit(query, bindings=bindings)
 
     async def close(self):
         await self.driver.close()
diff --git a/goblin/gremlin_python_driver/driver.py b/goblin/gremlin_python_driver/driver.py
index 7313409..8379d77 100644
--- a/goblin/gremlin_python_driver/driver.py
+++ b/goblin/gremlin_python_driver/driver.py
@@ -25,12 +25,17 @@ class Driver:
             client_session = aiohttp.ClientSession(loop=self._loop)
         self._client_session = client_session
         self._reclaimed = collections.deque()
+        self._driver_condition = asyncio.Condition(loop=loop)
         self._open_connections = 0
         self._inflight_messages = 0
         self._connecting = 0
-        self._max_connections = 32
+        self._max_connections = 4
         self._max_inflight_messages = 128
 
+    @property
+    def driver_condition(self):
+        return self._driver_condition
+
     @property
     def max_connections(self):
         return self._max_connections
@@ -43,38 +48,71 @@ class Driver:
     def total_connections(self):
         return self._connecting + self._open_connections
 
+    @property
+    def inflight_messages(self):
+        return self._inflight_messages
+
+    def add_inflight(self):
+        self._inflight_messages += 1
+
+    def remove_inflight(self):
+        self._inflight_messages -= 1
+
     def get(self):
         return AsyncDriverConnectionContextManager(self)
 
-    async def connect(self, *, force_close=True, recycle=False):
+    async def submit(self,
+                     gremlin,
+                     *,
+                     bindings=None,
+                     lang='gremlin-groovy',
+                     aliases=None,
+                     op="eval",
+                     processor="",
+                     session=None,
+                     request_id=None):
+        pass
+
+
+    async def connect(self, *, force_close=True, force_reclaim=False):
+        async with self.driver_condition:
+            conn = await self._get_new_connection(force_close, force_reclaim)
+            return conn
+
+    async def _get_new_connection(self, force_close, force_reclaim):
         if self.total_connections <= self._max_connections:
             self._connecting += 1
             try:
                 ws = await self._client_session.ws_connect(self._url)
                 self._open_connections +=1
                 return Connection(ws, self._loop, force_close=force_close,
-                                  recycle=recycle, driver=self)
+                                  force_reclaim=force_reclaim, driver=self)
             finally:
                 self._connecting -= 1
         else:
             raise RuntimeError("To many connections, try recycling")
 
-    async def recycle(self, *, force_close=False, recycle=True):
-        if self._reclaimed:
-            while self._reclaimed:
-                conn = self._reclaimed.popleft()
-                if not conn.closed:
-                    logger.info("Reusing connection: {}".format(conn))
-                    break
+    async def recycle(self, *, force_close=False, force_reclaim=True):
+        async with self.driver_condition:
+            while True:
+                if self._reclaimed:
+                    while self._reclaimed:
+                        conn = self._reclaimed.popleft()
+                        if not conn.closed:
+                            logger.info("Reusing connection: {}".format(conn))
+                            return conn
+                        else:
+                            self._open_connections -= 1
+                            logger.debug(
+                                "Discarded closed connection: {}".format(conn))
+                elif self.total_connections < self.max_connections:
+                    conn = await self._get_new_connection(force_close,
+                                                          force_reclaim)
+                    logger.info("Acquired new connection: {}".format(conn))
+                    return conn
                 else:
-                    self._open_connections -= 1
-                    logger.debug(
-                        "Discarded closed connection: {}".format(conn))
-        elif self.total_connections < self.max_connections:
-            conn = await self.connect(force_close=force_close,
-                                      recycle=recycle)
-            logger.info("Acquired new connection: {}".format(conn))
-        return conn
+                    await self.driver_condition.wait()
+
 
     async def reclaim(self, conn):
         if self.total_connections <= self.max_connections:
@@ -86,20 +124,28 @@ class Driver:
                 conn = None
             else:
                 self._reclaimed.append(conn)
+            await self._wakeup()
         else:
             if conn.driver is self:
                 # hmmm
                 await conn.close()
                 self._open_connections -= 1
 
+    async def _wakeup(self):
+        async with self.driver_condition:
+            self.driver_condition.notify()
+
     async def close(self):
-        while self._reclaimed:
-            conn = self._reclaimed.popleft()
-            await conn.close()
-        await self._client_session.close()
-        self._client_session = None
-        self._closed = True
-        logger.debug("Driver {} has been closed".format(self))
+        async with self.driver_condition:
+            waiters = []
+            while self._reclaimed:
+                conn = self._reclaimed.popleft()
+                waiters.append(conn.close())
+            await asyncio.gather(*waiters, loop=self._loop)
+            await self._client_session.close()
+            self._client_session = None
+            self._closed = True
+            logger.debug("Driver {} has been closed".format(self))
 
 
 class AsyncDriverConnectionContextManager:
@@ -130,7 +176,7 @@ class AsyncResponseIter:
         self._processor = processor
         self._session = session
         self._force_close = self._conn.force_close
-        self._recycle = self._conn.recycle
+        self._force_reclaim = self._conn.force_reclaim
         self._closed = False
         self._response_queue = asyncio.Queue(loop=loop)
 
@@ -181,24 +227,32 @@ class AsyncResponseIter:
                                                 message.message))
 
     async def term(self):
+        async with self._conn.conn_condition:
+            self._conn.driver.remove_inflight()
+            self._conn.conn_condition.notify()
         self._closed = True
         if self._force_close:
             await self.close()
-        elif self._recycle:
+        elif self._force_reclaim:
             await self._conn.reclaim()
 
 class Connection:
 
-    def __init__(self, ws, loop, *, force_close=True, recycle=False,
+    def __init__(self, ws, loop, *, force_close=True, force_reclaim=False,
                  driver=None, username=None, password=None):
         self._ws = ws
         self._loop = loop
         self._force_close = force_close
-        self._recycle = recycle
+        self._force_reclaim = force_reclaim
         self._driver = driver
         self._username = username
         self._password = password
         self._closed = False
+        self._conn_condition = asyncio.Condition(loop=loop)
+
+    @property
+    def conn_condition(self):
+        return self._conn_condition
 
     @property
     def closed(self):
@@ -209,8 +263,8 @@ class Connection:
         return self._force_close
 
     @property
-    def recycle(self):
-        return self._recycle
+    def force_reclaim(self):
+        return self._force_reclaim
 
     @property
     def driver(self):
@@ -220,16 +274,16 @@ class Connection:
         if self.driver:
             await self.driver.reclaim(self)
 
-    def submit(self,
-               gremlin,
-               *,
-               bindings=None,
-               lang='gremlin-groovy',
-               aliases=None,
-               op="eval",
-               processor="",
-               session=None,
-               request_id=None):
+    async def submit(self,
+                    gremlin,
+                    *,
+                    bindings=None,
+                    lang='gremlin-groovy',
+                    aliases=None,
+                    op="eval",
+                    processor="",
+                    session=None,
+                    request_id=None):
         if aliases is None:
             aliases = {}
         message = self._prepare_message(gremlin,
@@ -240,16 +294,23 @@ class Connection:
                                         processor,
                                         session,
                                         request_id)
-
-        self._ws.send_bytes(message)
-        return AsyncResponseIter(self._ws, self._loop, self, self._username,
-                                 self._password, processor, session)
+        async with self.conn_condition:
+            while True:
+                if (self.driver.inflight_messages <
+                        self.driver.max_inflight_messages):
+                    self._ws.send_bytes(message)
+                    return AsyncResponseIter(self._ws, self._loop, self,
+                                             self._username, self._password,
+                                             processor, session)
+                else:
+                    await self.driver.message_condition.wait()
 
     async def close(self):
-        await self._ws.close()
-        self._closed = True
-        self.driver._open_connections -= 1
-        self._driver = None
+        async with self.conn_condition:
+            await self._ws.close()
+            self._closed = True
+            self.driver._open_connections -= 1
+            self._driver = None
 
     def _prepare_message(self, gremlin, bindings, lang, aliases, op, processor,
                          session, request_id):
diff --git a/tests/test_driver.py b/tests/test_driver.py
index ad00e94..2499c0b 100644
--- a/tests/test_driver.py
+++ b/tests/test_driver.py
@@ -25,7 +25,8 @@ class TestDriver(unittest.TestCase):
         async def go():
             driver = Driver("http://localhost:8182/", self.loop)
             async with driver.get() as conn:
-                async for msg in conn.submit("1 + 1"):
+                stream = await conn.submit("1 + 1")
+                async for msg in stream:
                     self.assertEqual(msg.data[0], 2)
             await driver.close()
 
-- 
GitLab