From 83e18b60255bfb6566f6d4032c88c1f0a3c6f5ac Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Sun, 10 Jul 2016 15:11:17 -0400
Subject: [PATCH] messing around with the mapper

---
 goblin/mapper.py  | 26 +++++++++++++++++++-------
 goblin/query.py   | 25 +++++++------------------
 goblin/session.py | 17 ++++++-----------
 3 files changed, 32 insertions(+), 36 deletions(-)

diff --git a/goblin/mapper.py b/goblin/mapper.py
index 35ae6f7..1258fce 100644
--- a/goblin/mapper.py
+++ b/goblin/mapper.py
@@ -1,5 +1,7 @@
 """Helper functions and class to map between OGM Elements <-> DB Elements"""
 import logging
+import functools
+
 import inflection
 
 
@@ -23,7 +25,7 @@ def map_props_to_db(element, mapping):
     return property_tuples
 
 
-def map_vertex_to_ogm(result, element, mapping):
+def map_vertex_to_ogm(result, element, *, mapping=None):
     """Map a vertex returned by DB to OGM vertex"""
     props = mapping.properties
     for ogm_name, db_name, data_type in props_generator(props):
@@ -34,7 +36,7 @@ def map_vertex_to_ogm(result, element, mapping):
     return element
 
 
-def map_edge_to_ogm(result, element, mapping):
+def map_edge_to_ogm(result, element, *, mapping=None):
     """Map an edge returned by DB to OGM edge"""
     props = mapping.properties
     for ogm_name, db_name, data_type in props_generator(props):
@@ -50,16 +52,23 @@ def map_edge_to_ogm(result, element, mapping):
 # DB <-> OGM Mapping
 def create_mapping(namespace, properties):
     """Constructor for :py:class:`Mapping`"""
-    if namespace.get('__type__', None):
-        return Mapping(namespace, properties)
+    element_type = namespace.get('__type__', None)
+    if element_type:
+        if element_type == 'vertex':
+            mapping_func = map_vertex_to_ogm
+            return Mapping(namespace, element_type, mapping_func, properties)
+        elif element_type == 'edge':
+            mapping_func = map_edge_to_ogm
+            return Mapping(namespace, element_type, mapping_func, properties)
 
 
 class Mapping:
     """This class stores the information necessary to map between an
        OGM element and a DB element"""
-    def __init__(self, namespace, properties):
+    def __init__(self, namespace, element_type, mapper_func, properties):
         self._label = namespace.get('__label__', None) or self._create_label()
-        self._type = namespace['__type__']
+        self._type = element_type
+        self._mapper_func = functools.partial(mapper_func, mapping=self)
         self._properties = {}
         self._map_properties(properties)
 
@@ -67,6 +76,10 @@ class Mapping:
     def label(self):
         return self._label
 
+    @property
+    def mapper_func(self):
+        return self._mapper_func
+
     @property
     def properties(self):
         return self._properties
@@ -87,7 +100,6 @@ class Mapping:
             db_name = '{}__{}'.format(self._label, name)
             self._properties[name] = (db_name, data_type)
 
