From b3ebe3ae4b800589086312a4574bdb821b82529e Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Mon, 1 Aug 2016 15:35:28 -0400
Subject: [PATCH] add default id properties. pass hashable id callable

---
 goblin/abc.py         |  2 +-
 goblin/app.py         | 12 +++++++++---
 goblin/element.py     |  7 +++----
 goblin/mapper.py      |  8 ++++----
 goblin/properties.py  | 42 ++++++++++++++++++++++++++++++++++++++++++
 goblin/session.py     | 19 +++++++++++++------
 tests/test_session.py |  4 ++--
 7 files changed, 74 insertions(+), 20 deletions(-)

diff --git a/goblin/abc.py b/goblin/abc.py
index 652c14e..3f3dcf7 100644
--- a/goblin/abc.py
+++ b/goblin/abc.py
@@ -35,7 +35,7 @@ class DataType(abc.ABC):
     @abc.abstractmethod
     def validate(self, val):
         """Validate property value"""
-        raise NotImplementedError
+        return val
 
     @abc.abstractmethod
     def to_db(self, val=None):
diff --git a/goblin/app.py b/goblin/app.py
index 520919c..83cbf36 100644
--- a/goblin/app.py
+++ b/goblin/app.py
@@ -27,7 +27,7 @@ from goblin import driver, element, session
 logger = logging.getLogger(__name__)
 
 
-async def create_app(url, loop, **config):
+async def create_app(url, loop, get_hashable_id=None, **config):
     """
     Constructor function for :py:class:`Goblin`. Connect to database and
     build a dictionary of relevant vendor implmentation features.
@@ -63,7 +63,8 @@ async def create_app(url, loop, **config):
             'graph.features().graph().supportsThreadedTransactions()', aliases=aliases)
         msg = await stream.fetch_data()
         features['threaded_transactions'] = msg
-    return Goblin(url, loop, features=features, **config)
+    return Goblin(url, loop, get_hashable_id=get_hashable_id,
+                  features=features, **config)
 
 
 # Main API classes
@@ -83,7 +84,8 @@ class Goblin:
         'translator': process.GroovyTranslator('g')
     }
 
-    def __init__(self, url, loop, *, features=None, **config):
+    def __init__(self, url, loop, *, get_hashable_id=None, features=None,
+                 **config):
         self._url = url
         self._loop = loop
         self._features = features
@@ -92,6 +94,9 @@ class Goblin:
         self._vertices = collections.defaultdict(
             lambda: element.GenericVertex)
         self._edges = collections.defaultdict(lambda: element.GenericEdge)
+        if not get_hashable_id:
+            get_hashable_id = lambda x: x
+        self._get_hashable_id = get_hashable_id
 
     @property
     def vertices(self):
@@ -150,5 +155,6 @@ class Goblin:
         conn = await driver.GremlinServer.open(self.url, self._loop)
         return session.Session(self,
                                conn,
+                               self._get_hashable_id,
                                use_session=use_session,
                                aliases=aliases)
diff --git a/goblin/element.py b/goblin/element.py
index 263d223..23ba710 100644
--- a/goblin/element.py
+++ b/goblin/element.py
@@ -59,7 +59,7 @@ class ElementMeta(type):
 
 class Element(metaclass=ElementMeta):
     """Base class for classes that implement the Element property interface"""
-    pass
+    id = properties.IdProperty(properties.Generic)
 
 
 class Vertex(Element):
@@ -154,11 +154,12 @@ class VertexProperty(Vertex, abc.BaseProperty):
 
     __descriptor__ = VertexPropertyDescriptor
 
-    def __init__(self, data_type, *, default=None, db_name=None,
+    def __init__(self, data_type, *, val=None, default=None, db_name=None,
                  card=None):
         if isinstance(data_type, type):
             data_type = data_type()
         self._data_type = data_type
+        self._val = val
         self._default = default
         self._db_name = db_name
         if card is None:
@@ -179,8 +180,6 @@ class VertexProperty(Vertex, abc.BaseProperty):
     def setvalue(self, val):
         self._val = val
 
-    value = property(getvalue, setvalue)
-
     @property
     def db_name(self):
         return self._db_name
diff --git a/goblin/mapper.py b/goblin/mapper.py
index 0d82156..b69411c 100644
--- a/goblin/mapper.py
+++ b/goblin/mapper.py
@@ -83,7 +83,7 @@ def map_vertex_to_ogm(result, element, *, mapping=None):
             vert_prop = getattr(element, name)
             vert_prop.mapper_func(metaprop_dict, vert_prop)
     setattr(element, '__label__', result['label'])
-    setattr(element, 'id', result['id'])
+    setattr(element, '_id', result['id'])
     return element
 
 
@@ -110,7 +110,7 @@ def map_edge_to_ogm(result, element, *, mapping=None):
             value = data_type.to_ogm(value)
         setattr(element, name, value)
     setattr(element, '__label__', result['label'])
-    setattr(element, 'id', result['id'])
+    setattr(element, '_id', result['id'])
     setattr(element.source, '__label__', result['outVLabel'])
     setattr(element.target, '__label__', result['inVLabel'])
     sid = result['outV']
@@ -123,8 +123,8 @@ def map_edge_to_ogm(result, element, *, mapping=None):
     if _check_id(tid, etid):
         from goblin.element import GenericVertex
         element.target = GenericVertex()
-    setattr(element.source, 'id', sid)
-    setattr(element.target, 'id', tid)
+    setattr(element.source, '_id', sid)
+    setattr(element.target, '_id', tid)
     return element
 
 
diff --git a/goblin/properties.py b/goblin/properties.py
index 91b7572..7ab8d9a 100644
--- a/goblin/properties.py
+++ b/goblin/properties.py
@@ -85,7 +85,49 @@ class Property(abc.BaseProperty):
         return self._default
 
 
+class IdPropertyDescriptor:
+
+    def __init__(self, name, prop):
+        assert name == 'id', 'ID properties must be named "id"'
+        self._data_type = prop.data_type
+
+    def __get__(self, obj, objtype=None):
+        if obj is None:
+            raise exception.ElementError(
+                "Only instantiated elements have ID property")
+        return obj._id
+
+    def __set__(self, obj, val):
+        raise exception.ElementError('ID should not be set manually')
+
+
+class IdProperty(abc.BaseProperty):
+
+    __descriptor__ = IdPropertyDescriptor
+
+    def __init__(self, data_type):
+        if isinstance(data_type, type):
+            data_type = data_type()
+        self._data_type = data_type
+
+    @property
+    def data_type(self):
+        return self._data_type
+
+
 # Data types
+class Generic(abc.DataType):
+
+    def validate(self, val):
+        return super().validate(val)
+
+    def to_db(self, val=None):
+        return super().to_db(val=val)
+
+    def to_ogm(self, val):
+        return super().to_ogm(val)
+
+
 class String(abc.DataType):
     """Simple string datatype"""
 
diff --git a/goblin/session.py b/goblin/session.py
index f48fec9..772a4b9 100644
--- a/goblin/session.py
+++ b/goblin/session.py
@@ -41,7 +41,8 @@ class Session(connection.AbstractConnection):
     :param bool use_session: Support for Gremlin Server session. Not implemented
     """
 
