diff --git a/goblin/element.py b/goblin/element.py index bc8be9a441bc43bb80acbcd047c3eaefa78f96a4..38c7799d9a6fcca12ba28fe41737bf2419cb8707 100644 --- a/goblin/element.py +++ b/goblin/element.py @@ -32,9 +32,13 @@ class ElementMeta(type): :py:class:`property.Property` with :py:class:`property.PropertyDescriptor`. """ def __new__(cls, name, bases, namespace, **kwds): - if bases: + if name == 'VertexProperty': + element_type = name.lower() + elif bases: element_type = bases[0].__name__.lower() - namespace['__type__'] = element_type + else: + element_type = name.lower() + namespace['__type__'] = element_type if not namespace.get('__label__', None): namespace['__label__'] = inflection.underscore(name) props = {} diff --git a/goblin/manager.py b/goblin/manager.py index 9de16b6d64bbcc9e5a6cbd79a260d55dacdc694d..3f4810449853d5277dd1c0ca19666b644a0f731a 100644 --- a/goblin/manager.py +++ b/goblin/manager.py @@ -3,6 +3,10 @@ class VertexPropertyManager: + @property + def mapper_func(self): + return self._mapper_func + def __call__(self, val): results = [] for v in self: @@ -21,6 +25,7 @@ class ListVertexPropertyManager(list, VertexPropertyManager): self._data_type = data_type self._vertex_prop = vertex_prop self._card = card + self._mapper_func = vertex_prop.__mapping__.mapper_func list.__init__(self, obj) def append(self, val): @@ -35,6 +40,7 @@ class SetVertexPropertyManager(set, VertexPropertyManager): self._data_type = data_type self._vertex_prop = vertex_prop self._card = card + self._mapper_func = vertex_prop.__mapping__.mapper_func set.__init__(self, obj) def add(self, val): diff --git a/goblin/mapper.py b/goblin/mapper.py index 4e89d1f73eb801722d8afdcdfb3dbe473473a491..0d821563bf06f5db19ecc87c69f3b1aa8f0a8dce 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -25,7 +25,6 @@ from goblin import exception logger = logging.getLogger(__name__) -#######IMPLEMENT def map_props_to_db(element, mapping): """Convert OGM property names/values to DB property names/values""" property_tuples = [] @@ -35,35 +34,74 @@ def map_props_to_db(element, mapping): if val and isinstance(val, (list, set)): card = None for v in val: - # get metaprops as dic - metaprops = {} + metaprops = get_metaprops(v, v.__mapping__) property_tuples.append( (card, db_name, data_type.to_db(v.value), metaprops)) card = v.cardinality else: if hasattr(val, '__mapping__'): + metaprops = get_metaprops(val, val.__mapping__) val = val.value - property_tuples.append((None, db_name, data_type.to_db(val), None)) + else: + metaprops = None + property_tuples.append( + (None, db_name, data_type.to_db(val), metaprops)) return property_tuples +def get_metaprops(vertex_property, mapping): + props = mapping.ogm_properties + metaprops = {} + for ogm_name, (db_name, data_type) in props.items(): + val = getattr(vertex_property, ogm_name, None) + metaprops[db_name] = data_type.to_db(val) + return metaprops + + def map_vertex_to_ogm(result, element, *, mapping=None): """Map a vertex returned by DB to OGM vertex""" for db_name, value in result['properties'].items(): + metaprop_dict = {} if len(value) > 1: - # parse and assign vertex props + metas - value = [v['value'] for v in value] + values = [] + for v in value: + values.append(v['value']) + metaprops = v.get('properties', None) + if metaprops: + metaprop_dict[v['value']] = metaprops + value = values else: + metaprops = value[0].get('properties', None) value = value[0]['value'] + if metaprops: + metaprop_dict[value] = metaprops name, data_type = mapping.db_properties.get(db_name, (db_name, None)) if data_type: value = data_type.to_ogm(value) setattr(element, name, value) + if metaprop_dict: + vert_prop = getattr(element, name) + vert_prop.mapper_func(metaprop_dict, vert_prop) setattr(element, '__label__', result['label']) setattr(element, 'id', result['id']) return element +def map_vertex_property_to_ogm(result, element, *, mapping=None): + """Map a vertex returned by DB to OGM vertex""" + for val, metaprops in result.items(): + if isinstance(element, (list, set)): + current = element(val) + else: + current = element + for db_name, value in metaprops.items(): + name, data_type = mapping.db_properties.get( + db_name, (db_name, None)) + if data_type: + value = data_type.to_ogm(value) + setattr(current, name, value) + + def map_edge_to_ogm(result, element, *, mapping=None): """Map an edge returned by DB to OGM edge""" for db_name, value in result.get('properties', {}).items(): @@ -100,14 +138,22 @@ def _check_id(rid, eid): # DB <-> OGM Mapping def create_mapping(namespace, properties): """Constructor for :py:class:`Mapping`""" - element_type = namespace.get('__type__', None) - if element_type: - if element_type == 'vertex': - mapping_func = map_vertex_to_ogm - return Mapping(namespace, element_type, mapping_func, properties) - elif element_type == 'edge': - mapping_func = map_edge_to_ogm - return Mapping(namespace, element_type, mapping_func, properties) + element_type = namespace['__type__'] + if element_type == 'vertex': + mapping_func = map_vertex_to_ogm + mapping = Mapping( + namespace, element_type, mapping_func, properties) + elif element_type == 'edge': + mapping_func = map_edge_to_ogm + mapping = Mapping( + namespace, element_type, mapping_func, properties) + elif element_type == 'vertexproperty': + mapping_func = map_vertex_property_to_ogm + mapping = Mapping( + namespace, element_type, mapping_func, properties) + else: + mapping = None + return mapping class Mapping: diff --git a/goblin/session.py b/goblin/session.py index 6ed6f673a5a93d76b24942cb3f2445ad7194a3c1..8dcf0faa4f7f68cc3060012c04f9620c606bb9a0 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -195,7 +195,7 @@ class Session(connection.AbstractConnection): del edge return result - async def save(self, element): + async def save(self, elem): """ Save an element to the db. @@ -203,13 +203,13 @@ class Session(connection.AbstractConnection): :returns: :py:class:`Element<goblin.element.Element>` object """ - if element.__type__ == 'vertex': - result = await self.save_vertex(element) - elif element.__type__ == 'edge': - result = await self.save_edge(element) + if elem.__type__ == 'vertex': + result = await self.save_vertex(elem) + elif elem.__type__ == 'edge': + result = await self.save_edge(elem) else: raise exception.ElementError( - "Unknown element type: {}".format(element.__type__)) + "Unknown element type: {}".format(elem.__type__)) return result async def save_vertex(self, vertex): @@ -222,7 +222,7 @@ class Session(connection.AbstractConnection): """ result = await self._save_element( vertex, self._check_vertex, - self.traversal_factory.add_vertex, + self._add_vertex, self.update_vertex) self.current[result.id] = result return result @@ -240,7 +240,7 @@ class Session(connection.AbstractConnection): "Edges require both source/target vertices") result = await self._save_element( edge, self._check_edge, - self.traversal_factory.add_edge, + self._add_edge, self.update_edge) self.current[result.id] = result return result @@ -320,43 +320,89 @@ class Session(connection.AbstractConnection): return msg async def _save_element(self, - element, + elem, check_func, create_func, update_func): - if hasattr(element, 'id'): - result = await check_func(element) - if not result: - traversal = create_func(element) + if hasattr(elem, 'id'): + exists = await check_func(elem) + if not exists: + result = await create_func(elem) else: - traversal = await update_func(element) + result = await update_func(elem) else: - traversal = create_func(element) - return await self._simple_traversal(traversal, element) + result = await create_func(elem) + return result + + async def _add_vertex(self, elem): + """Convenience function for generating crud traversals.""" + props = mapper.map_props_to_db(elem, elem.__mapping__) + traversal = self.g.addV(elem.__mapping__.label) + traversal, _, metaprops = self.traversal_factory.add_properties( + traversal, props) + result = await self._simple_traversal(traversal, elem) + if metaprops: + await self._add_metaprops(result, metaprops) + traversal = self.traversal_factory.get_vertex_by_id(elem) + result = await self._simple_traversal(traversal, elem) + return result + + async def _add_edge(self, elem): + """Convenience function for generating crud traversals.""" + props = mapper.map_props_to_db(elem, elem.__mapping__) + traversal = self.g.V(elem.source.id) + traversal = traversal.addE(elem.__mapping__._label) + traversal = traversal.to( + self.g.V(elem.target.id)) + traversal, _, _ = self.traversal_factory.add_properties( + traversal, props) + return await self._simple_traversal(traversal, elem) - async def _check_vertex(self, element): - """Used to check for existence, does not update session element""" - traversal = self.g.V(element.id) + async def _check_vertex(self, vertex): + """Used to check for existence, does not update session vertex""" + traversal = self.g.V(vertex.id) stream = await self.conn.submit(repr(traversal)) return await stream.fetch_data() - async def _check_edge(self, element): - """Used to check for existence, does not update session element""" - traversal = self.g.E(element.id) + async def _check_edge(self, edge): + """Used to check for existence, does not update session edge""" + traversal = self.g.E(edge.id) stream = await self.conn.submit(repr(traversal)) return await stream.fetch_data() - async def _update_vertex_properties(self, element, traversal, props): - traversal, removals = self.traversal_factory.add_properties( + async def _update_vertex_properties(self, vertex, traversal, props): + traversal, removals, metaprops = self.traversal_factory.add_properties( traversal, props) - # traversal, removals = self.traversal_factory.add_vertex_properties(...) for k in removals: - await self.g.V(element.id).properties(k).drop().one_or_none() - return traversal + await self.g.V(vertex.id).properties(k).drop().one_or_none() + result = await self._simple_traversal(traversal, vertex) + if metaprops: + removals = await self._add_metaprops(result, metaprops) + for db_name, key, value in removals: + await self.g.V(vertex.id).properties( + db_name).has(key, value).drop().one_or_none() + traversal = self.traversal_factory.get_vertex_by_id(vertex) + result = await self._simple_traversal(traversal, vertex) + return result - async def _update_edge_properties(self, element, traversal, props): - traversal, removals = self.traversal_factory.add_properties( + async def _update_edge_properties(self, edge, traversal, props): + traversal, removals, _ = self.traversal_factory.add_properties( traversal, props) for k in removals: - await self.g.E(element.id).properties(k).drop().one_or_none() - return traversal + await self.g.E(edge.id).properties(k).drop().one_or_none() + return await self._simple_traversal(traversal, edge) + + async def _add_metaprops(self, result, metaprops): + potential_removals = [] + for metaprop in metaprops: + db_name, (binding, value), metaprops = metaprop + for key, val in metaprops.items(): + if val: + traversal = self.g.V(result.id).properties( + db_name).hasValue(value).property(key, val) + stream = await self.conn.submit( + repr(traversal), bindings=traversal.bindings) + await stream.fetch_data() + else: + potential_removals.append((db_name, key, value)) + return potential_removals diff --git a/goblin/traversal.py b/goblin/traversal.py index 15053c4a50150625f597f9f0da4da887314e21cf..f516a65c55276e12163257628e20ba97a8d39c62 100644 --- a/goblin/traversal.py +++ b/goblin/traversal.py @@ -139,33 +139,16 @@ class TraversalFactory: """Convenience function for generating crud traversals.""" return self.traversal().E(elem.id) - def add_vertex(self, elem): - """Convenience function for generating crud traversals.""" - props = mapper.map_props_to_db(elem, elem.__mapping__) - # vert_props = mapper.map_props_to_db - traversal = self.traversal().addV(elem.__mapping__.label) - traversal, _ = self.add_properties(traversal, props) - # traversal, _ = self.add_vertex_properties(...) - return traversal - - def add_edge(self, elem): - """Convenience function for generating crud traversals.""" - props = mapper.map_props_to_db(elem, elem.__mapping__) - traversal = self.traversal().V(elem.source.id) - traversal = traversal.addE(elem.__mapping__._label) - traversal = traversal.to( - self.traversal().V(elem.target.id)) - traversal, _ = self.add_properties(traversal, props) - return traversal - def add_properties(self, traversal, props): binding = 0 potential_removals = [] + potential_metaprops = [] for card, db_name, val, metaprops in props: if val: key = ('k' + str(binding), db_name) val = ('v' + str(binding), val) if card: + # Maybe use a dict here as a translator if card == cardinality.Cardinality.list: card = process.Cardinality.list elif card == cardinality.Cardinality.set: @@ -176,6 +159,8 @@ class TraversalFactory: else: traversal = traversal.property(key, val) binding += 1 + if metaprops: + potential_metaprops.append((db_name, val, metaprops)) else: potential_removals.append(db_name) - return traversal, potential_removals + return traversal, potential_removals, potential_metaprops diff --git a/tests/conftest.py b/tests/conftest.py index 0040676cd7751f3f803265a2d73c1b450704d33c..288d745ccfa4f7e83067eb3c885ff03454275031 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,8 +20,9 @@ from goblin import create_app, driver, element, properties, Cardinality from gremlin_python import process -# class PlaceName(element.VertexProperty): -# pass +class HistoricalName(element.VertexProperty): + notes = properties.Property(properties.String) + year = properties.Property(properties.Integer) # this is dumb but handy class Person(element.Vertex): @@ -37,10 +38,12 @@ class Person(element.Vertex): class Place(element.Vertex): name = properties.Property(properties.String) zipcode = properties.Property(properties.Integer) + historical_name = HistoricalName(properties.String, card=Cardinality.list) important_numbers = element.VertexProperty( properties.Integer, card=Cardinality.set) + class Knows(element.Edge): __label__ = 'knows' notes = properties.Property(properties.String, default='N/A') @@ -108,6 +111,11 @@ def boolean(): return properties.Boolean() +@pytest.fixture +def historical_name(): + return HistoricalName() + + @pytest.fixture def person(): return Person() @@ -144,6 +152,11 @@ def integer_class(): return properties.Integer +@pytest.fixture +def historical_name_class(): + return HistoricalName + + @pytest.fixture def person_class(): return Person diff --git a/tests/test_mapper.py b/tests/test_mapper.py index ffb86627f500ea2a34b4c75b9f94f44fe1f5a8b0..c6ea2882a65cf731154f5edc92d4b1533f5e46f5 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -22,26 +22,38 @@ from goblin import exception, properties def test_property_mapping(person, lives_in): db_name, data_type = person.__mapping__._ogm_properties['name'] - assert db_name == 'person__name' + assert db_name == 'person__name' assert isinstance(data_type, properties.String) db_name, data_type = person.__mapping__._ogm_properties['age'] - assert db_name == 'custom__person__age' + assert db_name == 'custom__person__age' assert isinstance(data_type, properties.Integer) db_name, data_type = lives_in.__mapping__._ogm_properties['notes'] - assert db_name == 'lives_in__notes' + assert db_name == 'lives_in__notes' assert isinstance(data_type, properties.String) ogm_name, data_type = person.__mapping__._db_properties['person__name'] - assert ogm_name == 'name' + assert ogm_name == 'name' assert isinstance(data_type, properties.String) ogm_name, data_type = person.__mapping__._db_properties['custom__person__age'] - assert ogm_name == 'age' + assert ogm_name == 'age' assert isinstance(data_type, properties.Integer) ogm_name, data_type = lives_in.__mapping__._db_properties['lives_in__notes'] assert ogm_name == 'notes' assert isinstance(data_type, properties.String) +def test_metaprop_mapping(place): + place.historical_name = ['Iowa City'] + db_name, data_type = place.historical_name( + 'Iowa City').__mapping__._ogm_properties['notes'] + assert db_name == 'historical_name__notes' + assert isinstance(data_type, properties.String) + db_name, data_type = place.historical_name( + 'Iowa City').__mapping__._ogm_properties['year'] + assert db_name == 'historical_name__year' + assert isinstance(data_type, properties.Integer) + + def test_label_creation(place, lives_in): assert place.__mapping__._label == 'place' assert lives_in.__mapping__._label == 'lives_in' diff --git a/tests/test_properties.py b/tests/test_properties.py index 80997e97447d44c8d6ddbbb061c265cb9e3896cc..fae77aad2ebfebc5faacbd1075a3d32146dc41cd 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -134,6 +134,26 @@ def test_cant_set_vertex_prop_on_edge(): vert_prop = element.VertexProperty(properties.String) +def test_meta_property_set_update(place): + assert not place.historical_name + place.historical_name = ['hispania', 'al-andalus'] + place.historical_name('hispania').notes = 'roman rule' + assert place.historical_name('hispania').notes == 'roman rule' + place.historical_name('hispania').year = 300 + assert place.historical_name('hispania').year == 300 + place.historical_name('al-andalus').notes = 'muslim rule' + assert place.historical_name('al-andalus').notes == 'muslim rule' + place.historical_name('al-andalus').year = 700 + assert place.historical_name('al-andalus').year == 700 + + +def test_meta_property_validation(place): + assert not place.historical_name + place.historical_name = ['spain'] + with pytest.raises(exception.ValidationError): + place.historical_name('spain').year = 'hello' + + class TestString: def test_validation(self, string): diff --git a/tests/test_vertex_properties_functional.py b/tests/test_vertex_properties_functional.py index ebe3e7aae79af639941124628e9705e7b7637c3f..b488bff4f3288b01aff83f2343dc0a71f0f99618 100644 --- a/tests/test_vertex_properties_functional.py +++ b/tests/test_vertex_properties_functional.py @@ -57,3 +57,63 @@ async def test_add_update_set_card_property(session, place): place.important_numbers = None result = await session.save(place) assert not result.important_numbers + + +@pytest.mark.asyncio +async def test_add_update_metas(session, place): + async with session: + place.historical_name = ['Detroit'] + place.historical_name('Detroit').notes = 'rock city' + place.historical_name('Detroit').year = 1900 + result = await session.save(place) + assert result.historical_name('Detroit').notes == 'rock city' + assert result.historical_name('Detroit').year == 1900 + + place.historical_name('Detroit').notes = 'comeback city' + place.historical_name('Detroit').year = 2016 + result = await session.save(place) + assert result.historical_name('Detroit').notes == 'comeback city' + assert result.historical_name('Detroit').year == 2016 + + place.historical_name('Detroit').notes = None + place.historical_name('Detroit').year = None + result = await session.save(place) + assert not result.historical_name('Detroit').notes + assert not result.historical_name('Detroit').year + + + + +@pytest.mark.asyncio +async def test_add_update_metas_list_card(session, place): + async with session: + place.historical_name = ['Hispania', 'Al-Andalus'] + place.historical_name('Hispania').notes = 'romans' + place.historical_name('Hispania').year = 200 + place.historical_name('Al-Andalus').notes = 'muslims' + place.historical_name('Al-Andalus').year = 700 + result = await session.save(place) + assert result.historical_name('Hispania').notes == 'romans' + assert result.historical_name('Hispania').year == 200 + assert result.historical_name('Al-Andalus').notes == 'muslims' + assert result.historical_name('Al-Andalus').year == 700 + + place.historical_name('Hispania').notes = 'really old' + place.historical_name('Hispania').year = 200 + place.historical_name('Al-Andalus').notes = 'less old' + place.historical_name('Al-Andalus').year = 700 + result = await session.save(place) + assert result.historical_name('Hispania').notes == 'really old' + assert result.historical_name('Hispania').year == 200 + assert result.historical_name('Al-Andalus').notes == 'less old' + assert result.historical_name('Al-Andalus').year == 700 + + place.historical_name('Hispania').notes = None + place.historical_name('Hispania').year = None + place.historical_name('Al-Andalus').notes = None + place.historical_name('Al-Andalus').year = None + result = await session.save(place) + assert not result.historical_name('Hispania').notes + assert not result.historical_name('Hispania').year + assert not result.historical_name('Al-Andalus').notes + assert not result.historical_name('Al-Andalus').year