From a129fb482823349d11e1c317b1464053340cbd46 Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Tue, 12 Jul 2016 21:08:11 -0400
Subject: [PATCH] changed config and registry to app module. simplified
 traversal API

---
 goblin/__init__.py          |   2 +-
 goblin/app.py               |  94 ++++++++++++++++++
 goblin/driver/api.py        |   2 -
 goblin/driver/connection.py |  12 +--
 goblin/engine.py            |  81 ----------------
 goblin/session.py           | 189 +++++++++++++++++++++++-------------
 goblin/traversal.py         | 140 +++++++-------------------
 tests/test_engine.py        |  75 +++++++++-----
 8 files changed, 304 insertions(+), 291 deletions(-)
 create mode 100644 goblin/app.py
 delete mode 100644 goblin/engine.py

diff --git a/goblin/__init__.py b/goblin/__init__.py
index a9b21a4..4b264cc 100644
--- a/goblin/__init__.py
+++ b/goblin/__init__.py
@@ -1,3 +1,3 @@
 from goblin.element import Vertex, Edge, VertexProperty
-from goblin.engine import Engine
+from goblin.app import create_app, App
 from goblin.properties import Property, String
diff --git a/goblin/app.py b/goblin/app.py
new file mode 100644
index 0000000..2ba2426
--- /dev/null
+++ b/goblin/app.py
@@ -0,0 +1,94 @@
+"""Main OGM API classes and constructors"""
+import collections
+import logging
+
+from goblin.gremlin_python import process
+from goblin import driver
+from goblin import session
+
+
+logger = logging.getLogger(__name__)
+
+
+# Constructor API
+async def create_app(url, loop, **config):
+    """Constructor function for :py:class:`Engine`. Connects to database
+       and builds a dictionary of relevant vendor implmentation features"""
+    features = {}
+    async with await driver.GremlinServer.open(url, loop) as conn:
+        # Propbably just use a parser to parse the whole feature list
+        stream = await conn.submit(
+            'graph.features().graph().supportsComputer()')
+        msg = await stream.fetch_data()
+        features['computer'] = msg.data[0]
+        stream = await conn.submit(
+            'graph.features().graph().supportsTransactions()')
+        msg = await stream.fetch_data()
+        features['transactions'] = msg.data[0]
+        stream = await conn.submit(
+            'graph.features().graph().supportsPersistence()')
+        msg = await stream.fetch_data()
+        features['persistence'] = msg.data[0]
+        stream = await conn.submit(
+            'graph.features().graph().supportsConcurrentAccess()')
+        msg = await stream.fetch_data()
+        features['concurrent_access'] = msg.data[0]
+        stream = await conn.submit(
+            'graph.features().graph().supportsThreadedTransactions()')
+        msg = await stream.fetch_data()
+        features['threaded_transactions'] = msg.data[0]
+    return App(url, loop, features=features, **config)
+
+
+# Main API classes
+class App:
+    """Class used to encapsulate database connection configuration and generate
+       database connections. Used as a factory to create :py:class:`Session`
+       objects. More config coming soon."""
+    DEFAULT_CONFIG = {
+        'translator': process.GroovyTranslator('g')
+    }
+
+    def __init__(self, url, loop, *, features=None, **config):
+        self._url = url
+        self._loop = loop
+        self._features = features
+        self._config = self.DEFAULT_CONFIG
+        self._config.update(config)
+        self._vertices = {}
+        self._edges = {}
+
+    @property
+    def vertices(self):
+        return self._vertices
+
+    @property
+    def edges(self):
+        return self._edges
+
+    def from_file(filepath):
+        pass
+
+    def from_obj(obj):
+        pass
+
+    @property
+    def translator(self):
+        return self._config['translator']
+
+    @property
+    def url(self):
+        return self._url
+
+    def register(self, *elements):
+        for element in elements:
+            if element.__type__ == 'vertex':
+                self._vertices[element.__label__] = element
+            if element.__type__ == 'edge':
+                self._edges[element.__label__] = element
+
+    async def session(self, *, use_session=False):
+        conn = await driver.GremlinServer.open(self.url, self._loop)
+        return session.Session(self,
+                               conn,
+                               use_session=use_session)
diff --git a/goblin/driver/api.py b/goblin/driver/api.py
index 22ab80d..0898193 100644
--- a/goblin/driver/api.py
+++ b/goblin/driver/api.py
@@ -12,12 +12,10 @@ class GremlinServer:
                    loop: asyncio.BaseEventLoop,
                    *,
                    client_session: aiohttp.ClientSession=None,
