diff --git a/goblin/engine.py b/goblin/engine.py index 50a1f9df4adbede70801e8d6f53b79708267e9c8..563f5d469b3940dd8085378c1a2fcddf3b971622 100644 --- a/goblin/engine.py +++ b/goblin/engine.py @@ -13,9 +13,7 @@ logger = logging.getLogger(__name__) # Constructor API async def create_engine(url, loop, - maxsize=256, - force_close=False, - force_release=True): + force_close=False): """Constructor function for :py:class:`Engine`. Connects to database and builds a dictionary of relevant vendor implmentation features""" features = {} @@ -43,7 +41,7 @@ async def create_engine(url, msg = await stream.fetch_data() features['threaded_transactions'] = msg.data[0] - return Engine(url, conn, loop, **features) + return Engine(url, conn, loop, force_close=force_close, **features) # Main API classes diff --git a/goblin/session.py b/goblin/session.py index a91fc1197a05e3c508185dfb3c151534ce152615..f1b4998a4786aabc509f04e19abae163498d27a3 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -45,10 +45,16 @@ class Session(connection.AbstractConnection): elem = self._pending.popleft() await self.save(elem) + @property + def g(self): + """Returns a simple traversal source""" + return self.traversal_factory.traversal().graph.traversal() + def traversal(self, element_class): + """Returns a traversal spawned from an element class""" label = element_class.__mapping__.label return self.traversal_factory.traversal( - element_class).traversal().hasLabel(label) + element_class=element_class).traversal() async def save(self, element): if element.__type__ == 'vertex': @@ -61,7 +67,7 @@ class Session(connection.AbstractConnection): async def save_vertex(self, element): result = await self._save_element( - element, self.traversal_factory.get_vertex_by_id, + element, self._check_vertex, self.traversal_factory.add_vertex, self.traversal_factory.update_vertex) self.current[result.id] = result @@ -71,7 +77,7 @@ class Session(connection.AbstractConnection): if not (hasattr(element, 'source') and hasattr(element, 'target')): raise Exception("Edges require source/target vetices") result = await self._save_element( - element, self.traversal_factory.get_edge_by_id, + element, self._check_edge, self.traversal_factory.add_edge, self.traversal_factory.update_edge) self.current[result.id] = result @@ -79,60 +85,47 @@ class Session(connection.AbstractConnection): async def _save_element(self, element, - get_func, + check_func, create_func, update_func): if hasattr(element, 'id'): - traversal = get_func(element) - stream = await self.execute_traversal(traversal) - result = await stream.fetch_data() + result = await check_func(element) if not result.data: - traversal = create_func(element) + element = await create_func(element) else: - traversal = update_func(element) + element = await update_func(element) else: - traversal = create_func(element) - stream = await self.execute_traversal(traversal) - result = await stream.fetch_data() - return element.__mapping__.mapper_func(result.data[0], element) + element = await create_func(element) + return element async def remove_vertex(self, element): - traversal = self.traversal_factory.remove_vertex(element) - result = await self._remove_element(element, traversal) + result = await self.traversal_factory.remove_vertex(element) + del self.current[element.id] return result async def remove_edge(self, element): - traversal = self.traversal_factory.remove_edge(element) - result = await self._remove_element(element, traversal) - return result - - async def _remove_element(self, element, traversal): - stream = await self.execute_traversal(traversal) - result = await stream.fetch_data() + result = await self.traversal_factory.remove_edge(element) del self.current[element.id] return result async def get_vertex(self, element): - traversal = self.traversal_factory.get_vertex_by_id(element) - stream = await self.execute_traversal(traversal) - result = await stream.fetch_data() - if result.data: - vertex = element.__mapping__.mapper_func(result.data[0], element) - return vertex + return await self.traversal_factory.get_vertex_by_id(element) async def get_edge(self, element): - traversal = self.traversal_factory.get_edge_by_id(element) - stream = await self.execute_traversal(traversal) - result = await stream.fetch_data() - if result.data: - vertex = element.__mapping__.mapper_func(result.data[0], element) - return vertex - - async def execute_traversal(self, traversal): - script = repr(traversal) - bindings = traversal.bindings - lang = traversal.graph.translator.target_language - return await self.submit(script, bindings=bindings, lang=lang) + return await self.traversal_factory.get_edge_by_id(element) + + async def _check_vertex(self, element): + """Used to check for existence, does not update session element""" + traversal = self.g.V(element.id) + stream = await self.submit(repr(traversal)) + return await stream.fetch_data() + + async def _check_edge(self, element): + """Used to check for existence, does not update session element""" + traversal = self.g.E(element.id) + stream = await self.submit(repr(traversal)) + return await stream.fetch_data() + async def submit(self, gremlin, diff --git a/goblin/traversal.py b/goblin/traversal.py index a0ece2c5c911468f6f5cfac3d44dfc2d6b8dad4a..0c2dba305c8cb5fcf77361fcbc71e81ff93fe81b 100644 --- a/goblin/traversal.py +++ b/goblin/traversal.py @@ -37,17 +37,25 @@ class GoblinTraversal(graph.AsyncGraphTraversal): async def all(self): return await self.next() + async def one(self): + # Idk really know how one will work + async for element in await self.all(): + return element + class Traversal(connection.AbstractConnection): - """Provides interface for user generated queries""" - def __init__(self, element_class, session, translator, loop): - self._element_class = element_class + """Wrapper for AsyncRemoteGraph that functions as a remote connection. + Used to generate/submit traversals.""" + def __init__(self, session, translator, loop, *, element=None, + element_class=None): self._session = session self._translator = translator + self._loop = loop + self._element = element + self._element_class = element_class self._graph = graph.AsyncRemoteGraph(self._translator, self, # Traversal implements RC graph_traversal=GoblinTraversal) - self._loop = loop @property def graph(self): @@ -57,19 +65,18 @@ class Traversal(connection.AbstractConnection): def session(self): return self._session - # Generative query methods... - def filter(self, **kwargs): - """Add a filter to the query""" - raise NotImplementedError - def traversal(self): - label = self._element_class.__mapping__.label - traversal = self._graph.traversal() - if self._element_class.__type__ == 'vertex': - traversal = traversal.V() - if self._element_class.__type__ == 'edge': - traversal = traversal.E() - return traversal.hasLabel(label) + if self._element_class: + label = self._element_class.__mapping__.label + traversal = self._graph.traversal() + if self._element_class.__type__ == 'vertex': + traversal = traversal.V() + if self._element_class.__type__ == 'edge': + traversal = traversal.E() + traversal = traversal.hasLabel(label) + else: + traversal = self.graph.traversal() + return traversal async def submit(self, gremlin, @@ -81,18 +88,22 @@ class Traversal(connection.AbstractConnection): gremlin, bindings=bindings, lang=lang) response_queue = asyncio.Queue(loop=self._loop) self._loop.create_task( - self._receive(async_iter, response_queue, self._element_class)) + self._receive(async_iter, response_queue)) return TraversalResponse(response_queue) - async def _receive(self, async_iter, response_queue, element_class): + async def _receive(self, async_iter, response_queue): async for msg in async_iter: results = msg.data if results: for result in results: current = self.session.current.get(result['id'], None) if not current: - current = element_class() - element = element_class.__mapping__.mapper_func( + if self._element or self._element_class: + current = self._element or self._element_class() + else: + # build generic element here + pass + element = current.__mapping__.mapper_func( result, current) response_queue.put_nowait(element) response_queue.put_nowait(None) @@ -106,55 +117,55 @@ class TraversalFactory: self._loop = loop self._binding = 0 - def traversal(self, element_class): - return Traversal(element_class, - self._session, + def traversal(self, *, element=None, element_class=None): + return Traversal(self._session, self._translator, - self._loop) + self._loop, + element=element, + element_class=element_class) - # Common CRUD methods that generate traversals - def remove_vertex(self, element): - traversal = self.traversal(element.__class__) - return traversal.graph.traversal().V(element.id).drop() + async def remove_vertex(self, element): + traversal = self.traversal(element=element) + return await traversal.graph.traversal().V(element.id).drop().one() - def remove_edge(self, element): - traversal = self.traversal(element.__class__) - return traversal.graph.traversal().E(element.id).drop() + async def remove_edge(self, element): + traversal = self.traversal(element=element) + return await traversal.graph.traversal().E(element.id).drop().one() - def get_vertex_by_id(self, element): - traversal = self.traversal(element.__class__) - return traversal.graph.traversal().V(element.id) + async def get_vertex_by_id(self, element): + traversal = self.traversal(element=element) + return await traversal.graph.traversal().V(element.id).one() - def get_edge_by_id(self, element): - traversal = self.traversal(element.__class__) - return traversal.graph.traversal().E(element.id) + async def get_edge_by_id(self, element): + traversal = self.traversal(element=element) + return await traversal.graph.traversal().E(element.id).one() - def add_vertex(self, element): + async def add_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element.__class__) + traversal = self.traversal(element=element) traversal = traversal.graph.traversal().addV(element.__mapping__.label) - return self._add_properties(traversal, props) + return await self._add_properties(traversal, props).one() - def add_edge(self, element): + async def add_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - base_traversal = self.traversal(element.__class__) + base_traversal = self.traversal(element=element) traversal = base_traversal.graph.traversal().V(element.source.id) traversal = traversal.addE(element.__mapping__._label) traversal = traversal.to( base_traversal.graph.traversal().V(element.target.id)) - return self._add_properties(traversal, props) + return await self._add_properties(traversal, props).one() - def update_vertex(self, element): + async def update_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element.__class__) + traversal = self.traversal(element=element) traversal = traversal.graph.traversal().V(element.id) - return self._add_properties(traversal, props) + return await self._add_properties(traversal, props).one() - def update_edge(self, element): + async def update_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal(element.__class__) + traversal = self.traversal(element=element) traversal = traversal.graph.traversal().E(element.id) - return self._add_properties(traversal, props) + return await self._add_properties(traversal, props).one() def _add_properties(self, traversal, props): for k, v in props: