From a129fb482823349d11e1c317b1464053340cbd46 Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Tue, 12 Jul 2016 21:08:11 -0400 Subject: [PATCH] changed config and registry to app module. simplified traversal API --- goblin/__init__.py | 2 +- goblin/app.py | 94 ++++++++++++++++++ goblin/driver/api.py | 2 - goblin/driver/connection.py | 12 +-- goblin/engine.py | 81 ---------------- goblin/session.py | 189 +++++++++++++++++++++++------------- goblin/traversal.py | 140 +++++++------------------- tests/test_engine.py | 75 +++++++++----- 8 files changed, 304 insertions(+), 291 deletions(-) create mode 100644 goblin/app.py delete mode 100644 goblin/engine.py diff --git a/goblin/__init__.py b/goblin/__init__.py index a9b21a4..4b264cc 100644 --- a/goblin/__init__.py +++ b/goblin/__init__.py @@ -1,3 +1,3 @@ from goblin.element import Vertex, Edge, VertexProperty -from goblin.engine import Engine +from goblin.app import create_app, App from goblin.properties import Property, String diff --git a/goblin/app.py b/goblin/app.py new file mode 100644 index 0000000..2ba2426 --- /dev/null +++ b/goblin/app.py @@ -0,0 +1,94 @@ +"""Main OGM API classes and constructors""" +import collections +import logging + +from goblin.gremlin_python import process +from goblin import driver +from goblin import session + + +logger = logging.getLogger(__name__) + + +# Constructor API +async def create_app(url, loop, **config): + """Constructor function for :py:class:`Engine`. Connects to database + and builds a dictionary of relevant vendor implmentation features""" + features = {} + async with await driver.GremlinServer.open(url, loop) as conn: + # Propbably just use a parser to parse the whole feature list + stream = await conn.submit( + 'graph.features().graph().supportsComputer()') + msg = await stream.fetch_data() + features['computer'] = msg.data[0] + stream = await conn.submit( + 'graph.features().graph().supportsTransactions()') + msg = await stream.fetch_data() + features['transactions'] = msg.data[0] + stream = await conn.submit( + 'graph.features().graph().supportsPersistence()') + msg = await stream.fetch_data() + features['persistence'] = msg.data[0] + stream = await conn.submit( + 'graph.features().graph().supportsConcurrentAccess()') + msg = await stream.fetch_data() + features['concurrent_access'] = msg.data[0] + stream = await conn.submit( + 'graph.features().graph().supportsThreadedTransactions()') + msg = await stream.fetch_data() + features['threaded_transactions'] = msg.data[0] + return App(url, loop, features=features, **config) + + +# Main API classes +class App: + """Class used to encapsulate database connection configuration and generate + database connections. Used as a factory to create :py:class:`Session` + objects. More config coming soon.""" + DEFAULT_CONFIG = { + 'translator': process.GroovyTranslator('g') + } + + def __init__(self, url, loop, *, features=None, **config): + self._url = url + self._loop = loop + self._features = features + self._config = self.DEFAULT_CONFIG + self._config.update(config) + self._vertices = {} + self._edges = {} + + @property + def vertices(self): + return self._vertices + + @property + def edges(self): + return self._edges + + def from_file(filepath): + pass + + def from_obj(obj): + pass + + @property + def translator(self): + return self._config['translator'] + + @property + def url(self): + return self._url + + def register(self, *elements): + for element in elements: + if element.__type__ == 'vertex': + self._vertices[element.__label__] = element + if element.__type__ == 'edge': + self._edges[element.__label__] = element + + async def session(self, *, use_session=False): + conn = await driver.GremlinServer.open(self.url, self._loop) + return session.Session(self, + conn, + use_session=use_session) diff --git a/goblin/driver/api.py b/goblin/driver/api.py index 22ab80d..0898193 100644 --- a/goblin/driver/api.py +++ b/goblin/driver/api.py @@ -12,12 +12,10 @@ class GremlinServer: loop: asyncio.BaseEventLoop, *, client_session: aiohttp.ClientSession=None, - force_close: bool=False, username: str=None, password: str=None) -> connection.Connection: if client_session is None: client_session = aiohttp.ClientSession(loop=loop) ws = await client_session.ws_connect(url) return connection.Connection(url, ws, loop, client_session, - force_close=force_close, username=username, password=password) diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 987c2a1..6490341 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -49,13 +49,12 @@ class AbstractConnection(abc.ABC): class Connection(AbstractConnection): - def __init__(self, url, ws, loop, conn_factory, *, force_close=True, - username=None, password=None): + def __init__(self, url, ws, loop, conn_factory, *, username=None, + password=None): self._url = url self._ws = ws self._loop = loop self._conn_factory = conn_factory - self._force_close = force_close self._username = username self._password = password self._closed = False @@ -69,10 +68,6 @@ class Connection(AbstractConnection): def closed(self): return self._closed - @property - def force_close(self): - return self._force_close - @property def url(self): return self._url @@ -186,9 +181,6 @@ class Connection(AbstractConnection): raise RuntimeError("{0} {1}".format(message.status_code, message.message)) - async def term(self): - if self._force_close: - await self.close() async def __aenter__(self): return self diff --git a/goblin/engine.py b/goblin/engine.py deleted file mode 100644 index 563f5d4..0000000 --- a/goblin/engine.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Main OGM API classes and constructors""" -import collections -import logging - -from goblin.gremlin_python import process -from goblin import driver -from goblin import session - - -logger = logging.getLogger(__name__) - - -# Constructor API -async def create_engine(url, - loop, - force_close=False): - """Constructor function for :py:class:`Engine`. Connects to database - and builds a dictionary of relevant vendor implmentation features""" - features = {} - # This will be some kind of manager client etc. - conn = await driver.GremlinServer.open(url, loop) - # Propbably just use a parser to parse the whole feature list - stream = await conn.submit( - 'graph.features().graph().supportsComputer()') - msg = await stream.fetch_data() - features['computer'] = msg.data[0] - stream = await conn.submit( - 'graph.features().graph().supportsTransactions()') - msg = await stream.fetch_data() - features['transactions'] = msg.data[0] - stream = await conn.submit( - 'graph.features().graph().supportsPersistence()') - msg = await stream.fetch_data() - features['persistence'] = msg.data[0] - stream = await conn.submit( - 'graph.features().graph().supportsConcurrentAccess()') - msg = await stream.fetch_data() - features['concurrent_access'] = msg.data[0] - stream = await conn.submit( - 'graph.features().graph().supportsThreadedTransactions()') - msg = await stream.fetch_data() - features['threaded_transactions'] = msg.data[0] - - return Engine(url, conn, loop, force_close=force_close, **features) - - -# Main API classes -class Engine(driver.AbstractConnection): - """Class used to encapsulate database connection configuration and generate - database connections. Used as a factory to create :py:class:`Session` - objects. More config coming soon.""" - - def __init__(self, url, conn, loop, *, force_close=True, **features): - self._url = url - self._conn = conn - self._loop = loop - self._force_close = force_close - self._features = features - self._translator = process.GroovyTranslator('g') - - @property - def translator(self): - return self._translator - - @property - def url(self): - return self._url - - @property - def conn(self): - return self._conn - - def session(self, *, use_session=False): - return session.Session(self, use_session=use_session) - - async def submit(self, query, *, bindings=None, session=None): - return await self._conn.submit(query, bindings=bindings) - - async def close(self): - await self.conn.close() - self._conn = None diff --git a/goblin/session.py b/goblin/session.py index f1b4998..9082761 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -1,10 +1,12 @@ """Main OGM API classes and constructors""" +import asyncio import collections import logging from goblin import mapper from goblin import traversal -from goblin.driver import connection +from goblin.driver import connection, graph +from goblin.gremlin_python import process logger = logging.getLogger(__name__) @@ -14,19 +16,25 @@ class Session(connection.AbstractConnection): """Provides the main API for interacting with the database. Does not necessarily correpsond to a database session.""" - def __init__(self, engine, *, use_session=False): - self._engine = engine - self._loop = self._engine._loop + def __init__(self, app, conn, *, use_session=False): + self._app = app + self._conn = conn + self._loop = self._app._loop self._use_session = False - self._session = None - self._traversal_factory = traversal.TraversalFactory( - self, self.engine.translator, self._loop) self._pending = collections.deque() self._current = {} + remote_graph = graph.AsyncRemoteGraph( + self._app.translator, self, + graph_traversal=traversal.GoblinTraversal) + self._traversal_factory = traversal.TraversalFactory(remote_graph) @property - def engine(self): - return self._engine + def app(self): + return self._app + + @property + def conn(self): + return self._conn @property def traversal_factory(self): @@ -36,6 +44,64 @@ class Session(connection.AbstractConnection): def current(self): return self._current + async def __aenter__(self): + return self + + async def __aexit__(self): + await self.close() + + async def close(self): + await self.conn.close() + self._traversal_factory = None + self._app = None + + # Traversal API + @property + def g(self): + """Returns a simple traversal source""" + return self.traversal_factory.traversal() + + def traversal(self, element_class): + """Returns a traversal spawned from an element class""" + return self.traversal_factory.traversal(element_class=element_class) + + async def submit(self, + gremlin, + *, + bindings=None, + lang='gremlin-groovy'): + """Get all results generated by query""" + async_iter = await self.conn.submit( + gremlin, bindings=bindings, lang=lang) + response_queue = asyncio.Queue(loop=self._loop) + self._loop.create_task( + self._receive(async_iter, response_queue)) + return traversal.TraversalResponse(response_queue) + + async def _receive(self, async_iter, response_queue): + async for msg in async_iter: + results = msg.data + if results: + for result in results: + current = self.current.get(result['id'], None) + if not current: + element_type = result['type'] + label = result['label'] + if element_type == 'vertex': + current = self.app.vertices.get(label, None) + else: + current = self.app.edges.get(label, None) + if not current: + # build generic element here + pass + else: + current = current() + element = current.__mapping__.mapper_func( + result, current) + response_queue.put_nowait(element) + response_queue.put_nowait(None) + + # Creation API def add(self, *elements): for elem in elements: self._pending.append(elem) @@ -45,16 +111,17 @@ class Session(connection.AbstractConnection): elem = self._pending.popleft() await self.save(elem) - @property - def g(self): - """Returns a simple traversal source""" - return self.traversal_factory.traversal().graph.traversal() + async def remove_vertex(self, element): + traversal = self.traversal_factory.remove_vertex(element) + result = await self._simple_traversal(traversal, element) + del self.current[element.id] + return result - def traversal(self, element_class): - """Returns a traversal spawned from an element class""" - label = element_class.__mapping__.label - return self.traversal_factory.traversal( - element_class=element_class).traversal() + async def remove_edge(self, element): + traversal = self.traversal_factory.remove_edge(element) + result = await self._simple_traversal(traversal, element) + del self.current[element.id] + return result async def save(self, element): if element.__type__ == 'vertex': @@ -83,6 +150,37 @@ class Session(connection.AbstractConnection): self.current[result.id] = result return result + async def get_vertex(self, element): + return await self.traversal_factory.get_vertex_by_id(element).one_or_none() + + async def get_edge(self, element): + return await self.traversal_factory.get_edge_by_id(element).one_or_none() + + # Transaction support + def tx(self): + raise NotImplementedError + + def _wrap_in_tx(self): + raise NotImplementedError + + async def commit(self): + await self.flush() + if self.engine._features['transactions'] and self._use_session(): + await self.tx() + raise NotImplementedError + + async def rollback(self): + raise NotImplementedError + + # *metodos especiales privados for creation API + async def _simple_traversal(self, traversal, element): + stream = await self.conn.submit( + repr(traversal), bindings=traversal.bindings) + msg = await stream.fetch_data() + if msg.data: + msg = element.__mapping__.mapper_func(msg.data[0], element) + return msg + async def _save_element(self, element, check_func, @@ -91,64 +189,21 @@ class Session(connection.AbstractConnection): if hasattr(element, 'id'): result = await check_func(element) if not result.data: - element = await create_func(element) + traversal = create_func(element) else: - element = await update_func(element) + traversal = update_func(element) else: - element = await create_func(element) - return element - - async def remove_vertex(self, element): - result = await self.traversal_factory.remove_vertex(element) - del self.current[element.id] - return result - - async def remove_edge(self, element): - result = await self.traversal_factory.remove_edge(element) - del self.current[element.id] - return result - - async def get_vertex(self, element): - return await self.traversal_factory.get_vertex_by_id(element) - - async def get_edge(self, element): - return await self.traversal_factory.get_edge_by_id(element) + traversal = create_func(element) + return await self._simple_traversal(traversal, element) async def _check_vertex(self, element): """Used to check for existence, does not update session element""" traversal = self.g.V(element.id) - stream = await self.submit(repr(traversal)) + stream = await self.conn.submit(repr(traversal)) return await stream.fetch_data() async def _check_edge(self, element): """Used to check for existence, does not update session element""" traversal = self.g.E(element.id) - stream = await self.submit(repr(traversal)) + stream = await self.conn.submit(repr(traversal)) return await stream.fetch_data() - - - async def submit(self, - gremlin, - *, - bindings=None, - lang='gremlin-groovy'): - if self.engine._features['transactions'] and not self._use_session(): - gremlin = self._wrap_in_tx(gremlin) - stream = await self.engine.submit(gremlin, bindings=bindings, - session=self._session) - return stream - - def _wrap_in_tx(self): - raise NotImplementedError - - def tx(self): - raise NotImplementedError - - async def commit(self): - await self.flush() - if self.engine._features['transactions'] and self._use_session(): - await self.tx() - raise NotImplementedError - - async def rollback(self): - raise NotImplementedError diff --git a/goblin/traversal.py b/goblin/traversal.py index 0f326ea..2ebdfd0 100644 --- a/goblin/traversal.py +++ b/goblin/traversal.py @@ -5,7 +5,6 @@ import logging from goblin import mapper from goblin.driver import connection, graph -from goblin.gremlin_python import process logger = logging.getLogger(__name__) @@ -37,134 +36,67 @@ class GoblinTraversal(graph.AsyncGraphTraversal): async def all(self): return await self.next() - async def one(self): - # Idk really know how one will work - async for element in await self.all(): - return element - - -class Traversal(connection.AbstractConnection): - """Wrapper for AsyncRemoteGraph that functions as a remote connection. - Used to generate/submit traversals.""" - def __init__(self, session, translator, loop, *, element=None, - element_class=None): - self._session = session - self._translator = translator - self._loop = loop - self._element = element - self._element_class = element_class - self._graph = graph.AsyncRemoteGraph(self._translator, - self, # Traversal implements RC - graph_traversal=GoblinTraversal) + async def one_or_none(self): + async for msg in await self.next(): + return resp + + +class TraversalFactory: + """Helper that wraps a AsyncRemoteGraph""" + def __init__(self, graph): + self._graph = graph + self._binding = 0 @property def graph(self): return self._graph - @property - def session(self): - return self._session - - def traversal(self): + def traversal(self, *, element_class=None): traversal = self.graph.traversal() - if self._element_class: - label = self._element_class.__mapping__.label + if element_class: + label = element_class.__mapping__.label traversal = self._graph.traversal() - if self._element_class.__type__ == 'vertex': + if element_class.__type__ == 'vertex': traversal = traversal.V() - if self._element_class.__type__ == 'edge': + if element_class.__type__ == 'edge': traversal = traversal.E() traversal = traversal.hasLabel(label) return traversal - async def submit(self, - gremlin, - *, - bindings=None, - lang='gremlin-groovy'): - """Get all results generated by query""" - async_iter = await self.session.submit( - gremlin, bindings=bindings, lang=lang) - response_queue = asyncio.Queue(loop=self._loop) - self._loop.create_task( - self._receive(async_iter, response_queue)) - return TraversalResponse(response_queue) - - async def _receive(self, async_iter, response_queue): - async for msg in async_iter: - results = msg.data - if results: - for result in results: - current = self.session.current.get(result['id'], None) - if not current: - if self._element or self._element_class: - current = self._element or self._element_class() - else: - # build generic element here - pass - element = current.__mapping__.mapper_func( - result, current) - response_queue.put_nowait(element) - response_queue.put_nowait(None) - - -class TraversalFactory: - - def __init__(self, session, translator, loop): - self._session = session - self._translator = translator - self._loop = loop - self._binding = 0 - - def traversal(self, *, element=None, element_class=None): - return Traversal(self._session, - self._translator, - self._loop, - element=element, - element_class=element_class) - - async def remove_vertex(self, element): - traversal = self.traversal(element=element) - return await traversal.graph.traversal().V(element.id).drop().one() + def remove_vertex(self, element): + return self.traversal().V(element.id).drop() - async def remove_edge(self, element): - traversal = self.traversal(element=element) - return await traversal.graph.traversal().E(element.id).drop().one() + def remove_edge(self, element): + return self.traversal().E(element.id).drop() - async def get_vertex_by_id(self, element): - traversal = self.traversal(element=element) - return await traversal.graph.traversal().V(element.id).one() + def get_vertex_by_id(self, element): + return self.traversal().V(element.id) - async def get_edge_by_id(self, element): - traversal = self.traversal(element=element) - return await traversal.graph.traversal().E(element.id).one() + def get_edge_by_id(self, element): + return self.traversal().E(element.id) - async def add_vertex(self, element): + def add_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element=element) - traversal = traversal.graph.traversal().addV(element.__mapping__.label) - return await self._add_properties(traversal, props).one() + traversal = self.traversal().addV(element.__mapping__.label) + return self._add_properties(traversal, props) - async def add_edge(self, element): + def add_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - base_traversal = self.traversal(element=element) - traversal = base_traversal.graph.traversal().V(element.source.id) + traversal = self.traversal().V(element.source.id) traversal = traversal.addE(element.__mapping__._label) traversal = traversal.to( - base_traversal.graph.traversal().V(element.target.id)) - return await self._add_properties(traversal, props).one() + self.traversal().V(element.target.id)) + return self._add_properties(traversal, props) - async def update_vertex(self, element): + def update_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element=element) - traversal = traversal.graph.traversal().V(element.id) - return await self._add_properties(traversal, props).one() + traversal = self.traversal().V(element.id) + return self._add_properties(traversal, props) - async def update_edge(self, element): + def update_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element=element) - traversal = traversal.graph.traversal().E(element.id) - return await self._add_properties(traversal, props).one() + traversal = self.traversal().E(element.id) + return self._add_properties(traversal, props) def _add_properties(self, traversal, props): for k, v in props: diff --git a/tests/test_engine.py b/tests/test_engine.py index 061293a..f42aa31 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,7 +1,7 @@ import asyncio import unittest -from goblin.engine import create_engine +from goblin.app import create_app from goblin.element import Vertex, Edge, VertexProperty from goblin.properties import Property, String @@ -28,9 +28,12 @@ class TestEngine(unittest.TestCase): def test_add_vertex(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' leif.notes = 'superdev' @@ -41,16 +44,18 @@ class TestEngine(unittest.TestCase): self.assertEqual(current.notes, 'superdev') self.assertIs(leif, current) self.assertEqual(leif.id, current.id) - await engine.close() - print(engine) + await session.close() self.loop.run_until_complete(go()) def test_update_vertex(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' session.add(leif) @@ -65,16 +70,19 @@ class TestEngine(unittest.TestCase): new_current = session._current[leif.id] self.assertIs(current, new_current) self.assertEqual(new_current.name, 'leif') - await engine.close() + await session.close() self.loop.run_until_complete(go()) def test_add_edge(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex, TestEdge) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' jon = TestVertex() @@ -94,15 +102,18 @@ class TestEngine(unittest.TestCase): self.assertEqual(leif.id, current.target.id) self.assertIs(jon, current.source) self.assertEqual(jon.id, current.source.id) - await engine.close() + await session.close() self.loop.run_until_complete(go()) def test_update_edge(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex, TestEdge) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' jon = TestVertex() @@ -119,7 +130,7 @@ class TestEngine(unittest.TestCase): await session.flush() new_current = session._current[works_for.id] self.assertEqual(new_current.notes, 'zerofail') - await engine.close() + await session.close() self.loop.run_until_complete(go()) @@ -128,9 +139,12 @@ class TestEngine(unittest.TestCase): def test_query_all(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' jon = TestVertex() @@ -145,15 +159,18 @@ class TestEngine(unittest.TestCase): self.assertEqual(len(session.current), 2) for result in results: self.assertIsInstance(result, Vertex) - await engine.close() + await session.close() self.loop.run_until_complete(go()) def test_remove_vertex(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex, TestEdge) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' session.add(leif) @@ -164,15 +181,18 @@ class TestEngine(unittest.TestCase): result = await session.get_vertex(leif) self.assertIsNone(result) self.assertEqual(len(list(session.current.items())), 0) - await engine.close() + await session.close() self.loop.run_until_complete(go()) def test_remove_edge(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex, TestEdge) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'leifur' jon = TestVertex() @@ -189,15 +209,18 @@ class TestEngine(unittest.TestCase): result = await session.get_edge(works_for) self.assertIsNone(result) self.assertEqual(len(list(session.current.items())), 2) - await engine.close() + await session.close() self.loop.run_until_complete(go()) def test_traversal(self): + app = self.loop.run_until_complete( + create_app("http://localhost:8182/", self.loop)) + app.register(TestVertex, TestEdge) + async def go(): - engine = await create_engine("http://localhost:8182/", self.loop) - session = engine.session() + session = await app.session() leif = TestVertex() leif.name = 'the one and only leifur' jon = TestVertex() @@ -218,6 +241,6 @@ class TestEngine(unittest.TestCase): self.assertIs(msg, leif) await session.remove_vertex(leif) await session.remove_vertex(jon) - await engine.close() + await session.close() self.loop.run_until_complete(go()) -- GitLab