diff --git a/goblin/abc.py b/goblin/abc.py index 652c14ef71b91a048f730b3ede54001847392777..3f3dcf7c6002d7fd57e60e8be7726ad9bc7bac1b 100644 --- a/goblin/abc.py +++ b/goblin/abc.py @@ -35,7 +35,7 @@ class DataType(abc.ABC): @abc.abstractmethod def validate(self, val): """Validate property value""" - raise NotImplementedError + return val @abc.abstractmethod def to_db(self, val=None): diff --git a/goblin/app.py b/goblin/app.py index 520919cc6a73fe12bcaaebd28d4d7d47b78597ed..83cbf36e402f19ddfa8168ca63e38f937e942d66 100644 --- a/goblin/app.py +++ b/goblin/app.py @@ -27,7 +27,7 @@ from goblin import driver, element, session logger = logging.getLogger(__name__) -async def create_app(url, loop, **config): +async def create_app(url, loop, get_hashable_id=None, **config): """ Constructor function for :py:class:`Goblin`. Connect to database and build a dictionary of relevant vendor implmentation features. @@ -63,7 +63,8 @@ async def create_app(url, loop, **config): 'graph.features().graph().supportsThreadedTransactions()', aliases=aliases) msg = await stream.fetch_data() features['threaded_transactions'] = msg - return Goblin(url, loop, features=features, **config) + return Goblin(url, loop, get_hashable_id=get_hashable_id, + features=features, **config) # Main API classes @@ -83,7 +84,8 @@ class Goblin: 'translator': process.GroovyTranslator('g') } - def __init__(self, url, loop, *, features=None, **config): + def __init__(self, url, loop, *, get_hashable_id=None, features=None, + **config): self._url = url self._loop = loop self._features = features @@ -92,6 +94,9 @@ class Goblin: self._vertices = collections.defaultdict( lambda: element.GenericVertex) self._edges = collections.defaultdict(lambda: element.GenericEdge) + if not get_hashable_id: + get_hashable_id = lambda x: x + self._get_hashable_id = get_hashable_id @property def vertices(self): @@ -150,5 +155,6 @@ class Goblin: conn = await driver.GremlinServer.open(self.url, self._loop) return session.Session(self, conn, + self._get_hashable_id, use_session=use_session, aliases=aliases) diff --git a/goblin/element.py b/goblin/element.py index 263d223bd04d6d8a17929f2b1e394f8ade3d4e47..23ba710b583ba40a2818ef42337537dbc7bb104b 100644 --- a/goblin/element.py +++ b/goblin/element.py @@ -59,7 +59,7 @@ class ElementMeta(type): class Element(metaclass=ElementMeta): """Base class for classes that implement the Element property interface""" - pass + id = properties.IdProperty(properties.Generic) class Vertex(Element): @@ -154,11 +154,12 @@ class VertexProperty(Vertex, abc.BaseProperty): __descriptor__ = VertexPropertyDescriptor - def __init__(self, data_type, *, default=None, db_name=None, + def __init__(self, data_type, *, val=None, default=None, db_name=None, card=None): if isinstance(data_type, type): data_type = data_type() self._data_type = data_type + self._val = val self._default = default self._db_name = db_name if card is None: @@ -179,8 +180,6 @@ class VertexProperty(Vertex, abc.BaseProperty): def setvalue(self, val): self._val = val - value = property(getvalue, setvalue) - @property def db_name(self): return self._db_name diff --git a/goblin/mapper.py b/goblin/mapper.py index 0d821563bf06f5db19ecc87c69f3b1aa8f0a8dce..b69411c26d5b32a3ea6fdb64c9126e4de48ad1c3 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -83,7 +83,7 @@ def map_vertex_to_ogm(result, element, *, mapping=None): vert_prop = getattr(element, name) vert_prop.mapper_func(metaprop_dict, vert_prop) setattr(element, '__label__', result['label']) - setattr(element, 'id', result['id']) + setattr(element, '_id', result['id']) return element @@ -110,7 +110,7 @@ def map_edge_to_ogm(result, element, *, mapping=None): value = data_type.to_ogm(value) setattr(element, name, value) setattr(element, '__label__', result['label']) - setattr(element, 'id', result['id']) + setattr(element, '_id', result['id']) setattr(element.source, '__label__', result['outVLabel']) setattr(element.target, '__label__', result['inVLabel']) sid = result['outV'] @@ -123,8 +123,8 @@ def map_edge_to_ogm(result, element, *, mapping=None): if _check_id(tid, etid): from goblin.element import GenericVertex element.target = GenericVertex() - setattr(element.source, 'id', sid) - setattr(element.target, 'id', tid) + setattr(element.source, '_id', sid) + setattr(element.target, '_id', tid) return element diff --git a/goblin/properties.py b/goblin/properties.py index 91b7572094544447044597524bf98f2b61fa879f..7ab8d9a766e9bc04f3f076f458f4a5c3c36bb027 100644 --- a/goblin/properties.py +++ b/goblin/properties.py @@ -85,7 +85,49 @@ class Property(abc.BaseProperty): return self._default +class IdPropertyDescriptor: + + def __init__(self, name, prop): + assert name == 'id', 'ID properties must be named "id"' + self._data_type = prop.data_type + + def __get__(self, obj, objtype=None): + if obj is None: + raise exception.ElementError( + "Only instantiated elements have ID property") + return obj._id + + def __set__(self, obj, val): + raise exception.ElementError('ID should not be set manually') + + +class IdProperty(abc.BaseProperty): + + __descriptor__ = IdPropertyDescriptor + + def __init__(self, data_type): + if isinstance(data_type, type): + data_type = data_type() + self._data_type = data_type + + @property + def data_type(self): + return self._data_type + + # Data types +class Generic(abc.DataType): + + def validate(self, val): + return super().validate(val) + + def to_db(self, val=None): + return super().to_db(val=val) + + def to_ogm(self, val): + return super().to_ogm(val) + + class String(abc.DataType): """Simple string datatype""" diff --git a/goblin/session.py b/goblin/session.py index f48fec94b87949c6a5c1888575a90230f8e57f7e..772a4b9e81fe3e44d4ff682858b737d838f4ee01 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -41,7 +41,8 @@ class Session(connection.AbstractConnection): :param bool use_session: Support for Gremlin Server session. Not implemented """ - def __init__(self, app, conn, *, use_session=False, aliases=None): + def __init__(self, app, conn, get_hashable_id, *, use_session=False, + aliases=None): self._app = app self._conn = conn self._loop = self._app._loop @@ -49,6 +50,7 @@ class Session(connection.AbstractConnection): self._aliases = aliases or dict() self._pending = collections.deque() self._current = weakref.WeakValueDictionary() + self._get_hashable_id = get_hashable_id remote_graph = graph.AsyncRemoteGraph( self._app.translator, self, graph_traversal=traversal.GoblinTraversal) @@ -137,7 +139,8 @@ class Session(connection.AbstractConnection): async for result in async_iter: if (isinstance(result, dict) and result.get('type', '') in ['vertex', 'edge']): - current = self.current.get(result['id'], None) + hashable_id = self._get_hashable_id(result['id']) + current = self.current.get(hashable_id, None) if not current: element_type = result['type'] label = result['label'] @@ -180,7 +183,8 @@ class Session(connection.AbstractConnection): """ traversal = self.traversal_factory.remove_vertex(vertex) result = await self._simple_traversal(traversal, vertex) - vertex = self.current.pop(vertex.id) + hashable_id = self._get_hashable_id(vertex.id) + vertex = self.current.pop(hashable_id) del vertex return result @@ -192,7 +196,8 @@ class Session(connection.AbstractConnection): """ traversal = self.traversal_factory.remove_edge(edge) result = await self._simple_traversal(traversal, edge) - edge = self.current.pop(edge.id) + hashable_id = self._get_hashable_id(edge.id) + edge = self.current.pop(hashable_id) del edge return result @@ -225,7 +230,8 @@ class Session(connection.AbstractConnection): vertex, self._check_vertex, self._add_vertex, self.update_vertex) - self.current[result.id] = result + hashable_id = self._get_hashable_id(result.id) + self.current[hashable_id] = result return result async def save_edge(self, edge): @@ -243,7 +249,8 @@ class Session(connection.AbstractConnection): edge, self._check_edge, self._add_edge, self.update_edge) - self.current[result.id] = result + hashable_id = self._get_hashable_id(result.id) + self.current[hashable_id] = result return result async def get_vertex(self, vertex): diff --git a/tests/test_session.py b/tests/test_session.py index d60d3cd6ec530946ba84942e8ebf321a665c32b5..544ea465ee47688f40fa059b12a487b082375214 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -132,7 +132,7 @@ class TestCreationApi: @pytest.mark.asyncio async def test_get_vertex_doesnt_exist(self, session, person): async with session: - person.id = 1000000000000000000000000000000000000000000000 + person._id = 1000000000000000000000000000000000000000000000 result = await session.get_vertex(person) assert not result @@ -144,7 +144,7 @@ class TestCreationApi: works_with = knows works_with.source = jon works_with.target = leif - works_with.id = 1000000000000000000000000000000000000000000000 + works_with._id = 1000000000000000000000000000000000000000000000 result = await session.get_edge(works_with) assert not result