From 83e18b60255bfb6566f6d4032c88c1f0a3c6f5ac Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Sun, 10 Jul 2016 15:11:17 -0400 Subject: [PATCH] messing around with the mapper --- goblin/mapper.py | 26 +++++++++++++++++++------- goblin/query.py | 25 +++++++------------------ goblin/session.py | 17 ++++++----------- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/goblin/mapper.py b/goblin/mapper.py index 35ae6f7..1258fce 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -1,5 +1,7 @@ """Helper functions and class to map between OGM Elements <-> DB Elements""" import logging +import functools + import inflection @@ -23,7 +25,7 @@ def map_props_to_db(element, mapping): return property_tuples -def map_vertex_to_ogm(result, element, mapping): +def map_vertex_to_ogm(result, element, *, mapping=None): """Map a vertex returned by DB to OGM vertex""" props = mapping.properties for ogm_name, db_name, data_type in props_generator(props): @@ -34,7 +36,7 @@ def map_vertex_to_ogm(result, element, mapping): return element -def map_edge_to_ogm(result, element, mapping): +def map_edge_to_ogm(result, element, *, mapping=None): """Map an edge returned by DB to OGM edge""" props = mapping.properties for ogm_name, db_name, data_type in props_generator(props): @@ -50,16 +52,23 @@ def map_edge_to_ogm(result, element, mapping): # DB <-> OGM Mapping def create_mapping(namespace, properties): """Constructor for :py:class:`Mapping`""" - if namespace.get('__type__', None): - return Mapping(namespace, properties) + element_type = namespace.get('__type__', None) + if element_type: + if element_type == 'vertex': + mapping_func = map_vertex_to_ogm + return Mapping(namespace, element_type, mapping_func, properties) + elif element_type == 'edge': + mapping_func = map_edge_to_ogm + return Mapping(namespace, element_type, mapping_func, properties) class Mapping: """This class stores the information necessary to map between an OGM element and a DB element""" - def __init__(self, namespace, properties): + def __init__(self, namespace, element_type, mapper_func, properties): self._label = namespace.get('__label__', None) or self._create_label() - self._type = namespace['__type__'] + self._type = element_type + self._mapper_func = functools.partial(mapper_func, mapping=self) self._properties = {} self._map_properties(properties) @@ -67,6 +76,10 @@ class Mapping: def label(self): return self._label + @property + def mapper_func(self): + return self._mapper_func + @property def properties(self): return self._properties @@ -87,7 +100,6 @@ class Mapping: db_name = '{}__{}'.format(self._label, name) self._properties[name] = (db_name, data_type) - def __repr__(self): return '<{}(type={}, label={}, properties={})'.format( self.__class__.__name__, self._type, self._label, diff --git a/goblin/query.py b/goblin/query.py index 1887e83..69b05f8 100644 --- a/goblin/query.py +++ b/goblin/query.py @@ -37,11 +37,10 @@ class QueryResponse: class GoblinTraversal(gremlin_python.PythonGraphTraversal): - def __init__(self, translator, query, element_class, mapper_func): + def __init__(self, translator, query, element_class): super().__init__(translator, remote_connection=None) self._query = query self._element_class = element_class - self._mapper_func = mapper_func async def all(self): result = await self._query._all(self) @@ -52,10 +51,6 @@ class GoblinTraversal(gremlin_python.PythonGraphTraversal): def element_class(self): return self._element_class - @property - def mapper_func(self): - return self._mapper_func - class Query: """Provides interface for user generated queries""" @@ -87,14 +82,10 @@ class Query: def traversal(self, element_class): if element_class.__type__ == 'vertex': - mapper_func = mapper.map_vertex_to_ogm - traversal = GoblinTraversal(self._translator, self, element_class, - mapper_func) + traversal = GoblinTraversal(self._translator, self, element_class) traversal.translator.addSpawnStep(traversal, "V") if element_class.__type__ == 'edge': - mapper_func = mapper.map_edge_to_ogm - traversal = GoblinTraversal(self._translator, self, element_class, - mapper_func) + traversal = GoblinTraversal(self._translator, self, element_class) traversal.translator.addSpawnStep(traversal, "E") return traversal.hasLabel(element_class.__mapping__.label) @@ -104,21 +95,19 @@ class Query: async_iter = await self.session.execute_traversal(traversal) response_queue = asyncio.Queue(loop=self._loop) self._loop.create_task( - self._receive(async_iter, response_queue, traversal.element_class, - traversal.mapper_func)) + self._receive(async_iter, response_queue, traversal.element_class)) return QueryResponse(response_queue) - async def _receive(self, async_iter, response_queue, element_class, - mapper_func): + async def _receive(self, async_iter, response_queue, element_class): async for msg in async_iter: - # import ipdb; ipdb.set_trace() results = msg.data if results: for result in results: current = self.session.current.get(result['id'], None) if not current: current = element_class() - element = mapper_func(result, current, current.__mapping__) + element = element_class.__mapping__.mapper_func( + result, current) response_queue.put_nowait(element) response_queue.put_nowait(None) diff --git a/goblin/session.py b/goblin/session.py index 6456661..461e83a 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -59,8 +59,7 @@ class Session: result = await self._save_element(element, self.query.get_vertex_by_id, self.query.add_vertex, - self.query.update_vertex, - mapper.map_vertex_to_ogm) + self.query.update_vertex) self.current[result.id] = result return result @@ -70,8 +69,7 @@ class Session: result = await self._save_element(element, self.query.get_edge_by_id, self.query.add_edge, - self.query.update_edge, - mapper.map_edge_to_ogm) + self.query.update_edge) self.current[result.id] = result return result @@ -79,8 +77,7 @@ class Session: element, get_func, create_func, - update_func, - mapper_func): + update_func): if hasattr(element, 'id'): traversal = get_func(element) stream = await self.execute_traversal(traversal) @@ -93,7 +90,7 @@ class Session: traversal = create_func(element) stream = await self.execute_traversal(traversal) result = await stream.fetch_data() - return mapper_func(result.data[0], element, element.__mapping__) + return element.__mapping__.mapper_func(result.data[0], element) async def remove_vertex(self, element): traversal = self.query.remove_vertex(element) @@ -116,8 +113,7 @@ class Session: stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: - vertex = mapper.map_vertex_to_ogm(result.data[0], element, - element.__mapping__) + vertex = element.__mapping__.mapper_func(result.data[0], element) return vertex async def get_edge(self, element): @@ -125,8 +121,7 @@ class Session: stream = await self.execute_traversal(traversal) result = await stream.fetch_data() if result.data: - vertex = mapper.map_edge_to_ogm(result.data[0], element, - element.__mapping__) + vertex = element.__mapping__.mapper_func(result.data[0], element) return vertex async def execute_traversal(self, traversal): -- GitLab