diff --git a/goblin/api.py b/goblin/api.py index 205bf2f0957061b59a4630c40f130fc0fdc8f037..91a1b37508a61251f3d844e9371ba8a6bd33161a 100644 --- a/goblin/api.py +++ b/goblin/api.py @@ -106,27 +106,52 @@ class Session: return result async def save_vertex(self, element): + result = await self._save_element(element, + self._create_vertex, + self._update_vertex, + mapper.map_vertex_to_ogm) + return result + + async def save_edge(self, element): + if not (element.source and element.target): + raise Exception("Edges require source/target vetices") + result = await self._save_element(element, + self._create_edge, + self._update_edge, + mapper.map_edge_to_ogm) + return result + + async def _save_element(self, + element, + create_func, + update_func, + mapper_func): if hasattr(element, 'id'): # Something like # if self._current.get(element.id): # old = self._current[element.id] # element = merge_elements(old, element) - script, bindings = self._get_vertex_by_id(element) + script, bindings = self._get_edge_by_id(element) stream = await self.engine.execute(script, bindings=bindings) result = await stream.fetch_data() await stream.close() if not result.data: - script, bindings = self._create_vertex(element) + script, bindings = create_func(element) else: - script, bindings = self._update_vertex(element) + script, bindings = update_func(element) else: - script, bindings = self._create_vertex(element) + script, bindings = create_func(element) stream = await self.engine.execute(script, bindings=bindings) result = await stream.fetch_data() # Will just release the conn back to pool here await stream.close() - return mapper.map_vertex_to_ogm(result.data[0], element, - element.__mapping__) + return mapper_func(result.data[0], element, element.__mapping__) + + def _get_vertex_by_id(self, element): + traversal = self.engine.g.V(element.id) + script = traversal.translator.traversal_script + bindings = traversal.bindings + return result, bindings def _create_vertex(self, element): props = mapper.map_props_to_db(element, element.__mapping__) @@ -143,37 +168,12 @@ class Session: def _update_vertex(self, element): raise NotImplementedError - def _get_vertex_by_id(self, element): - traversal = self.engine.g.V(element.id) - script = traversal.translator.traversal_script - bindings = traversal.bindings + def _get_edge_by_id(self, element): + t = self.engine.g.E(element.id) + script = t.translator.traversal_script + bindings = t.bindings return result, bindings - async def save_edge(self, element): - if not (element.source and element.target): - raise Exception("Edges require source/target vetices") - if hasattr(element, 'id'): - # Something like - # if self._current.get(element.id): - # old = self._current[element.id] - # element = merge_elements(old, element) - script, bindings = self._get_edge_by_id(element) - stream = await self.engine.execute(script, bindings=bindings) - result = await stream.fetch_data() - await stream.close() - if not result.data: - script, bindings = self._create_edge(element) - else: - script, bindings = self._update_edge(element) - else: - script, bindings = self._create_edge(element) - stream = await self.engine.execute(script, bindings=bindings) - result = await stream.fetch_data() - # Will just release the conn back to pool here - await stream.close() - return mapper.map_edge_to_ogm(result.data[0], element, - element.__mapping__) - def _create_edge(self, element): props = mapper.map_props_to_db(element, element.__mapping__) traversal = self.engine.g.V(element.source.id) @@ -191,14 +191,8 @@ class Session: def _update_edge(self, element): raise NotImplementedError - def _get_edge_by_id(self, element): - t = self.engine.g.E(element.id) - script = t.translator.traversal_script - bindings = t.bindings - return result, bindings - async def commit(self): - await self.flush() + raise NotImplementedError async def rollback(self): raise NotImplementedError