-                   force_close: bool=False,
                    username: str=None,
                    password: str=None) -> connection.Connection:
         if client_session is None:
             client_session = aiohttp.ClientSession(loop=loop)
         ws = await client_session.ws_connect(url)
         return connection.Connection(url, ws, loop, client_session,
-                                     force_close=force_close,
                                      username=username, password=password)
diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py
index 987c2a1..6490341 100644
--- a/goblin/driver/connection.py
+++ b/goblin/driver/connection.py
@@ -49,13 +49,12 @@ class AbstractConnection(abc.ABC):
 
 class Connection(AbstractConnection):
 
-    def __init__(self, url, ws, loop, conn_factory, *, force_close=True,
-                 username=None, password=None):
+    def __init__(self, url, ws, loop, conn_factory, *, username=None,
+                 password=None):
         self._url = url
         self._ws = ws
         self._loop = loop
         self._conn_factory = conn_factory
-        self._force_close = force_close
         self._username = username
         self._password = password
         self._closed = False
@@ -69,10 +68,6 @@ class Connection(AbstractConnection):
     def closed(self):
         return self._closed
 
-    @property
-    def force_close(self):
-        return self._force_close
-
     @property
     def url(self):
         return self._url
@@ -186,9 +181,6 @@ class Connection(AbstractConnection):
             raise RuntimeError("{0} {1}".format(message.status_code,
                                                 message.message))
 
-    async def term(self):
-        if self._force_close:
-            await self.close()
 
     async def __aenter__(self):
         return self