-
     def __repr__(self):
         return '<{}(type={}, label={}, properties={})'.format(
             self.__class__.__name__, self._type, self._label,
diff --git a/goblin/query.py b/goblin/query.py
index 1887e83..69b05f8 100644
--- a/goblin/query.py
+++ b/goblin/query.py
@@ -37,11 +37,10 @@ class QueryResponse:
 
 class GoblinTraversal(gremlin_python.PythonGraphTraversal):
 
-    def __init__(self, translator, query, element_class, mapper_func):
+    def __init__(self, translator, query, element_class):
         super().__init__(translator, remote_connection=None)
         self._query = query
         self._element_class = element_class
-        self._mapper_func = mapper_func
 
     async def all(self):
         result = await self._query._all(self)
@@ -52,10 +51,6 @@ class GoblinTraversal(gremlin_python.PythonGraphTraversal):
     def element_class(self):
         return self._element_class
 
-    @property
-    def mapper_func(self):
-        return self._mapper_func
-
 
 class Query:
     """Provides interface for user generated queries"""
@@ -87,14 +82,10 @@ class Query:
     def traversal(self, element_class):
 
         if element_class.__type__ == 'vertex':
-            mapper_func = mapper.map_vertex_to_ogm
-            traversal = GoblinTraversal(self._translator, self, element_class,
-                                        mapper_func)
+            traversal = GoblinTraversal(self._translator, self, element_class)
             traversal.translator.addSpawnStep(traversal, "V")
         if element_class.__type__ == 'edge':
-            mapper_func = mapper.map_edge_to_ogm
-            traversal = GoblinTraversal(self._translator, self, element_class,
-                                        mapper_func)
+            traversal = GoblinTraversal(self._translator, self, element_class)
             traversal.translator.addSpawnStep(traversal, "E")
         return traversal.hasLabel(element_class.__mapping__.label)
 
@@ -104,21 +95,19 @@ class Query:
         async_iter = await self.session.execute_traversal(traversal)
         response_queue = asyncio.Queue(loop=self._loop)
         self._loop.create_task(
-            self._receive(async_iter, response_queue, traversal.element_class,
-                          traversal.mapper_func))
+            self._receive(async_iter, response_queue, traversal.element_class))
         return QueryResponse(response_queue)
 
-    async def _receive(self, async_iter, response_queue, element_class,
-                       mapper_func):
+    async def _receive(self, async_iter, response_queue, element_class):
         async for msg in async_iter:
-            # import ipdb; ipdb.set_trace()
             results = msg.data
             if results:
                 for result in results:
                     current = self.session.current.get(result['id'], None)
                     if not current:
                         current = element_class()
-                    element = mapper_func(result, current, current.__mapping__)
+                    element = element_class.__mapping__.mapper_func(
+                        result, current)
                     response_queue.put_nowait(element)
         response_queue.put_nowait(None)
 
diff --git a/goblin/session.py b/goblin/session.py
index 6456661..461e83a 100644
--- a/goblin/session.py
+++ b/goblin/session.py
@@ -59,8 +59,7 @@ class Session:
         result = await self._save_element(element,
                                           self.query.get_vertex_by_id,
                                           self.query.add_vertex,
-                                          self.query.update_vertex,
-                                          mapper.map_vertex_to_ogm)
+                                          self.query.update_vertex)
         self.current[result.id] = result
         return result
 
@@ -70,8 +69,7 @@ class Session:
         result = await self._save_element(element,
                                           self.query.get_edge_by_id,
                                           self.query.add_edge,
-                                          self.query.update_edge,
-                                          mapper.map_edge_to_ogm)
+                                          self.query.update_edge)
         self.current[result.id] = result
         return result
 
@@ -79,8 +77,7 @@ class Session:
                             element,
                             get_func,
                             create_func,
-                            update_func,
-                            mapper_func):
+                            update_func):
         if hasattr(element, 'id'):
             traversal = get_func(element)
             stream = await self.execute_traversal(traversal)
@@ -93,7 +90,7 @@ class Session:
             traversal = create_func(element)
         stream = await self.execute_traversal(traversal)
         result = await stream.fetch_data()
-        return mapper_func(result.data[0], element, element.__mapping__)
+        return element.__mapping__.mapper_func(result.data[0], element)
 
     async def remove_vertex(self, element):
         traversal = self.query.remove_vertex(element)
@@ -116,8 +113,7 @@ class Session:
         stream = await self.execute_traversal(traversal)
         result = await stream.fetch_data()
         if result.data:
-            vertex = mapper.map_vertex_to_ogm(result.data[0], element,
-                                              element.__mapping__)
+            vertex = element.__mapping__.mapper_func(result.data[0], element)
             return vertex
 
     async def get_edge(self, element):
@@ -125,8 +121,7 @@ class Session:
         stream = await self.execute_traversal(traversal)
         result = await stream.fetch_data()
         if result.data:
-            vertex = mapper.map_edge_to_ogm(result.data[0], element,
-                                              element.__mapping__)
+            vertex = element.__mapping__.mapper_func(result.data[0], element)
             return vertex
 
     async def execute_traversal(self, traversal):
-- 
GitLab