diff --git a/goblin/api.py b/goblin/api.py index 0a0b5c23902c4986232e9be1f74c8d5e9137fcdb..54f92e9517857dbb5ebb06c6512ff58fef162952 100644 --- a/goblin/api.py +++ b/goblin/api.py @@ -6,6 +6,7 @@ from goblin import gremlin_python from goblin import driver from goblin import mapper from goblin import meta +from goblin import traversal from goblin import query @@ -49,7 +50,7 @@ async def create_engine(url, # Main API classes -class Engine: +class Engine(driver.AbstractConnection): """Class used to encapsulate database connection configuration and generate database connections. Used as a factory to create :py:class:`Session` objects. More config coming soon.""" @@ -77,7 +78,7 @@ class Engine: def session(self, *, use_session=False): return Session(self, use_session=use_session) - async def execute(self, query, *, bindings=None, session=None): + async def submit(self, query, *, bindings=None, session=None): return await self._conn.submit(query, bindings=bindings) async def close(self): @@ -93,20 +94,18 @@ class Session: self._engine = engine self._use_session = False self._session = None - self._g = gremlin_python.PythonGraphTraversalSource( - self.engine.translator) + self._traversal = traversal.TraversalSource(self.engine.translator) self._pending = collections.deque() self._current = {} - self._binding = 0 - - @property - def g(self): - return self._g @property def engine(self): return self._engine + @property + def traversal(self): + return self._traversal + @property def current(self): return self._current @@ -134,9 +133,9 @@ class Session: async def save_vertex(self, element): result = await self._save_element(element, - self._get_vertex_by_id, - self._create_vertex, - self._update_vertex, + self.traversal.get_vertex_by_id, + self.traversal.add_vertex, + self.traversal.update_vertex, mapper.map_vertex_to_ogm) self.current[result.id] = result return result @@ -145,9 +144,9 @@ class Session: if not (hasattr(element, 'source') and hasattr(element, 'target')): raise Exception("Edges require source/target vetices") result = await self._save_element(element, - self._get_edge_by_id, - self._create_edge, - self._update_edge, + self.traversal.get_edge_by_id, + self.traversal.add_edge, + self.traversal.update_edge, mapper.map_edge_to_ogm) self.current[result.id] = result return result @@ -159,8 +158,8 @@ class Session: update_func, mapper_func): if hasattr(element, 'id'): - traversal = get_func(element.id) - stream = await self._execute_traversal(traversal) + traversal = get_func(element) + stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if not result.data: traversal = create_func(element) @@ -168,88 +167,50 @@ class Session: traversal = update_func(element) else: traversal = create_func(element) - stream = await self._execute_traversal(traversal) + stream = await self.execute_traversal(traversal) result = await stream.fetch_data() return mapper_func(result.data[0], element, element.__mapping__) - async def delete_vertex(self, element): - traversal = self.g.V(element.id).drop() - result = await self._delete_element(element, traversal) + async def remove_vertex(self, element): + traversal = self.traversal.remove_vertex(element) + result = await self._remove_element(element, traversal) return result - async def delete_edge(self, element): - traversal = self.g.E(element.id).drop() - result = await self._delete_element(element, traversal) + async def remove_edge(self, element): + traversal = self.traversal.remove_edge(element) + result = await self._remove_element(element, traversal) return result - async def _delete_element(self, element, traversal): - stream = await self._execute_traversal(traversal) + async def _remove_element(self, element, traversal): + stream = await self.execute_traversal(traversal) result = await stream.fetch_data() del self.current[element.id] return result async def get_vertex(self, element): - traversal = self._get_vertex_by_id(element.id) - stream = await self._execute_traversal(traversal) + traversal = self.traversal.get_vertex_by_id(element) + stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: vertex = mapper.map_vertex_to_ogm(result.data[0], element, element.__mapping__) return vertex - def _get_vertex_by_id(self, vid): - return self.g.V(vid) - async def get_edge(self, element): - traversal = self._get_edge_by_id(element.id) - stream = await self._execute_traversal(traversal) + traversal = self.traversal.get_edge_by_id(element) + stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: - edge = mapper.map_edge_to_ogm(result.data[0], element, - element.__mapping__) - return edge - - def _get_edge_by_id(self, eid): - return self.g.E(eid) - - def _create_vertex(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.addV(element.__mapping__.label) - return self._add_properties(traversal, props) - - def _create_edge(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.V(element.source.id) - traversal = traversal.addE(element.__mapping__._label) - traversal = traversal.to(self.g.V(element.target.id)) - return self._add_properties(traversal, props) - - def _update_vertex(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.V(element.id) - return self._add_properties(traversal, props) - - def _update_edge(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.E(element.id) - return self._add_properties(traversal, props) - - def _add_properties(self, traversal, props): - for k, v in props: - if v: - traversal = traversal.property( - ('k' + str(self._binding), k), - ('v' + str(self._binding), v)) - self._binding += 1 - self._binding = 0 - return traversal - - async def _execute_traversal(self, traversal): + vertex = mapper.map_edge_to_ogm(result.data[0], element, + element.__mapping__) + return vertex + + async def execute_traversal(self, traversal): script, bindings = query.parse_traversal(traversal) if self.engine._features['transactions'] and not self._use_session(): script = self._wrap_in_tx(script) - stream = await self.engine.execute(script, bindings=bindings, - session=self._session) + stream = await self.engine.submit(script, bindings=bindings, + session=self._session) return stream def _wrap_in_tx(self): diff --git a/goblin/driver/__init__.py b/goblin/driver/__init__.py index 9e3904495f4036aa880a5b07f44ef2608208130a..8e929ce5afb0b1c62b265c5c46528b332032d31a 100644 --- a/goblin/driver/__init__.py +++ b/goblin/driver/__init__.py @@ -1 +1,2 @@ from goblin.driver.api import GremlinServer +from goblin.driver.connection import AbstractConnection diff --git a/goblin/query.py b/goblin/query.py index 03c45378eb9e290165a04b8503122d7fb5999ae3..d378e429cd2201148e6381100ca485b8d7ae6313 100644 --- a/goblin/query.py +++ b/goblin/query.py @@ -47,11 +47,11 @@ class Query: self._engine = session.engine self._element_class = element_class if element_class.__type__ == 'vertex': - self._traversal = self._session.g.V().hasLabel( + self._traversal = self.session.traversal.g.V().hasLabel( element_class.__mapping__.label) self._mapper = mapper.map_vertex_to_ogm elif element_class.__type__ == 'edge': - self._traversal = self._session.g.E().hasLabel( + self._traversal = self.session.traversal.g.E().hasLabel( element_class.__mapping__.label) self._mapper = mapper.map_edge_to_ogm else: @@ -69,5 +69,5 @@ class Query: # Methods that issue a query async def all(self): """Get all results generated by query""" - async_iter = await self.session._execute_traversal(self._traversal) + async_iter = await self.session.execute_traversal(self._traversal) return AsyncQueryResponseIter(async_iter, self) diff --git a/goblin/traversal.py b/goblin/traversal.py new file mode 100644 index 0000000000000000000000000000000000000000..0257c510da7aa490baec6a744335f6084543db54 --- /dev/null +++ b/goblin/traversal.py @@ -0,0 +1,64 @@ +"""Class used to produce traversals""" +from goblin import gremlin_python +from goblin import mapper + + +class TraversalSource: + """A wrapper for :py:class:gremlin_python.PythonGraphTraversalSource that + generates commonly used traversals""" + def __init__(self, translator): + self._traversal_source = gremlin_python.PythonGraphTraversalSource( + translator) + self._binding = 0 + + @property + def g(self): + return self.traversal_source + + @property + def traversal_source(self): + return self._traversal_source + + def remove_vertex(self, element): + return self.g.V(element.id).drop() + + def remove_edge(self, element): + return self.g.E(element.id).drop() + + def get_vertex_by_id(self, element): + return self.g.V(element.id) + + def get_edge_by_id(self, element): + return self.g.E(element.id) + + def add_vertex(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.g.addV(element.__mapping__.label) + return self._add_properties(traversal, props) + + def add_edge(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.g.V(element.source.id) + traversal = traversal.addE(element.__mapping__._label) + traversal = traversal.to(self.g.V(element.target.id)) + return self._add_properties(traversal, props) + + def update_vertex(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.g.V(element.id) + return self._add_properties(traversal, props) + + def update_edge(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.g.E(element.id) + return self._add_properties(traversal, props) + + def _add_properties(self, traversal, props): + for k, v in props: + if v: + traversal = traversal.property( + ('k' + str(self._binding), k), + ('v' + str(self._binding), v)) + self._binding += 1 + self._binding = 0 + return traversal diff --git a/tests/test_engine.py b/tests/test_engine.py index 597d49b2514e4e6054d0bbd44060af687d3990fc..a084622c3ff54e4dada0bd94bb99531fc1bd1bf0 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -148,7 +148,7 @@ class TestEngine(unittest.TestCase): self.loop.run_until_complete(go()) - def test_delete_vertex(self): + def test_remove_vertex(self): async def go(): engine = await create_engine("http://localhost:8182/", self.loop) @@ -159,7 +159,7 @@ class TestEngine(unittest.TestCase): await session.flush() current = session._current[leif.id] self.assertIs(leif, current) - await session.delete_vertex(leif) + await session.remove_vertex(leif) result = await session.get_vertex(leif) self.assertIsNone(result) self.assertEqual(len(list(session.current.items())), 0) @@ -167,7 +167,7 @@ class TestEngine(unittest.TestCase): self.loop.run_until_complete(go()) - def test_delete_edge(self): + def test_remove_edge(self): async def go(): engine = await create_engine("http://localhost:8182/", self.loop) @@ -184,7 +184,7 @@ class TestEngine(unittest.TestCase): await session.flush() current = session._current[works_for.id] self.assertIs(current, works_for) - await session.delete_edge(works_for) + await session.remove_edge(works_for) result = await session.get_edge(works_for) self.assertIsNone(result) self.assertEqual(len(list(session.current.items())), 2)