-    def __init__(self, app, conn, *, use_session=False, aliases=None):
+    def __init__(self, app, conn, get_hashable_id, *, use_session=False,
+                 aliases=None):
         self._app = app
         self._conn = conn
         self._loop = self._app._loop
@@ -49,6 +50,7 @@ class Session(connection.AbstractConnection):
         self._aliases = aliases or dict()
         self._pending = collections.deque()
         self._current = weakref.WeakValueDictionary()
+        self._get_hashable_id = get_hashable_id
         remote_graph = graph.AsyncRemoteGraph(
             self._app.translator, self,
             graph_traversal=traversal.GoblinTraversal)
@@ -137,7 +139,8 @@ class Session(connection.AbstractConnection):
         async for result in async_iter:
             if (isinstance(result, dict) and
                     result.get('type', '') in ['vertex', 'edge']):
-                current = self.current.get(result['id'], None)
+                hashable_id = self._get_hashable_id(result['id'])
+                current = self.current.get(hashable_id, None)
                 if not current:
                     element_type = result['type']
                     label = result['label']
@@ -180,7 +183,8 @@ class Session(connection.AbstractConnection):
         """
         traversal = self.traversal_factory.remove_vertex(vertex)
         result = await self._simple_traversal(traversal, vertex)
-        vertex = self.current.pop(vertex.id)
+        hashable_id = self._get_hashable_id(vertex.id)
+        vertex = self.current.pop(hashable_id)
         del vertex
         return result
 
@@ -192,7 +196,8 @@ class Session(connection.AbstractConnection):
         """
         traversal = self.traversal_factory.remove_edge(edge)
         result = await self._simple_traversal(traversal, edge)
-        edge = self.current.pop(edge.id)
+        hashable_id = self._get_hashable_id(edge.id)
+        edge = self.current.pop(hashable_id)
         del edge
         return result
 
@@ -225,7 +230,8 @@ class Session(connection.AbstractConnection):
             vertex, self._check_vertex,
             self._add_vertex,
             self.update_vertex)
-        self.current[result.id] = result
+        hashable_id = self._get_hashable_id(result.id)
+        self.current[hashable_id] = result
         return result
 
     async def save_edge(self, edge):
@@ -243,7 +249,8 @@ class Session(connection.AbstractConnection):
             edge, self._check_edge,
             self._add_edge,
             self.update_edge)
-        self.current[result.id] = result
+        hashable_id = self._get_hashable_id(result.id)
+        self.current[hashable_id] = result
         return result
 
     async def get_vertex(self, vertex):
diff --git a/tests/test_session.py b/tests/test_session.py
index d60d3cd..544ea46 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -132,7 +132,7 @@ class TestCreationApi:
     @pytest.mark.asyncio
     async def test_get_vertex_doesnt_exist(self, session, person):
         async with session:
-            person.id = 1000000000000000000000000000000000000000000000
+            person._id = 1000000000000000000000000000000000000000000000
             result = await session.get_vertex(person)
             assert not result
 
@@ -144,7 +144,7 @@ class TestCreationApi:
             works_with = knows
             works_with.source = jon
             works_with.target = leif
-            works_with.id = 1000000000000000000000000000000000000000000000
+            works_with._id = 1000000000000000000000000000000000000000000000
             result = await session.get_edge(works_with)
             assert not result
 
-- 
GitLab