diff --git a/goblin/api.py b/goblin/api.py index 43d3ba4c94e6f12c275c34fae135c2e197aeb132..c6145330c45f8900367caadd3343fe4910c21726 100644 --- a/goblin/api.py +++ b/goblin/api.py @@ -45,7 +45,6 @@ class Engine: self._loop = loop self._features = features self._translator = gremlin_python.GroovyTranslator('g') - self._g = gremlin_python.PythonGraphTraversalSource(self._translator) # This will be a pool self._driver = driver.Driver(self._url, self._loop) @@ -53,8 +52,8 @@ class Engine: return Session(self) @property - def g(self): - return self._g + def translator(self): + return self._translator @property def url(self): @@ -73,10 +72,16 @@ class Session: def __init__(self, engine): self._engine = engine + self._g = gremlin_python.PythonGraphTraversalSource( + self.engine.translator) self._pending = collections.deque() self._current = {} self._binding = 0 + @property + def g(self): + return self._g + @property def engine(self): return self._engine @@ -94,7 +99,8 @@ class Session: while self._pending: elem = self._pending.popleft() result = await self.save_element(elem) - self._current[result.id] = result + if result: + self._current[result.id] = result async def save_element(self, element): if element.__type__ == 'vertex': @@ -148,49 +154,46 @@ class Session: return mapper_func(result.data[0], element, element.__mapping__) def _get_vertex_by_id(self, element): - traversal = self.engine.g.V(element.id) - script = traversal.translator.traversal_script - bindings = traversal.bindings - return result, bindings + traversal = self.g.V(element.id) + return self._get_script_bindings(traversal) + + def _get_edge_by_id(self, element): + traversal = self.g.E(element.id) + return self._get_script_bindings(traversal) def _create_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.engine.g.addV(element.__mapping__.label) - for k, v in props: - traversal = traversal.property( - ('k' + str(self._binding), k), ('v' + str(self._binding), v)) - self._binding += 1 - self._binding = 0 - script = traversal.translator.traversal_script - bindings = traversal.bindings - return script, bindings + traversal = self.g.addV(element.__mapping__.label) + traversal = self._add_properties(traversal, props) + return self._get_script_bindings(traversal) def _update_vertex(self, element): raise NotImplementedError - def _get_edge_by_id(self, element): - t = self.engine.g.E(element.id) - script = t.translator.traversal_script - bindings = t.bindings - return result, bindings - def _create_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.engine.g.V(element.source.id) + traversal = self.g.V(element.source.id) traversal = traversal.addE(element.__mapping__._label) - traversal = traversal.to(self.engine.g.V(element.target.id)) + traversal = traversal.to(self.g.V(element.target.id)) + traversal = self._add_properties(traversal, props) + return self._get_script_bindings(traversal) + + def _update_edge(self, element): + raise NotImplementedError + + def _add_properties(self, traversal, props): for k, v in props: traversal = traversal.property( ('k' + str(self._binding), k), ('v' + str(self._binding), v)) self._binding += 1 self._binding = 0 + return traversal + + def _get_script_bindings(self, traversal): script = traversal.translator.traversal_script bindings = traversal.bindings return script, bindings - def _update_edge(self, element): - raise NotImplementedError - async def commit(self): raise NotImplementedError diff --git a/goblin/query.py b/goblin/query.py index aa6661f5033458ed0a0e38bcd9b3270f4e8032fc..8d3ab1ed66dc5603fada06b5e6837a6241c85f53 100644 --- a/goblin/query.py +++ b/goblin/query.py @@ -3,12 +3,11 @@ class Query: def __init__(self, session, element_class): self._session = session self._engine = session.engine - self._bindings = {} if element_class.__type__ == 'vertex': - self._traversal = self._engine.g.V().hasLabel( + self._traversal = self._session.g.V().hasLabel( element_class.__mapping__.label) elif element_class.__type__ == 'edge': - self._traversal = self._engine.g.E().hasLabel( + self._traversal = self._session.g.E().hasLabel( element_class.__mapping__.label) else: raise Exception("unknown element type") @@ -20,7 +19,8 @@ class Query: # Methods that issue a query async def all(self): script = self._traversal.translator.traversal_script - stream = await self._engine.execute(script, bindings=self._bindings) + stream = await self._engine.execute( + script, bindings=self._traversal.bindings) # This should return and async iterator wrapper that can see and update # parent session object, but for the demo it works return stream