diff --git a/goblin/mapper.py b/goblin/mapper.py index 3521153ea6f87290774c39fe96ebd4ef8208e583..35ae6f7385540836ff9a194dfa32aaedeb446ae5 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -7,14 +7,16 @@ logger = logging.getLogger(__name__) def props_generator(properties): - for prop in properties: - yield prop['ogm_name'], prop['db_name'], prop['data_type'] + for ogm_name, (db_name, data_type) in properties.items(): + yield ogm_name, db_name, data_type def map_props_to_db(element, mapping): """Convert OGM property names/values to DB property names/values""" property_tuples = [] props = mapping.properties + # What happens if unknown props come back on an element from a database? + # currently they are ignored... for ogm_name, db_name, data_type in props_generator(props): val = getattr(element, ogm_name, None) property_tuples.append((db_name, data_type.to_db(val))) @@ -58,7 +60,7 @@ class Mapping: def __init__(self, namespace, properties): self._label = namespace.get('__label__', None) or self._create_label() self._type = namespace['__type__'] - self._properties = [] + self._properties = {} self._map_properties(properties) @property @@ -69,6 +71,13 @@ class Mapping: def properties(self): return self._properties + def __getattr__(self, value): + try: + mapping, _ = self._properties[value] + return mapping + except: + raise Exception("Unknown property") + def _create_label(self): return inflection.underscore(self.__class__.__name__) @@ -76,9 +85,8 @@ class Mapping: for name, prop in properties.items(): data_type = prop.data_type db_name = '{}__{}'.format(self._label, name) - mapping = {'ogm_name': name, 'db_name': db_name, - 'data_type': data_type} - self._properties.append(mapping) + self._properties[name] = (db_name, data_type) + def __repr__(self): return '<{}(type={}, label={}, properties={})'.format( diff --git a/tests/test_engine.py b/tests/test_engine.py index 5501420670f3f16dedc4fc70ba1ea0950e7a7a46..061293a2e940a4be4da4571ae23afb821bc54299 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -209,11 +209,11 @@ class TestEngine(unittest.TestCase): session.add(leif, jon, works_for) await session.flush() result = await session.traversal(TestVertex).has( - ('k1', 'test_vertex__name'), ('v1', 'the one and only leifur'))._in().all() + TestVertex.__mapping__.name, ('v1', 'the one and only leifur'))._in().all() async for msg in result: self.assertIs(msg, jon) result = await session.traversal(TestVertex).has( - ('k1', 'test_vertex__name'), ('v1', 'the one and only jonathan')).out().all() + TestVertex.__mapping__.name, ('v1', 'the one and only jonathan')).out().all() async for msg in result: self.assertIs(msg, leif) await session.remove_vertex(leif)