diff --git a/goblin/engine.py b/goblin/engine.py
deleted file mode 100644
index 563f5d4..0000000
--- a/goblin/engine.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Main OGM API classes and constructors"""
-import collections
-import logging
-
-from goblin.gremlin_python import process
-from goblin import driver
-from goblin import session
-
-
-logger = logging.getLogger(__name__)
-
-
-# Constructor API
-async def create_engine(url,
-                        loop,
-                        force_close=False):
-    """Constructor function for :py:class:`Engine`. Connects to database
-       and builds a dictionary of relevant vendor implmentation features"""
-    features = {}
-    # This will be some kind of manager client etc.
-    conn = await driver.GremlinServer.open(url, loop)
-    # Propbably just use a parser to parse the whole feature list
-    stream = await conn.submit(
-        'graph.features().graph().supportsComputer()')
-    msg = await stream.fetch_data()
-    features['computer'] = msg.data[0]
-    stream = await conn.submit(
-        'graph.features().graph().supportsTransactions()')
-    msg = await stream.fetch_data()
-    features['transactions'] = msg.data[0]
-    stream = await conn.submit(
-        'graph.features().graph().supportsPersistence()')
-    msg = await stream.fetch_data()
-    features['persistence'] = msg.data[0]
-    stream = await conn.submit(
-        'graph.features().graph().supportsConcurrentAccess()')
-    msg = await stream.fetch_data()
-    features['concurrent_access'] = msg.data[0]
-    stream = await conn.submit(
-        'graph.features().graph().supportsThreadedTransactions()')
-    msg = await stream.fetch_data()
-    features['threaded_transactions'] = msg.data[0]
-
-    return Engine(url, conn, loop, force_close=force_close, **features)
-
-
-# Main API classes
-class Engine(driver.AbstractConnection):
-    """Class used to encapsulate database connection configuration and generate
-       database connections. Used as a factory to create :py:class:`Session`
-       objects. More config coming soon."""
-
-    def __init__(self, url, conn, loop, *, force_close=True, **features):
-        self._url = url
-        self._conn = conn
-        self._loop = loop
-        self._force_close = force_close
-        self._features = features
-        self._translator = process.GroovyTranslator('g')
-
-    @property
-    def translator(self):
-        return self._translator
-
-    @property
-    def url(self):
-        return self._url
-
-    @property
-    def conn(self):
-        return self._conn
-
-    def session(self, *, use_session=False):
-        return session.Session(self, use_session=use_session)
-
-    async def submit(self, query, *, bindings=None, session=None):
-        return await self._conn.submit(query, bindings=bindings)
-
-    async def close(self):
-        await self.conn.close()
-        self._conn = None
diff --git a/goblin/session.py b/goblin/session.py
index f1b4998..9082761 100644
--- a/goblin/session.py
+++ b/goblin/session.py
@@ -1,10 +1,12 @@
 """Main OGM API classes and constructors"""
+import asyncio
 import collections
 import logging
 
 from goblin import mapper
 from goblin import traversal
-from goblin.driver import connection
+from goblin.driver import connection, graph
+from goblin.gremlin_python import process
 
 
 logger = logging.getLogger(__name__)
@@ -14,19 +16,25 @@ class Session(connection.AbstractConnection):
     """Provides the main API for interacting with the database. Does not
        necessarily correpsond to a database session."""
 
-    def __init__(self, engine, *, use_session=False):
-        self._engine = engine
-        self._loop = self._engine._loop
+    def __init__(self, app, conn, *, use_session=False):
+        self._app = app
+        self._conn = conn
+        self._loop = self._app._loop
         self._use_session = False
-        self._session = None
-        self._traversal_factory = traversal.TraversalFactory(
-            self, self.engine.translator, self._loop)
         self._pending = collections.deque()
         self._current = {}
+        remote_graph = graph.AsyncRemoteGraph(
+            self._app.translator, self,
+            graph_traversal=traversal.GoblinTraversal)
+        self._traversal_factory = traversal.TraversalFactory(remote_graph)
 
     @property
-    def engine(self):
-        return self._engine
+    def app(self):
+        return self._app
+
+    @property
+    def conn(self):
+        return self._conn
 
     @property
     def traversal_factory(self):
@@ -36,6 +44,64 @@ class Session(connection.AbstractConnection):
     def current(self):
         return self._current
 
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self):
+        await self.close()
+
+    async def close(self):
+        await self.conn.close()
+        self._traversal_factory = None
+        self._app = None
+
+    # Traversal API
+    @property
+    def g(self):
+        """Returns a simple traversal source"""
+        return self.traversal_factory.traversal()
+
+    def traversal(self, element_class):
+        """Returns a traversal spawned from an element class"""
+        return self.traversal_factory.traversal(element_class=element_class)
+
+    async def submit(self,
+                    gremlin,
+                    *,
+                    bindings=None,
+                    lang='gremlin-groovy'):
+        """Get all results generated by query"""
+        async_iter = await self.conn.submit(
+            gremlin, bindings=bindings, lang=lang)
+        response_queue = asyncio.Queue(loop=self._loop)
+        self._loop.create_task(
+            self._receive(async_iter, response_queue))
+        return traversal.TraversalResponse(response_queue)
+
+    async def _receive(self, async_iter, response_queue):
+        async for msg in async_iter:
+            results = msg.data
+            if results:
+                for result in results:
+                    current = self.current.get(result['id'], None)
+                    if not current:
+                        element_type = result['type']
+                        label = result['label']
+                        if element_type == 'vertex':
+                            current = self.app.vertices.get(label, None)
+                        else:
+                            current = self.app.edges.get(label, None)
+                        if not current:
+                            # build generic element here
+                            pass
+                        else:
+                            current = current()
+                    element = current.__mapping__.mapper_func(
+                        result, current)
+                    response_queue.put_nowait(element)
+        response_queue.put_nowait(None)
+
+    # Creation API
     def add(self, *elements):
         for elem in elements:
             self._pending.append(elem)
@@ -45,16 +111,17 @@ class Session(connection.AbstractConnection):
             elem = self._pending.popleft()
             await self.save(elem)
 
-    @property
-    def g(self):
-        """Returns a simple traversal source"""
-        return self.traversal_factory.traversal().graph.traversal()
+    async def remove_vertex(self, element):
+        traversal = self.traversal_factory.remove_vertex(element)
+        result = await self._simple_traversal(traversal, element)
+        del self.current[element.id]
+        return result
 
-    def traversal(self, element_class):
-        """Returns a traversal spawned from an element class"""
-        label = element_class.__mapping__.label
-        return self.traversal_factory.traversal(
-            element_class=element_class).traversal()
+    async def remove_edge(self, element):
+        traversal = self.traversal_factory.remove_edge(element)
+        result = await self._simple_traversal(traversal, element)
+        del self.current[element.id]
+        return result
 
     async def save(self, element):
         if element.__type__ == 'vertex':
@@ -83,6 +150,37 @@ class Session(connection.AbstractConnection):
         self.current[result.id] = result
         return result
 
+    async def get_vertex(self, element):
+        return await self.traversal_factory.get_vertex_by_id(element).one_or_none()
+
+    async def get_edge(self, element):
+        return await self.traversal_factory.get_edge_by_id(element).one_or_none()
+
+    # Transaction support
+    def tx(self):
+        raise NotImplementedError
+
+    def _wrap_in_tx(self):
+        raise NotImplementedError
+
+    async def commit(self):
+        await self.flush()
+        if self.engine._features['transactions'] and self._use_session():
+            await self.tx()
+        raise NotImplementedError
+
+    async def rollback(self):
+        raise NotImplementedError
+
+    # *metodos especiales privados for creation API
+    async def _simple_traversal(self, traversal, element):
+        stream = await self.conn.submit(
+            repr(traversal), bindings=traversal.bindings)
+        msg = await stream.fetch_data()
+        if msg.data:
+            msg = element.__mapping__.mapper_func(msg.data[0], element)
+            return msg
+
     async def _save_element(self,
                             element,
                             check_func,
@@ -91,64 +189,21 @@ class Session(connection.AbstractConnection):
         if hasattr(element, 'id'):
             result = await check_func(element)
             if not result.data:
-                element = await create_func(element)
+                traversal = create_func(element)
             else:
-                element = await update_func(element)
+                traversal = update_func(element)
         else:
-            element = await create_func(element)
-        return element
-
-    async def remove_vertex(self, element):
-        result = await self.traversal_factory.remove_vertex(element)
-        del self.current[element.id]
-        return result
-
-    async def remove_edge(self, element):
-        result = await self.traversal_factory.remove_edge(element)
-        del self.current[element.id]
-        return result
-
-    async def get_vertex(self, element):
-        return await self.traversal_factory.get_vertex_by_id(element)
-
-    async def get_edge(self, element):
-        return await self.traversal_factory.get_edge_by_id(element)
+            traversal = create_func(element)
+        return await self._simple_traversal(traversal, element)
 
     async def _check_vertex(self, element):
         """Used to check for existence, does not update session element"""
         traversal = self.g.V(element.id)
-        stream = await self.submit(repr(traversal))
+        stream = await self.conn.submit(repr(traversal))
         return await stream.fetch_data()
 
     async def _check_edge(self, element):
         """Used to check for existence, does not update session element"""
         traversal = self.g.E(element.id)
-        stream = await self.submit(repr(traversal))
+        stream = await self.conn.submit(repr(traversal))
         return await stream.fetch_data()
-
-
-    async def submit(self,
-                    gremlin,
-                    *,
-                    bindings=None,
-                    lang='gremlin-groovy'):
-        if self.engine._features['transactions'] and not self._use_session():
-            gremlin = self._wrap_in_tx(gremlin)
-        stream = await self.engine.submit(gremlin, bindings=bindings,
-                                          session=self._session)
-        return stream
-
-    def _wrap_in_tx(self):
-        raise NotImplementedError
-
-    def tx(self):
-        raise NotImplementedError
-
-    async def commit(self):
-        await self.flush()
-        if self.engine._features['transactions'] and self._use_session():
-            await self.tx()
-        raise NotImplementedError
-
-    async def rollback(self):
-        raise NotImplementedError
diff --git a/goblin/traversal.py b/goblin/traversal.py
index 0f326ea..2ebdfd0 100644
--- a/goblin/traversal.py
+++ b/goblin/traversal.py
@@ -5,7 +5,6 @@ import logging
 
 from goblin import mapper
 from goblin.driver import connection, graph
-from goblin.gremlin_python import process
 
 
 logger = logging.getLogger(__name__)
@@ -37,134 +36,67 @@ class GoblinTraversal(graph.AsyncGraphTraversal):
     async def all(self):
         return await self.next()
 
-    async def one(self):
-        # Idk really know how one will work
-        async for element in await self.all():
-            return element
-
-
-class Traversal(connection.AbstractConnection):
-    """Wrapper for AsyncRemoteGraph that functions as a remote connection.
-       Used to generate/submit traversals."""
-    def __init__(self, session, translator, loop, *, element=None,
-                 element_class=None):
-        self._session = session
-        self._translator = translator
-        self._loop = loop
-        self._element = element
-        self._element_class = element_class
-        self._graph = graph.AsyncRemoteGraph(self._translator,
-                                             self,  # Traversal implements RC
-                                             graph_traversal=GoblinTraversal)
+    async def one_or_none(self):
+        async for msg in await self.next():
+            return resp
+
+
+class TraversalFactory:
+    """Helper that wraps a AsyncRemoteGraph"""
+    def __init__(self, graph):
+        self._graph = graph
+        self._binding = 0
 
     @property
     def graph(self):
         return self._graph
 
-    @property
-    def session(self):
-        return self._session
-
-    def traversal(self):
+    def traversal(self, *, element_class=None):
         traversal = self.graph.traversal()
-        if self._element_class:
-            label = self._element_class.__mapping__.label
+        if element_class:
+            label = element_class.__mapping__.label
             traversal = self._graph.traversal()
-            if self._element_class.__type__ == 'vertex':
+            if element_class.__type__ == 'vertex':
                 traversal = traversal.V()
-            if self._element_class.__type__ == 'edge':
+            if element_class.__type__ == 'edge':
                 traversal = traversal.E()
             traversal = traversal.hasLabel(label)
         return traversal
 
-    async def submit(self,
-                    gremlin,
-                    *,
-                    bindings=None,
-                    lang='gremlin-groovy'):
-        """Get all results generated by query"""
-        async_iter = await self.session.submit(
-            gremlin, bindings=bindings, lang=lang)
-        response_queue = asyncio.Queue(loop=self._loop)
-        self._loop.create_task(
-            self._receive(async_iter, response_queue))
-        return TraversalResponse(response_queue)
-
-    async def _receive(self, async_iter, response_queue):
-        async for msg in async_iter:
-            results = msg.data
-            if results:
-                for result in results:
-                    current = self.session.current.get(result['id'], None)
-                    if not current:
-                        if self._element or self._element_class:
-                            current = self._element or self._element_class()
-                        else:
-                            # build generic element here
-                            pass
-                    element = current.__mapping__.mapper_func(
-                        result, current)
-                    response_queue.put_nowait(element)
-        response_queue.put_nowait(None)
-
-
-class TraversalFactory:
-
-    def __init__(self, session, translator, loop):
-        self._session = session
-        self._translator = translator
-        self._loop = loop
-        self._binding = 0
-
-    def traversal(self, *, element=None, element_class=None):
-        return Traversal(self._session,
-                         self._translator,
-                         self._loop,
-                         element=element,
-                         element_class=element_class)
-
-    async def remove_vertex(self, element):
-        traversal = self.traversal(element=element)
-        return await traversal.graph.traversal().V(element.id).drop().one()
+    def remove_vertex(self, element):
+        return self.traversal().V(element.id).drop()
 
-    async def remove_edge(self, element):
-        traversal = self.traversal(element=element)
-        return await traversal.graph.traversal().E(element.id).drop().one()
+    def remove_edge(self, element):
+        return self.traversal().E(element.id).drop()
 
-    async def get_vertex_by_id(self, element):
-        traversal = self.traversal(element=element)
-        return await traversal.graph.traversal().V(element.id).one()
+    def get_vertex_by_id(self, element):
+        return self.traversal().V(element.id)
 
-    async def get_edge_by_id(self, element):
-        traversal = self.traversal(element=element)
-        return await traversal.graph.traversal().E(element.id).one()
+    def get_edge_by_id(self, element):
+        return self.traversal().E(element.id)
 
-    async def add_vertex(self, element):
+    def add_vertex(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        traversal = self.traversal(element=element)
-        traversal = traversal.graph.traversal().addV(element.__mapping__.label)
-        return await self._add_properties(traversal, props).one()
+        traversal = self.traversal().addV(element.__mapping__.label)
+        return self._add_properties(traversal, props)
 
-    async def add_edge(self, element):
+    def add_edge(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        base_traversal = self.traversal(element=element)
-        traversal = base_traversal.graph.traversal().V(element.source.id)
+        traversal = self.traversal().V(element.source.id)
         traversal = traversal.addE(element.__mapping__._label)
         traversal = traversal.to(
-            base_traversal.graph.traversal().V(element.target.id))
-        return await self._add_properties(traversal, props).one()
+            self.traversal().V(element.target.id))
+        return self._add_properties(traversal, props)
 
-    async def update_vertex(self, element):
+    def update_vertex(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        traversal = self.traversal(element=element)
-        traversal = traversal.graph.traversal().V(element.id)
-        return await self._add_properties(traversal, props).one()
+        traversal = self.traversal().V(element.id)
+        return self._add_properties(traversal, props)
 
-    async def update_edge(self, element):
+    def update_edge(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        traversal = self.traversal(element=element)
-        traversal = traversal.graph.traversal().E(element.id)
-        return await self._add_properties(traversal, props).one()
+        traversal = self.traversal().E(element.id)
+        return self._add_properties(traversal, props)
 
     def _add_properties(self, traversal, props):
         for k, v in props:
diff --git a/tests/test_engine.py b/tests/test_engine.py
index 061293a..f42aa31 100644
--- a/tests/test_engine.py
+++ b/tests/test_engine.py
@@ -1,7 +1,7 @@
 import asyncio
 import unittest
 
-from goblin.engine import create_engine
+from goblin.app import create_app
 from goblin.element import Vertex, Edge, VertexProperty
 from goblin.properties import Property, String
 
@@ -28,9 +28,12 @@ class TestEngine(unittest.TestCase):
 
     def test_add_vertex(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             leif.notes = 'superdev'
@@ -41,16 +44,18 @@ class TestEngine(unittest.TestCase):
             self.assertEqual(current.notes, 'superdev')
             self.assertIs(leif, current)
             self.assertEqual(leif.id, current.id)
-            await engine.close()
-            print(engine)
+            await session.close()
 
         self.loop.run_until_complete(go())
 
     def test_update_vertex(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             session.add(leif)
@@ -65,16 +70,19 @@ class TestEngine(unittest.TestCase):
             new_current = session._current[leif.id]
             self.assertIs(current, new_current)
             self.assertEqual(new_current.name, 'leif')
-            await engine.close()
+            await session.close()
 
 
         self.loop.run_until_complete(go())
 
     def test_add_edge(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex, TestEdge)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             jon = TestVertex()
@@ -94,15 +102,18 @@ class TestEngine(unittest.TestCase):
             self.assertEqual(leif.id, current.target.id)
             self.assertIs(jon, current.source)
             self.assertEqual(jon.id, current.source.id)
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
 
     def test_update_edge(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex, TestEdge)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             jon = TestVertex()
@@ -119,7 +130,7 @@ class TestEngine(unittest.TestCase):
             await session.flush()
             new_current = session._current[works_for.id]
             self.assertEqual(new_current.notes, 'zerofail')
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
 
@@ -128,9 +139,12 @@ class TestEngine(unittest.TestCase):
 
     def test_query_all(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             jon = TestVertex()
@@ -145,15 +159,18 @@ class TestEngine(unittest.TestCase):
             self.assertEqual(len(session.current), 2)
             for result in results:
                 self.assertIsInstance(result, Vertex)
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
 
     def test_remove_vertex(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex, TestEdge)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             session.add(leif)
@@ -164,15 +181,18 @@ class TestEngine(unittest.TestCase):
             result = await session.get_vertex(leif)
             self.assertIsNone(result)
             self.assertEqual(len(list(session.current.items())), 0)
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
 
     def test_remove_edge(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex, TestEdge)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'leifur'
             jon = TestVertex()
@@ -189,15 +209,18 @@ class TestEngine(unittest.TestCase):
             result = await session.get_edge(works_for)
             self.assertIsNone(result)
             self.assertEqual(len(list(session.current.items())), 2)
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
 
     def test_traversal(self):
 
+        app = self.loop.run_until_complete(
+            create_app("http://localhost:8182/", self.loop))
+        app.register(TestVertex, TestEdge)
+
         async def go():
-            engine = await create_engine("http://localhost:8182/", self.loop)
-            session = engine.session()
+            session = await app.session()
             leif = TestVertex()
             leif.name = 'the one and only leifur'
             jon = TestVertex()
@@ -218,6 +241,6 @@ class TestEngine(unittest.TestCase):
                 self.assertIs(msg, leif)
             await session.remove_vertex(leif)
             await session.remove_vertex(jon)
-            await engine.close()
+            await session.close()
 
         self.loop.run_until_complete(go())
-- 
GitLab