diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 468dda631da3ca9999622d07faf2e6c8671f7dba..987c2a1f140337983911e94e1465cdfe91ed51eb 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -46,10 +46,6 @@ class AbstractConnection(abc.ABC): async def submit(self): raise NotImplementedError - @abc.abstractmethod - async def close(self): - raise NotImplementedError - class Connection(AbstractConnection): diff --git a/goblin/driver/graph.py b/goblin/driver/graph.py index 7d6c70e5aa0ab90037f94062929158d1684a3281..cd2fc2e932019c97528f7871c13b8b02556fa40b 100644 --- a/goblin/driver/graph.py +++ b/goblin/driver/graph.py @@ -29,17 +29,20 @@ class AsyncRemoteStrategy(TraversalStrategy): return result -class AsyncGraph(object): +class AsyncGraph: def traversal(self): return GraphTraversalSource(self, self.traversal_strategy, - graph_traversal=AsyncGraphTraversal) + graph_traversal=self.graph_traversal) class AsyncRemoteGraph(AsyncGraph): - def __init__(self, translator, remote_connection): + def __init__(self, translator, remote_connection, *, graph_traversal=None): self.traversal_strategy = AsyncRemoteStrategy() # A single traversal strategy self.translator = translator self.remote_connection = remote_connection + if graph_traversal is None: + graph_traversal = AsyncGraphTraversal + self.graph_traversal = graph_traversal def __repr__(self): return "remotegraph[" + self.remote_connection.url + "]" diff --git a/goblin/session.py b/goblin/session.py index 461e83a5934e89158048f09f09ca23e1039ffc44..a91fc1197a05e3c508185dfb3c151534ce152615 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -3,13 +3,14 @@ import collections import logging from goblin import mapper -from goblin import query +from goblin import traversal +from goblin.driver import connection logger = logging.getLogger(__name__) -class Session: +class Session(connection.AbstractConnection): """Provides the main API for interacting with the database. Does not necessarily correpsond to a database session.""" @@ -18,7 +19,8 @@ class Session: self._loop = self._engine._loop self._use_session = False self._session = None - self._query = query.Query(self, self.engine.translator, self._loop) + self._traversal_factory = traversal.TraversalFactory( + self, self.engine.translator, self._loop) self._pending = collections.deque() self._current = {} @@ -27,8 +29,8 @@ class Session: return self._engine @property - def query(self): - return self._query + def traversal_factory(self): + return self._traversal_factory @property def current(self): @@ -44,7 +46,9 @@ class Session: await self.save(elem) def traversal(self, element_class): - return self.query.traversal(element_class) + label = element_class.__mapping__.label + return self.traversal_factory.traversal( + element_class).traversal().hasLabel(label) async def save(self, element): if element.__type__ == 'vertex': @@ -56,20 +60,20 @@ class Session: return result async def save_vertex(self, element): - result = await self._save_element(element, - self.query.get_vertex_by_id, - self.query.add_vertex, - self.query.update_vertex) + result = await self._save_element( + element, self.traversal_factory.get_vertex_by_id, + self.traversal_factory.add_vertex, + self.traversal_factory.update_vertex) self.current[result.id] = result return result async def save_edge(self, element): if not (hasattr(element, 'source') and hasattr(element, 'target')): raise Exception("Edges require source/target vetices") - result = await self._save_element(element, - self.query.get_edge_by_id, - self.query.add_edge, - self.query.update_edge) + result = await self._save_element( + element, self.traversal_factory.get_edge_by_id, + self.traversal_factory.add_edge, + self.traversal_factory.update_edge) self.current[result.id] = result return result @@ -93,12 +97,12 @@ class Session: return element.__mapping__.mapper_func(result.data[0], element) async def remove_vertex(self, element): - traversal = self.query.remove_vertex(element) + traversal = self.traversal_factory.remove_vertex(element) result = await self._remove_element(element, traversal) return result async def remove_edge(self, element): - traversal = self.query.remove_edge(element) + traversal = self.traversal_factory.remove_edge(element) result = await self._remove_element(element, traversal) return result @@ -109,7 +113,7 @@ class Session: return result async def get_vertex(self, element): - traversal = self.query.get_vertex_by_id(element) + traversal = self.traversal_factory.get_vertex_by_id(element) stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: @@ -117,7 +121,7 @@ class Session: return vertex async def get_edge(self, element): - traversal = self.query.get_edge_by_id(element) + traversal = self.traversal_factory.get_edge_by_id(element) stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: @@ -125,11 +129,19 @@ class Session: return vertex async def execute_traversal(self, traversal): - # Move parsing to query - script, bindings = query.parse_traversal(traversal) + script = repr(traversal) + bindings = traversal.bindings + lang = traversal.graph.translator.target_language + return await self.submit(script, bindings=bindings, lang=lang) + + async def submit(self, + gremlin, + *, + bindings=None, + lang='gremlin-groovy'): if self.engine._features['transactions'] and not self._use_session(): - script = self._wrap_in_tx(script) - stream = await self.engine.submit(script, bindings=bindings, + gremlin = self._wrap_in_tx(gremlin) + stream = await self.engine.submit(gremlin, bindings=bindings, session=self._session) return stream diff --git a/goblin/query.py b/goblin/traversal.py similarity index 54% rename from goblin/query.py rename to goblin/traversal.py index 546186a64bfbe996ce3499ab37ddcc034bf1854a..a0ece2c5c911468f6f5cfac3d44dfc2d6b8dad4a 100644 --- a/goblin/query.py +++ b/goblin/traversal.py @@ -4,19 +4,14 @@ import functools import logging from goblin import mapper -from goblin.gremlin_python import structure, process +from goblin.driver import connection, graph +from goblin.gremlin_python import process logger = logging.getLogger(__name__) -def parse_traversal(traversal): - script = repr(traversal) - bindings = traversal.bindings - return script, bindings - - -class QueryResponse: +class TraversalResponse: def __init__(self, response_queue): self._queue = response_queue @@ -37,93 +32,57 @@ class QueryResponse: # This is all a hack until we figure out GLV integration... -class GoblinTraversal(process.GraphTraversal): - - def __init__(self, graph, traversal_strategies, bytecode, *, query=None, - element_class=None): - super().__init__(graph, traversal_strategies, bytecode) - self._query = query - self._element_class = element_class +class GoblinTraversal(graph.AsyncGraphTraversal): async def all(self): - result = await self._query._all(self) - self._query = None - return result - - @property - def element_class(self): - return self._element_class - - def next(self): - raise NotImplementedError - - def __repr__(self): - return self.graph.translator.translate(self.bytecode) + return await self.next() - def toList(self): - raise NotImplementedError - - def toSet(self): - raise NotImplementedError - - -class NoOpGraph(structure.Graph): - """A silly graph that doesn't have traversal strategies and doesn't use a - connection.""" - def __init__(self, translator): - self.translator = translator - def traversal(self, *, query=None, element_class=None): - traversal = functools.partial( - GoblinTraversal, query=query, element_class=element_class) - return process.GraphTraversalSource(self, - None, - graph_traversal=traversal) - - -class Query: +class Traversal(connection.AbstractConnection): """Provides interface for user generated queries""" - def __init__(self, session, translator, loop): + def __init__(self, element_class, session, translator, loop): + self._element_class = element_class self._session = session self._translator = translator - self._graph = NoOpGraph(self._translator) + self._graph = graph.AsyncRemoteGraph(self._translator, + self, # Traversal implements RC + graph_traversal=GoblinTraversal) self._loop = loop - self._binding = 0 - - @property - def session(self): - return self._session @property - def g(self): - return self.traversal_source + def graph(self): + return self._graph @property - def traversal_source(self): - return self._graph.traversal() + def session(self): + return self._session # Generative query methods... def filter(self, **kwargs): """Add a filter to the query""" raise NotImplementedError - def traversal(self, element_class): - traversal = self._graph.traversal(query=self, - element_class=element_class) - if element_class.__type__ == 'vertex': + def traversal(self): + label = self._element_class.__mapping__.label + traversal = self._graph.traversal() + if self._element_class.__type__ == 'vertex': traversal = traversal.V() - if element_class.__type__ == 'edge': + if self._element_class.__type__ == 'edge': traversal = traversal.E() - return traversal.hasLabel(element_class.__mapping__.label) + return traversal.hasLabel(label) - # Methods that issue a traversal query to server - async def _all(self, traversal): + async def submit(self, + gremlin, + *, + bindings=None, + lang='gremlin-groovy'): """Get all results generated by query""" - async_iter = await self.session.execute_traversal(traversal) + async_iter = await self.session.submit( + gremlin, bindings=bindings, lang=lang) response_queue = asyncio.Queue(loop=self._loop) self._loop.create_task( - self._receive(async_iter, response_queue, traversal.element_class)) - return QueryResponse(response_queue) + self._receive(async_iter, response_queue, self._element_class)) + return TraversalResponse(response_queue) async def _receive(self, async_iter, response_queue, element_class): async for msg in async_iter: @@ -138,39 +97,63 @@ class Query: response_queue.put_nowait(element) response_queue.put_nowait(None) + +class TraversalFactory: + + def __init__(self, session, translator, loop): + self._session = session + self._translator = translator + self._loop = loop + self._binding = 0 + + def traversal(self, element_class): + return Traversal(element_class, + self._session, + self._translator, + self._loop) + # Common CRUD methods that generate traversals def remove_vertex(self, element): - return self.g.V(element.id).drop() + traversal = self.traversal(element.__class__) + return traversal.graph.traversal().V(element.id).drop() def remove_edge(self, element): - return self.g.E(element.id).drop() + traversal = self.traversal(element.__class__) + return traversal.graph.traversal().E(element.id).drop() def get_vertex_by_id(self, element): - return self.g.V(element.id) + traversal = self.traversal(element.__class__) + return traversal.graph.traversal().V(element.id) def get_edge_by_id(self, element): - return self.g.E(element.id) + traversal = self.traversal(element.__class__) + return traversal.graph.traversal().E(element.id) def add_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.addV(element.__mapping__.label) + traversal = self.traversal(element.__class__) + traversal = traversal.graph.traversal().addV(element.__mapping__.label) return self._add_properties(traversal, props) def add_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.V(element.source.id) + base_traversal = self.traversal(element.__class__) + traversal = base_traversal.graph.traversal().V(element.source.id) traversal = traversal.addE(element.__mapping__._label) - traversal = traversal.to(self.g.V(element.target.id)) + traversal = traversal.to( + base_traversal.graph.traversal().V(element.target.id)) return self._add_properties(traversal, props) def update_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.V(element.id) + traversal = self.traversal(element.__class__) + traversal = traversal.graph.traversal().V(element.id) return self._add_properties(traversal, props) def update_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.g.E(element.id) + traversal = self.traversal(element.__class__) + traversal = traversal.graph.traversal().E(element.id) return self._add_properties(traversal, props) def _add_properties(self, traversal, props):