From 3baeccb8eb1e4a748951feea2a860bc918fca1e4 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Wed, 6 Jul 2016 00:24:36 -0400
Subject: [PATCH] working on driver. using simple connection in Goblin for now

---
 goblin/api.py               | 69 ++++++++++++++++------------------
 goblin/driver/api.py        |  3 +-
 goblin/driver/connection.py | 75 +++++++++++++++++++------------------
 goblin/driver/pool.py       |  7 ----
 tests/test_driver.py        |  2 +-
 5 files changed, 74 insertions(+), 82 deletions(-)

diff --git a/goblin/api.py b/goblin/api.py
index 51c33a4..ce9d70b 100644
--- a/goblin/api.py
+++ b/goblin/api.py
@@ -21,32 +21,31 @@ async def create_engine(url,
     """Constructor function for :py:class:`Engine`. Connects to database
        and builds a dictionary of relevant vendor implmentation features"""
     features = {}
-    # Will use a driver here
-    driver = gremlin_python_driver.Driver(url, loop)
-    async with driver.get() as conn:
-        # Propbably just use a parser to parse the whole feature list
-        stream = await conn.submit(
-            'graph.features().graph().supportsComputer()')
-        msg = await stream.fetch_data()
-        features['computer'] = msg.data[0]
-        stream = await conn.submit(
-            'graph.features().graph().supportsTransactions()')
-        msg = await stream.fetch_data()
-        features['transactions'] = msg.data[0]
-        stream = await conn.submit(
-            'graph.features().graph().supportsPersistence()')
-        msg = await stream.fetch_data()
-        features['persistence'] = msg.data[0]
-        stream = await conn.submit(
-            'graph.features().graph().supportsConcurrentAccess()')
-        msg = await stream.fetch_data()
-        features['concurrent_access'] = msg.data[0]
-        stream = await conn.submit(
-            'graph.features().graph().supportsThreadedTransactions()')
-        msg = await stream.fetch_data()
-        features['threaded_transactions'] = msg.data[0]
-
-    return Engine(url, loop, driver=driver, **features)
+    # This will be some kind of manager client etc.
+    conn = await driver.GremlinServer.open(url, loop)
+    # Propbably just use a parser to parse the whole feature list
+    stream = await conn.submit(
+        'graph.features().graph().supportsComputer()')
+    msg = await stream.fetch_data()
+    features['computer'] = msg.data[0]
+    stream = await conn.submit(
+        'graph.features().graph().supportsTransactions()')
+    msg = await stream.fetch_data()
+    features['transactions'] = msg.data[0]
+    stream = await conn.submit(
+        'graph.features().graph().supportsPersistence()')
+    msg = await stream.fetch_data()
+    features['persistence'] = msg.data[0]
+    stream = await conn.submit(
+        'graph.features().graph().supportsConcurrentAccess()')
+    msg = await stream.fetch_data()
+    features['concurrent_access'] = msg.data[0]
+    stream = await conn.submit(
+        'graph.features().graph().supportsThreadedTransactions()')
+    msg = await stream.fetch_data()
+    features['threaded_transactions'] = msg.data[0]
+
+    return Engine(url, conn, loop, **features)
 
 
 # Main API classes
@@ -55,16 +54,13 @@ class Engine:
        database connections. Used as a factory to create :py:class:`Session`
        objects. More config coming soon."""
 
-    def __init__(self, url, loop, *, driver=None, force_close=True, **features):
+    def __init__(self, url, conn, loop, *, force_close=True, **features):
         self._url = url
+        self._conn = conn
         self._loop = loop
         self._force_close = force_close
         self._features = features
         self._translator = gremlin_python.GroovyTranslator('g')
-        # This will be a driver
-        if driver is None:
-            driver = gremlin_python_driver.Driver(url, loop)
-        self._driver = driver
 
     @property
     def translator(self):
@@ -75,19 +71,18 @@ class Engine:
         return self._url
 
     @property
-    def driver(self):
-        return self._driver
+    def conn(self):
+        return self._conn
 
     def session(self, *, use_session=False):
         return Session(self, use_session=use_session)
 
     async def execute(self, query, *, bindings=None, session=None):
-        conn = await self.driver.recycle()
-        return await conn.submit(query, bindings=bindings)
+        return await self._conn.submit(query, bindings=bindings)
 
     async def close(self):
-        await self.driver.close()
-        self._driver = None
+        await self.conn.close()
+        self._conn = None
 
 
 class Session:
diff --git a/goblin/driver/api.py b/goblin/driver/api.py
index 5c1a48c..8ca0c20 100644
--- a/goblin/driver/api.py
+++ b/goblin/driver/api.py
@@ -29,6 +29,7 @@ class GremlinServer:
                                      force_release=force_release,
                                      pool=pool, username=username,
                                      password=password)
+
     @classmethod
     async def create_client(cls,
                             url: str,
@@ -36,7 +37,7 @@ class GremlinServer:
                             *,
                             conn_factory: aiohttp.ClientSession=None,
                             max_inflight: int=None,
-                            max_connections: in=None,
+                            max_connections: int=None,
                             force_close: bool=False,
                             force_release: bool=False,
                             pool: pool.Pool=None,
diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py
index 5abf371..ee19abd 100644
--- a/goblin/driver/connection.py
+++ b/goblin/driver/connection.py
@@ -1,12 +1,24 @@
 import abc
 import asyncio
+import collections
+import json
+import logging
+import uuid
+
+
+logger = logging.getLogger(__name__)
+
+
+Message = collections.namedtuple(
+    "Message",
+    ["status_code", "data", "message", "metadata"])
 
 
 class AsyncResponseIter:
 
     def __init__(self, response_queue, loop, conn, username, password,
                  processor, session):
-        self._response_queue = self.response_queue
+        self._response_queue = response_queue
         self._loop = loop
         self._conn = conn
         self._force_close = self._conn.force_close
@@ -35,15 +47,6 @@ class AsyncResponseIter:
             await self._conn.close()
             self._conn = None
 
-    async def term(self):
-        self._conn.remove_inflight()
-        async with self._conn.condition:
-            self._conn.condition.notify()
-        if self._force_close:
-            await self.close()
-        elif self._force_release:
-            await self._conn.release()
-
 
 class AbstractConnection(abc.ABC):
 
@@ -55,10 +58,6 @@ class AbstractConnection(abc.ABC):
     async def close(self):
         raise NotImplementedError
 
-    @abc.abstractproperty
-    def condition(self):
-        return self._condition
-
     @abc.abstractproperty
     def closed(self):
         return self._closed
@@ -86,12 +85,13 @@ class Connection(AbstractConnection):
         self._username = username
         self._password = password
         self._closed = False
-        self._condition = asyncio.Condition(loop=loop)
         self._response_queues = {}
         self._inflight = 0
         if not max_inflight:
             max_inflight = 32
         self._max_inflight = 32
+        self._semaphore = asyncio.Semaphore(self._max_inflight,
+                                            loop=self._loop)
 
     @property
     def max_inflight(self):
@@ -109,8 +109,8 @@ class Connection(AbstractConnection):
         return self._response_queues
 
     @property
-    def condition(self):
-        return super().condition
+    def semaphore(self):
+        return self._semaphore
 
     @property
     def closed(self):
@@ -140,6 +140,8 @@ class Connection(AbstractConnection):
                     request_id=None):
         if aliases is None:
             aliases = {}
+        if request_id is None:
+            request_id = str(uuid.uuid4())
         message = self._prepare_message(gremlin,
                                         bindings,
                                         lang,
@@ -148,30 +150,23 @@ class Connection(AbstractConnection):
                                         processor,
                                         session,
                                         request_id)
-        async with self.condition:
-            while True:
-                if (self.inflight < self.max_inflight):
-                    self._inflight += 1
-                    self.response_queues[request_id] = asyncio.Queue(
-                        loop=self._loop)
-                    self._ws.send_bytes(message)
-                    return AsyncResponseIter(request_id, self._loop, self,
-                                             self._username, self._password,
-                                             processor, session)
-                else:
-                    await self.condition.wait()
+        await self.semaphore.acquire()
+        self._inflight += 1
+        response_queue = asyncio.Queue(loop=self._loop)
+        self.response_queues[request_id] = response_queue
+        self._ws.send_bytes(message)
+        return AsyncResponseIter(response_queue, self._loop, self,
+                                 self._username, self._password,
+                                 processor, session)
 
     async def close(self):
-        async with self.condition:
-            await self._ws.close()
-            self._closed = True
-            self._pool = None
+        await self._ws.close()
+        self._closed = True
+        self._pool = None
         await self._conn_factory.close()
 
     def _prepare_message(self, gremlin, bindings, lang, aliases, op, processor,
                          session, request_id):
-        if request_id is None:
-            request_id = str(uuid.uuid4())
         message = {
             "requestId": request_id,
             "op": op,
@@ -222,7 +217,7 @@ class Connection(AbstractConnection):
         data = await self._ws.receive()
         # parse aiohttp response here
         message = json.loads(data.data.decode("utf-8"))
-        request_id = message['request_id']
+        request_id = message['requestId']
         message = Message(message["status"]["code"],
                           message["result"]["data"],
                           message["status"]["message"],
@@ -242,6 +237,14 @@ class Connection(AbstractConnection):
             raise RuntimeError("{0} {1}".format(message.status_code,
                                                 message.message))
 
+    async def term(self):
+        self.remove_inflight()
+        self.semaphore.release()
+        if self._force_close:
+            await self.close()
+        elif self._force_release:
+            await self.release()
+
     async def __aenter__(self):
         return self
 
diff --git a/goblin/driver/pool.py b/goblin/driver/pool.py
index 36a6a3b..461082c 100644
--- a/goblin/driver/pool.py
+++ b/goblin/driver/pool.py
@@ -1,9 +1,7 @@
 """Simple Async driver for the TinkerPop3 Gremlin Server"""
 import asyncio
 import collections
-import json
 import logging
-import uuid
 
 import aiohttp
 
@@ -11,11 +9,6 @@ import aiohttp
 logger = logging.getLogger(__name__)
 
 
-Message = collections.namedtuple(
-    "Message",
-    ["status_code", "data", "message", "metadata"])
-
-
 class Pool:
 
     def __init__(self, url, loop, *, client_session=None):
diff --git a/tests/test_driver.py b/tests/test_driver.py
index 6845756..041733c 100644
--- a/tests/test_driver.py
+++ b/tests/test_driver.py
@@ -36,7 +36,7 @@ class TestDriver(unittest.TestCase):
         async def go():
             connection = await driver.GremlinServer.open(
                 "http://localhost:8182/", self.loop)
-            stream = await conn.submit("1 + 1")
+            stream = await connection.submit("1 + 1")
             async for msg in stream:
                 self.assertEqual(msg.data[0], 2)
             await connection.close()
-- 
GitLab