Commit b3ebe3ae authored by davebshow's avatar davebshow
Browse files

add default id properties. pass hashable id callable

parent 7cd6d59e
......@@ -35,7 +35,7 @@ class DataType(abc.ABC):
@abc.abstractmethod
def validate(self, val):
"""Validate property value"""
raise NotImplementedError
return val
@abc.abstractmethod
def to_db(self, val=None):
......
......@@ -27,7 +27,7 @@ from goblin import driver, element, session
logger = logging.getLogger(__name__)
async def create_app(url, loop, **config):
async def create_app(url, loop, get_hashable_id=None, **config):
"""
Constructor function for :py:class:`Goblin`. Connect to database and
build a dictionary of relevant vendor implmentation features.
......@@ -63,7 +63,8 @@ async def create_app(url, loop, **config):
'graph.features().graph().supportsThreadedTransactions()', aliases=aliases)
msg = await stream.fetch_data()
features['threaded_transactions'] = msg
return Goblin(url, loop, features=features, **config)
return Goblin(url, loop, get_hashable_id=get_hashable_id,
features=features, **config)
# Main API classes
......@@ -83,7 +84,8 @@ class Goblin:
'translator': process.GroovyTranslator('g')
}
def __init__(self, url, loop, *, features=None, **config):
def __init__(self, url, loop, *, get_hashable_id=None, features=None,
**config):
self._url = url
self._loop = loop
self._features = features
......@@ -92,6 +94,9 @@ class Goblin:
self._vertices = collections.defaultdict(
lambda: element.GenericVertex)
self._edges = collections.defaultdict(lambda: element.GenericEdge)
if not get_hashable_id:
get_hashable_id = lambda x: x
self._get_hashable_id = get_hashable_id
@property
def vertices(self):
......@@ -150,5 +155,6 @@ class Goblin:
conn = await driver.GremlinServer.open(self.url, self._loop)
return session.Session(self,
conn,
self._get_hashable_id,
use_session=use_session,
aliases=aliases)
......@@ -59,7 +59,7 @@ class ElementMeta(type):
class Element(metaclass=ElementMeta):
"""Base class for classes that implement the Element property interface"""
pass
id = properties.IdProperty(properties.Generic)
class Vertex(Element):
......@@ -154,11 +154,12 @@ class VertexProperty(Vertex, abc.BaseProperty):
__descriptor__ = VertexPropertyDescriptor
def __init__(self, data_type, *, default=None, db_name=None,
def __init__(self, data_type, *, val=None, default=None, db_name=None,
card=None):
if isinstance(data_type, type):
data_type = data_type()
self._data_type = data_type
self._val = val
self._default = default
self._db_name = db_name
if card is None:
......@@ -179,8 +180,6 @@ class VertexProperty(Vertex, abc.BaseProperty):
def setvalue(self, val):
self._val = val
value = property(getvalue, setvalue)
@property
def db_name(self):
return self._db_name
......
......@@ -83,7 +83,7 @@ def map_vertex_to_ogm(result, element, *, mapping=None):
vert_prop = getattr(element, name)
vert_prop.mapper_func(metaprop_dict, vert_prop)
setattr(element, '__label__', result['label'])
setattr(element, 'id', result['id'])
setattr(element, '_id', result['id'])
return element
......@@ -110,7 +110,7 @@ def map_edge_to_ogm(result, element, *, mapping=None):
value = data_type.to_ogm(value)
setattr(element, name, value)
setattr(element, '__label__', result['label'])
setattr(element, 'id', result['id'])
setattr(element, '_id', result['id'])
setattr(element.source, '__label__', result['outVLabel'])
setattr(element.target, '__label__', result['inVLabel'])
sid = result['outV']
......@@ -123,8 +123,8 @@ def map_edge_to_ogm(result, element, *, mapping=None):
if _check_id(tid, etid):
from goblin.element import GenericVertex
element.target = GenericVertex()
setattr(element.source, 'id', sid)
setattr(element.target, 'id', tid)
setattr(element.source, '_id', sid)
setattr(element.target, '_id', tid)
return element
......
......@@ -85,7 +85,49 @@ class Property(abc.BaseProperty):
return self._default
class IdPropertyDescriptor:
def __init__(self, name, prop):
assert name == 'id', 'ID properties must be named "id"'
self._data_type = prop.data_type
def __get__(self, obj, objtype=None):
if obj is None:
raise exception.ElementError(
"Only instantiated elements have ID property")
return obj._id
def __set__(self, obj, val):
raise exception.ElementError('ID should not be set manually')
class IdProperty(abc.BaseProperty):
__descriptor__ = IdPropertyDescriptor
def __init__(self, data_type):
if isinstance(data_type, type):
data_type = data_type()
self._data_type = data_type
@property
def data_type(self):
return self._data_type
# Data types
class Generic(abc.DataType):
def validate(self, val):
return super().validate(val)
def to_db(self, val=None):
return super().to_db(val=val)
def to_ogm(self, val):
return super().to_ogm(val)
class String(abc.DataType):
"""Simple string datatype"""
......
......@@ -41,7 +41,8 @@ class Session(connection.AbstractConnection):
:param bool use_session: Support for Gremlin Server session. Not implemented
"""
def __init__(self, app, conn, *, use_session=False, aliases=None):
def __init__(self, app, conn, get_hashable_id, *, use_session=False,
aliases=None):
self._app = app
self._conn = conn
self._loop = self._app._loop
......@@ -49,6 +50,7 @@ class Session(connection.AbstractConnection):
self._aliases = aliases or dict()
self._pending = collections.deque()
self._current = weakref.WeakValueDictionary()
self._get_hashable_id = get_hashable_id
remote_graph = graph.AsyncRemoteGraph(
self._app.translator, self,
graph_traversal=traversal.GoblinTraversal)
......@@ -137,7 +139,8 @@ class Session(connection.AbstractConnection):
async for result in async_iter:
if (isinstance(result, dict) and
result.get('type', '') in ['vertex', 'edge']):
current = self.current.get(result['id'], None)
hashable_id = self._get_hashable_id(result['id'])
current = self.current.get(hashable_id, None)
if not current:
element_type = result['type']
label = result['label']
......@@ -180,7 +183,8 @@ class Session(connection.AbstractConnection):
"""
traversal = self.traversal_factory.remove_vertex(vertex)
result = await self._simple_traversal(traversal, vertex)
vertex = self.current.pop(vertex.id)
hashable_id = self._get_hashable_id(vertex.id)
vertex = self.current.pop(hashable_id)
del vertex
return result
......@@ -192,7 +196,8 @@ class Session(connection.AbstractConnection):
"""
traversal = self.traversal_factory.remove_edge(edge)
result = await self._simple_traversal(traversal, edge)
edge = self.current.pop(edge.id)
hashable_id = self._get_hashable_id(edge.id)
edge = self.current.pop(hashable_id)
del edge
return result
......@@ -225,7 +230,8 @@ class Session(connection.AbstractConnection):
vertex, self._check_vertex,
self._add_vertex,
self.update_vertex)
self.current[result.id] = result
hashable_id = self._get_hashable_id(result.id)
self.current[hashable_id] = result
return result
async def save_edge(self, edge):
......@@ -243,7 +249,8 @@ class Session(connection.AbstractConnection):
edge, self._check_edge,
self._add_edge,
self.update_edge)
self.current[result.id] = result
hashable_id = self._get_hashable_id(result.id)
self.current[hashable_id] = result
return result
async def get_vertex(self, vertex):
......
......@@ -132,7 +132,7 @@ class TestCreationApi:
@pytest.mark.asyncio
async def test_get_vertex_doesnt_exist(self, session, person):
async with session:
person.id = 1000000000000000000000000000000000000000000000
person._id = 1000000000000000000000000000000000000000000000
result = await session.get_vertex(person)
assert not result
......@@ -144,7 +144,7 @@ class TestCreationApi:
works_with = knows
works_with.source = jon
works_with.target = leif
works_with.id = 1000000000000000000000000000000000000000000000
works_with._id = 1000000000000000000000000000000000000000000000
result = await session.get_edge(works_with)
assert not result
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment