diff --git a/goblin/__init__.py b/goblin/__init__.py index 4b264cc3219566e347fa546c8021b6c639dd6243..c6678709b0e327159d62d415f60359fe9c50b3f1 100644 --- a/goblin/__init__.py +++ b/goblin/__init__.py @@ -1,3 +1,3 @@ from goblin.element import Vertex, Edge, VertexProperty -from goblin.app import create_app, App +from goblin.app import create_app, Goblin from goblin.properties import Property, String diff --git a/goblin/app.py b/goblin/app.py index 2e5866599904aa36fbe4505914a69baa6ba03403..790e7ec646808f298e4e39f9da2175a5ff874add 100644 --- a/goblin/app.py +++ b/goblin/app.py @@ -3,8 +3,7 @@ import collections import logging from goblin.gremlin_python import process -from goblin import driver -from goblin import session +from goblin import driver, element, session logger = logging.getLogger(__name__) @@ -36,11 +35,11 @@ async def create_app(url, loop, **config): 'graph.features().graph().supportsThreadedTransactions()') msg = await stream.fetch_data() features['threaded_transactions'] = msg - return App(url, loop, features=features, **config) + return Goblin(url, loop, features=features, **config) # Main API classes -class App: +class Goblin: """Class used to encapsulate database connection configuration and generate database connections. Used as a factory to create :py:class:`Session` objects. More config coming soon.""" @@ -54,8 +53,9 @@ class App: self._features = features self._config = self.DEFAULT_CONFIG self._config.update(config) - self._vertices = {} - self._edges = {} + self._vertices = collections.defaultdict( + lambda: element.Vertex) + self._edges = collections.defaultdict(lambda: element.Edge) @property def vertices(self): @@ -65,6 +65,10 @@ class App: def edges(self): return self._edges + @property + def features(self): + return self._features + def from_file(filepath): pass diff --git a/goblin/driver/__init__.py b/goblin/driver/__init__.py index 8e929ce5afb0b1c62b265c5c46528b332032d31a..26f294091ddb3de7c3c8ae8854fe4784c76d1e6c 100644 --- a/goblin/driver/__init__.py +++ b/goblin/driver/__init__.py @@ -1,2 +1,3 @@ from goblin.driver.api import GremlinServer from goblin.driver.connection import AbstractConnection +from goblin.driver.graph import AsyncRemoteGraph diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 2af1fd95b03d63567e3754a6dbfc78f32fe0b2e7..4c772c43017c345e4aa1eb4a1fb5be0e685a98a3 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -1,6 +1,7 @@ import abc import asyncio import collections +import functools import json import logging import uuid @@ -11,7 +12,20 @@ logger = logging.getLogger(__name__) Message = collections.namedtuple( "Message", - ["status_code", "data", "message", "metadata"]) + ["status_code", "data", "message"]) + + +def error_handler(fn): + @functools.wraps(fn) + async def wrapper(self): + msg = await fn(self) + if msg: + if msg.status_code not in [200, 206, 204]: + raise RuntimeError( + "{0}: {1}".format(msg.status_code, msg.message)) + msg = msg.data + return msg + return wrapper class Response: @@ -31,6 +45,7 @@ class Response: else: raise StopAsyncIteration + @error_handler async def fetch_data(self): if self._done: return None @@ -101,7 +116,7 @@ class Connection(AbstractConnection): response_queue = asyncio.Queue(loop=self._loop) self.response_queues[request_id] = response_queue if self._ws.closed: - self._ws = await self.conn_factory.ws_connect(self._url) + self._ws = await self.conn_factory.ws_connect(self.url) self._ws.send_bytes(message) self._loop.create_task(self.receive()) return Response(response_queue, self._loop) @@ -164,29 +179,28 @@ class Connection(AbstractConnection): # parse aiohttp response here message = json.loads(data.data.decode("utf-8")) request_id = message['requestId'] - message = Message(message["status"]["code"], - message["result"]["data"], - message["status"]["message"], - message["result"]["meta"]) + status_code = message['status']['code'] + data = message["result"]["data"] response_queue = self._response_queues[request_id] - if message.status_code in [200, 206, 204]: - if message.data: - for result in message.data: - response_queue.put_nowait(result) - if message.status_code == 206: - self._loop.create_task(self.receive()) - else: - response_queue.put_nowait(None) - del self._response_queues[request_id] - elif message.status_code == 407: + if status_code == 407: await self._authenticate(self._username, self._password, self._processor, self._session) self._loop.create_task(self.receive()) else: - del self._response_queues[request_id] - raise RuntimeError("{0} {1}".format(message.status_code, - message.message)) - + if data: + for result in data: + message = Message(status_code, result, + message["status"]["message"]) + response_queue.put_nowait(message) + else: + message = Message(status_code, data, + message["status"]["message"]) + response_queue.put_nowait(message) + if status_code == 206: + self._loop.create_task(self.receive()) + else: + response_queue.put_nowait(None) + del self._response_queues[request_id] async def __aenter__(self): return self diff --git a/goblin/driver/graph.py b/goblin/driver/graph.py index cd2fc2e932019c97528f7871c13b8b02556fa40b..09d0973a89456260d7ad0f80d297c144508c221e 100644 --- a/goblin/driver/graph.py +++ b/goblin/driver/graph.py @@ -46,3 +46,13 @@ class AsyncRemoteGraph(AsyncGraph): def __repr__(self): return "remotegraph[" + self.remote_connection.url + "]" + + async def close(self): + await self.remote_connection.close() + self.remote_connection = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() diff --git a/goblin/element.py b/goblin/element.py index 4d5db7e7092ed9015d1a1307a90fa49a2eaedb59..42b310cf1aca5ab1330f472bfc221130d2503b01 100644 --- a/goblin/element.py +++ b/goblin/element.py @@ -1,5 +1,7 @@ import logging +import inflection + from goblin import abc from goblin import mapper from goblin import properties @@ -16,6 +18,8 @@ class ElementMeta(type): def __new__(cls, name, bases, namespace, **kwds): if bases: namespace['__type__'] = bases[0].__name__.lower() + if not namespace.get('__label__', None): + namespace['__label__'] = inflection.underscore(name) props = {} new_namespace = {} for k, v in namespace.items(): @@ -23,8 +27,7 @@ class ElementMeta(type): props[k] = v v = v.__descriptor__(k, v) new_namespace[k] = v - new_namespace['__mapping__'] = mapper.create_mapping(namespace, - props) + new_namespace['__mapping__'] = mapper.create_mapping(namespace, props) logger.warning("Creating new Element class {}: {}".format( name, new_namespace['__mapping__'])) result = type.__new__(cls, name, bases, new_namespace) @@ -44,10 +47,8 @@ class Edge(Element): """Base class for user defined Edge classes""" def __init__(self, source=None, target=None): - if source: - self._source = source - if target: - self._target = target + self._source = source + self._target = target def getsource(self): return self._source diff --git a/goblin/mapper.py b/goblin/mapper.py index 1258fced6936f5e65c70acfc6661b438fc6837fe..1b8d09aa2ed65505dfef956e11a2231a04c7fb63 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -2,24 +2,15 @@ import logging import functools -import inflection - logger = logging.getLogger(__name__) -def props_generator(properties): - for ogm_name, (db_name, data_type) in properties.items(): - yield ogm_name, db_name, data_type - - def map_props_to_db(element, mapping): """Convert OGM property names/values to DB property names/values""" property_tuples = [] props = mapping.properties - # What happens if unknown props come back on an element from a database? - # currently they are ignored... - for ogm_name, db_name, data_type in props_generator(props): + for ogm_name, (db_name, data_type) in props.items(): val = getattr(element, ogm_name, None) property_tuples.append((db_name, data_type.to_db(val))) return property_tuples @@ -27,10 +18,13 @@ def map_props_to_db(element, mapping): def map_vertex_to_ogm(result, element, *, mapping=None): """Map a vertex returned by DB to OGM vertex""" - props = mapping.properties - for ogm_name, db_name, data_type in props_generator(props): - val = result['properties'].get(db_name, [{'value': None}])[0]['value'] - setattr(element, ogm_name, data_type.to_ogm(val)) + for db_name, value in result['properties'].items(): + # This will be more complex for vertex properties... + value = value[0]['value'] + name, data_type = mapping.properties.get(db_name, (db_name, None)) + if data_type: + value = data_type.to_ogm(value) + setattr(element, name, value) setattr(element, '__label__', result['label']) setattr(element, 'id', result['id']) return element @@ -38,17 +32,37 @@ def map_vertex_to_ogm(result, element, *, mapping=None): def map_edge_to_ogm(result, element, *, mapping=None): """Map an edge returned by DB to OGM edge""" - props = mapping.properties - for ogm_name, db_name, data_type in props_generator(props): - val = result['properties'].get(db_name, None) - setattr(element, ogm_name, data_type.to_ogm(val)) + for db_name, value in result.items(): + name, data_type = mapping.properties.get(db_name, (db_name, None)) + if data_type: + value = data_type.to_ogm(value) + setattr(element, name, value) setattr(element, '__label__', result['label']) setattr(element, 'id', result['id']) - setattr(element.source, '__label__', result['inVLabel']) - setattr(element.target, '__label__', result['outVLabel']) + setattr(element.source, '__label__', result['outVLabel']) + setattr(element.target, '__label__', result['inVLabel']) + sid = result['outV'] + esid = getattr(element.source, 'id', None) + if _check_id(sid, esid): + from goblin.element import Vertex + element.source = Vertex() + tid = result['inV'] + etid = getattr(element.target, 'id', None) + if _check_id(tid, etid): + from goblin.element import Vertex + element.target = Vertex() + setattr(element.source, 'id', sid) + setattr(element.target, 'id', tid) return element +def _check_id(rid, eid): + if eid and rid != eid: + logger.warning('Edge vertex id has changed') + return True + return False + + # DB <-> OGM Mapping def create_mapping(namespace, properties): """Constructor for :py:class:`Mapping`""" @@ -66,7 +80,7 @@ class Mapping: """This class stores the information necessary to map between an OGM element and a DB element""" def __init__(self, namespace, element_type, mapper_func, properties): - self._label = namespace.get('__label__', None) or self._create_label() + self._label = namespace['__label__'] self._type = element_type self._mapper_func = functools.partial(mapper_func, mapping=self) self._properties = {} @@ -91,13 +105,11 @@ class Mapping: except: raise Exception("Unknown property") - def _create_label(self): - return inflection.underscore(self.__class__.__name__) - def _map_properties(self, properties): for name, prop in properties.items(): data_type = prop.data_type db_name = '{}__{}'.format(self._label, name) + self._properties[db_name] = (name, data_type) self._properties[name] = (db_name, data_type) def __repr__(self): diff --git a/goblin/properties.py b/goblin/properties.py index dcf60296bbf17dbabf00d41f32533113cca53dff..3bc6aef4c639f71c6fce84de55c8e6ce0f33f88e 100644 --- a/goblin/properties.py +++ b/goblin/properties.py @@ -68,3 +68,29 @@ class String(abc.DataType): def to_ogm(self, val): return super().to_ogm(val) + + +class Integer(abc.DataType): + """Simple string datatype""" + + def validate(self, val): + """Need to think about this.""" + if val is not None: + try: + return int(val) + except Exception as e: + raise Exception("Invalid") from e + + def to_db(self, val): + return super().to_db(val) + + def to_ogm(self, val): + return super().to_ogm(val) + + +class Float(abc.DataType): + pass + + +class Bool(abc.DataType): + pass diff --git a/goblin/session.py b/goblin/session.py index 52519423da8a2e23ea55fac9f50bf8340595d70f..b9b6a99e9fe0d9c8e9ae48948458277ec2027124 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -2,6 +2,7 @@ import asyncio import collections import logging +import weakref from goblin import mapper from goblin import traversal @@ -21,7 +22,7 @@ class Session(connection.AbstractConnection): self._loop = self._app._loop self._use_session = False self._pending = collections.deque() - self._current = {} + self._current = weakref.WeakValueDictionary() remote_graph = graph.AsyncRemoteGraph( self._app.translator, self, graph_traversal=traversal.GoblinTraversal) @@ -46,7 +47,7 @@ class Session(connection.AbstractConnection): async def __aenter__(self): return self - async def __aexit__(self): + async def __aexit__(self, exc_type, exc, tb): await self.close() async def close(self): @@ -84,14 +85,11 @@ class Session(connection.AbstractConnection): element_type = result['type'] label = result['label'] if element_type == 'vertex': - current = self.app.vertices.get(label, None) + current = self.app.vertices[label]() else: - current = self.app.edges.get(label, None) - if not current: - # build generic element here - pass - else: - current = current() + current = self.app.edges[label]() + current.source = element.Vertex() + current.target = element.Vertex() element = current.__mapping__.mapper_func(result, current) response_queue.put_nowait(element) response_queue.put_nowait(None) @@ -109,13 +107,15 @@ class Session(connection.AbstractConnection): async def remove_vertex(self, element): traversal = self.traversal_factory.remove_vertex(element) result = await self._simple_traversal(traversal, element) - del self.current[element.id] + element = self.current.pop(element.id) + del element return result async def remove_edge(self, element): traversal = self.traversal_factory.remove_edge(element) result = await self._simple_traversal(traversal, element) - del self.current[element.id] + element = self.current.pop(element.id) + del element return result async def save(self, element): @@ -131,7 +131,7 @@ class Session(connection.AbstractConnection): result = await self._save_element( element, self._check_vertex, self.traversal_factory.add_vertex, - self.traversal_factory.update_vertex) + self.update_vertex) self.current[result.id] = result return result @@ -141,7 +141,7 @@ class Session(connection.AbstractConnection): result = await self._save_element( element, self._check_edge, self.traversal_factory.add_edge, - self.traversal_factory.update_edge) + self.update_edge) self.current[result.id] = result return result @@ -153,6 +153,16 @@ class Session(connection.AbstractConnection): return await self.traversal_factory.get_edge_by_id( element).one_or_none() + async def update_vertex(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.traversal().V(element.id) + traversal = await self._update_properties(element, traversal, props) + + async def update_edge(self, element): + props = mapper.map_props_to_db(element, element.__mapping__) + traversal = self.traversal().E(element.id) + return await self._update_properties(element, traversal, props) + # Transaction support def tx(self): raise NotImplementedError @@ -188,7 +198,7 @@ class Session(connection.AbstractConnection): if not result: traversal = create_func(element) else: - traversal = update_func(element) + traversal = await update_func(element) else: traversal = create_func(element) return await self._simple_traversal(traversal, element) @@ -204,3 +214,16 @@ class Session(connection.AbstractConnection): traversal = self.g.E(element.id) stream = await self.conn.submit(repr(traversal)) return await stream.fetch_data() + + async def _update_properties(self, element, traversal, props): + binding = 0 + for k, v in props: + if v: + traversal = traversal.property( + ('k' + str(binding), k), + ('v' + str(binding), v)) + else: + await self.g.V(element.id).properties( + ('k' + str(binding), k)).drop().one_or_none() + binding += 1 + return traversal diff --git a/goblin/traversal.py b/goblin/traversal.py index 2ebdfd04e9ad1b78135c6f86d89ed7180ded2f2b..74c7b9a4f215a3d7ceca9b89855c687559287a5a 100644 --- a/goblin/traversal.py +++ b/goblin/traversal.py @@ -30,7 +30,7 @@ class TraversalResponse: raise StopAsyncIteration -# This is all a hack until we figure out GLV integration... +# This is all until we figure out GLV integration... class GoblinTraversal(graph.AsyncGraphTraversal): async def all(self): @@ -38,14 +38,13 @@ class GoblinTraversal(graph.AsyncGraphTraversal): async def one_or_none(self): async for msg in await self.next(): - return resp + return msg class TraversalFactory: """Helper that wraps a AsyncRemoteGraph""" def __init__(self, graph): self._graph = graph - self._binding = 0 @property def graph(self): @@ -88,22 +87,12 @@ class TraversalFactory: self.traversal().V(element.target.id)) return self._add_properties(traversal, props) - def update_vertex(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal().V(element.id) - return self._add_properties(traversal, props) - - def update_edge(self, element): - props = mapper.map_props_to_db(element, element.__mapping__) - traversal = self.traversal().E(element.id) - return self._add_properties(traversal, props) - def _add_properties(self, traversal, props): + binding = 0 for k, v in props: if v: traversal = traversal.property( - ('k' + str(self._binding), k), - ('v' + str(self._binding), v)) - self._binding += 1 - self._binding = 0 + ('k' + str(binding), k), + ('v' + str(binding), v)) + binding += 1 return traversal diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..b7e478982ccf9ab1963c74e1084dfccb6e42c583 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[aliases] +test=pytest diff --git a/setup.py b/setup.py index dc6daf40c01c25b55100bac7f0f261f2a5c820d5..186a91116295219cec74d6b8887618f292016964 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,8 @@ setup( "inflection==0.3.1" ], test_suite="tests", + setup_requires=['pytest-runner'], + tests_require=['pytest-asyncio', 'pytest'], classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b6a4167a2735af15e893f3cad273f22f333470 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,132 @@ +import pytest +from goblin import create_app, driver, element, properties +from goblin.gremlin_python import process + + +class Person(element.Vertex): + __label__ = 'person' + name = properties.Property(properties.String) + age = properties.Property(properties.Integer) + + +class Place(element.Vertex): + name = properties.Property(properties.String) + zipcode = properties.Property(properties.Integer) + + +class Knows(element.Edge): + __label__ = 'knows' + notes = properties.Property(properties.String, default='N/A') + + +class LivesIn(element.Edge): + notes = properties.Property(properties.String) + + +@pytest.fixture +def gremlin_server(): + return driver.GremlinServer + + +@pytest.fixture +def unused_server_url(unused_tcp_port): + return 'http://localhost:{}/'.format(unused_tcp_port) + + +@pytest.fixture +def connection(gremlin_server, event_loop): + conn = event_loop.run_until_complete( + gremlin_server.open("http://localhost:8182/", event_loop)) + return conn + + +@pytest.fixture +def remote_graph(connection): + translator = process.GroovyTranslator('g') + return driver.AsyncRemoteGraph(translator, connection) + + +@pytest.fixture +def app(event_loop): + app = event_loop.run_until_complete( + create_app("http://localhost:8182/", event_loop)) + app.register(Person, Place, Knows, LivesIn) + return app + + +@pytest.fixture +def session(event_loop, app): + session = event_loop.run_until_complete(app.session()) + return session + + +# Instance fixtures +@pytest.fixture +def string(): + return properties.String() + + +@pytest.fixture +def integer(): + return properties.Integer() + + +@pytest.fixture +def person(): + return Person() + + +@pytest.fixture +def place(): + return Place() + + +@pytest.fixture +def knows(): + return Knows() + + +@pytest.fixture +def lives_in(): + return LivesIn() + + +@pytest.fixture +def place_name(): + return PlaceName() + + +# Class fixtures +@pytest.fixture +def string_class(): + return properties.String + + +@pytest.fixture +def integer_class(): + return properties.Integer + + +@pytest.fixture +def person_class(): + return Person + + +@pytest.fixture +def place_class(): + return Place + + +@pytest.fixture +def knows_class(): + return Knows + + +@pytest.fixture +def lives_in_class(): + return LivesIn + + +@pytest.fixture +def place_name_class(): + return PlaceName diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 0000000000000000000000000000000000000000..05ad04e6c5b052002bbeab892a2d47621e3bbf83 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,26 @@ +from goblin import element +from goblin.gremlin_python import process + + +def test_registry(app, person, place, knows, lives_in): + assert len(app.vertices) == 2 + assert len(app.edges) == 2 + assert person.__class__ == app.vertices['person'] + assert place.__class__ == app.vertices['place'] + assert knows.__class__ == app.edges['knows'] + assert lives_in.__class__ == app.edges['lives_in'] + + +def test_registry_defaults(app): + vertex = app.vertices['unregistered'] + assert isinstance(vertex(), element.Vertex) + edge = app.edges['unregistered'] + assert isinstance(edge(), element.Edge) + + +def test_features(app): + assert app._features + + +def test_translator(app): + assert isinstance(app.translator, process.GroovyTranslator) diff --git a/tests/test_driver.py b/tests/test_driver.py index 0b486a8299cd48fb6a8d5950a5bb445138a10acf..7b2de032277cfef758a3c172ba14a422a2896e22 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,60 +1,73 @@ -import asyncio -import unittest +import pytest -from goblin import driver -from goblin.driver import graph -from goblin.gremlin_python import process +@pytest.mark.asyncio +async def test_get_close_conn(connection): + ws = connection._ws + assert not ws.closed + assert not connection.closed + await connection.close() + assert connection.closed + assert ws.closed -class TestDriver(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) +@pytest.mark.asyncio +async def test_conn_context_manager(connection): + async with connection: + assert not connection.closed + assert connection.closed - def test_open(self): - async def go(): - connection = await driver.GremlinServer.open( - "http://localhost:8182/", self.loop) - async with connection: - self.assertFalse(connection._ws.closed) - self.assertTrue(connection._ws.closed) +@pytest.mark.asyncio +async def test_submit(connection): + async with connection: + stream = await connection.submit("1 + 1") + results = [] + async for msg in stream: + results.append(msg) + assert len(results) == 1 + assert results[0] == 2 - self.loop.run_until_complete(go()) - def test_open_as_ctx_mng(self): +@pytest.mark.asyncio +async def test_204_empty_stream(connection): + resp = False + async with connection: + stream = await connection.submit('g.V().has("unlikely", "even less likely")') + async for msg in stream: + resp = True + assert not resp - async def go(): - async with await driver.GremlinServer.open( - "http://localhost:8182/", self.loop) as connection: - self.assertFalse(connection._ws.closed) - self.assertTrue(connection._ws.closed) - self.loop.run_until_complete(go()) +@pytest.mark.asyncio +async def test_server_error(connection): + async with connection: + stream = await connection.submit('g. V jla;sdf') + with pytest.raises(Exception): + async for msg in stream: + pass - def test_submit(self): - async def go(): - connection = await driver.GremlinServer.open( - "http://localhost:8182/", self.loop) - stream = await connection.submit("1 + 1") - async for msg in stream: - self.assertEqual(msg.data[0], 2) - await connection.close() - - self.loop.run_until_complete(go()) - - def test_async_graph(self): - - async def go(): - translator = process.GroovyTranslator('g') - connection = await driver.GremlinServer.open( - "http://localhost:8182/", self.loop) - g = graph.AsyncRemoteGraph(translator, connection) - traversal = g.traversal() - resp = await traversal.V().next() - async for msg in resp: - print(msg) - await connection.close() - self.loop.run_until_complete(go()) +@pytest.mark.asyncio +async def test_cant_connect(event_loop, gremlin_server, unused_server_url): + with pytest.raises(Exception): + await gremlin_server.open(unused_server_url, event_loop) + + +@pytest.mark.asyncio +async def test_resp_queue_removed_from_conn(connection): + async with connection: + stream = await connection.submit("1 + 1") + async for msg in stream: + pass + assert stream._response_queue not in list( + connection._response_queues.values()) + + +@pytest.mark.asyncio +async def test_stream_done(connection): + async with connection: + stream = await connection.submit("1 + 1") + async for msg in stream: + pass + assert stream._done diff --git a/tests/test_engine.py b/tests/test_engine.py deleted file mode 100644 index 714e843cec56ec77bdb97a72c1a136ff82bbeed0..0000000000000000000000000000000000000000 --- a/tests/test_engine.py +++ /dev/null @@ -1,246 +0,0 @@ -import asyncio -import unittest - -from goblin.app import create_app -from goblin.element import Vertex, Edge, VertexProperty -from goblin.properties import Property, String - - -class TestVertex(Vertex): - __label__ = 'test_vertex' - name = Property(String) - notes = Property(String, default='N/A') - - -class TestEdge(Edge): - __label__ = 'test_edge' - notes = Property(String, default='N/A') - - -class TestEngine(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_add_vertex(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - leif.notes = 'superdev' - session.add(leif) - await session.flush() - current = session._current[leif.id] - self.assertEqual(current.name, 'leifur') - self.assertEqual(current.notes, 'superdev') - self.assertIs(leif, current) - self.assertEqual(leif.id, current.id) - await session.close() - - self.loop.run_until_complete(go()) - - def test_update_vertex(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - session.add(leif) - await session.flush() - current = session._current[leif.id] - self.assertEqual(current.name, 'leifur') - self.assertEqual(current.notes, 'N/A') - - leif.name = 'leif' - session.add(leif) - await session.flush() - new_current = session._current[leif.id] - self.assertIs(current, new_current) - self.assertEqual(new_current.name, 'leif') - await session.close() - - - self.loop.run_until_complete(go()) - - def test_add_edge(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex, TestEdge) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - jon = TestVertex() - jon.name = 'jonathan' - works_for = TestEdge() - works_for.source = jon - works_for.target = leif - self.assertEqual(works_for.notes, 'N/A') - works_for.notes = 'zerofail' - session.add(leif, jon, works_for) - await session.flush() - current = session._current[works_for.id] - self.assertEqual(current.notes, 'zerofail') - self.assertIs(current, works_for) - self.assertEqual(current.id, works_for.id) - self.assertIs(leif, current.target) - self.assertEqual(leif.id, current.target.id) - self.assertIs(jon, current.source) - self.assertEqual(jon.id, current.source.id) - await session.close() - - self.loop.run_until_complete(go()) - - def test_update_edge(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex, TestEdge) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - jon = TestVertex() - jon.name = 'jonathan' - works_for = TestEdge() - works_for.source = jon - works_for.target = leif - session.add(leif, jon, works_for) - await session.flush() - current = session._current[works_for.id] - self.assertEqual(works_for.notes, 'N/A') - works_for.notes = 'zerofail' - session.add(works_for) - await session.flush() - new_current = session._current[works_for.id] - self.assertEqual(new_current.notes, 'zerofail') - await session.close() - - self.loop.run_until_complete(go()) - - - self.loop.run_until_complete(go()) - - def test_query_all(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - jon = TestVertex() - jon.name = 'jonathan' - session.add(leif, jon) - await session.flush() - results = [] - stream = await session.traversal(TestVertex).all() - async for msg in stream: - results.append(msg) - print(len(results)) - self.assertEqual(len(session.current), 2) - for result in results: - self.assertIsInstance(result, Vertex) - await session.close() - - self.loop.run_until_complete(go()) - - def test_remove_vertex(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex, TestEdge) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - session.add(leif) - await session.flush() - current = session._current[leif.id] - self.assertIs(leif, current) - await session.remove_vertex(leif) - result = await session.get_vertex(leif) - self.assertIsNone(result) - self.assertEqual(len(list(session.current.items())), 0) - await session.close() - - self.loop.run_until_complete(go()) - - def test_remove_edge(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex, TestEdge) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'leifur' - jon = TestVertex() - jon.name = 'jonathan' - works_for = TestEdge() - works_for.source = jon - works_for.target = leif - works_for.notes = 'zerofail' - session.add(leif, jon, works_for) - await session.flush() - current = session._current[works_for.id] - self.assertIs(current, works_for) - await session.remove_edge(works_for) - result = await session.get_edge(works_for) - self.assertIsNone(result) - self.assertEqual(len(list(session.current.items())), 2) - await session.close() - - self.loop.run_until_complete(go()) - - def test_traversal(self): - - app = self.loop.run_until_complete( - create_app("http://localhost:8182/", self.loop)) - app.register(TestVertex, TestEdge) - - async def go(): - session = await app.session() - leif = TestVertex() - leif.name = 'the one and only leifur' - jon = TestVertex() - jon.name = 'the one and only jonathan' - works_for = TestEdge() - works_for.source = jon - works_for.target = leif - works_for.notes = 'the one and only zerofail' - session.add(leif, jon, works_for) - await session.flush() - result = await session.traversal(TestVertex).has( - TestVertex.name, ('v1', 'the one and only leifur'))._in().all() - async for msg in result: - self.assertIs(msg, jon) - result = await session.traversal(TestVertex).has( - TestVertex.name, ('v1', 'the one and only jonathan')).out().all() - async for msg in result: - self.assertIs(msg, leif) - await session.remove_vertex(leif) - await session.remove_vertex(jon) - await session.close() - - self.loop.run_until_complete(go()) diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7d0a22b307fe2f83d74920d9bd9552f283143d --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,38 @@ +import pytest +from goblin.gremlin_python import process + + +@pytest.mark.asyncio +async def test_close_graph(remote_graph): + remote_connection = remote_graph.remote_connection + await remote_graph.close() + assert remote_connection.closed + + +@pytest.mark.asyncio +async def test_conn_context_manager(remote_graph): + remote_connection = remote_graph.remote_connection + async with remote_graph: + assert not remote_graph.remote_connection.closed + assert remote_connection.closed + + +@pytest.mark.asyncio +async def test_generate_traversal(remote_graph): + async with remote_graph: + traversal = remote_graph.traversal().V().hasLabel(('v1', 'person')) + assert isinstance(traversal, process.GraphTraversal) + assert traversal.bindings['v1'] == 'person' + + +@pytest.mark.asyncio +async def test_submit_traversal(remote_graph): + async with remote_graph: + g = remote_graph.traversal() + resp = await g.addV('person').property('name', 'leifur').next() + leif = await resp.fetch_data() + assert leif['properties']['name'][0]['value'] == 'leifur' + assert leif['label'] == 'person' + resp = await g.V(leif['id']).drop().next() + none = await resp.fetch_data() + assert none is None diff --git a/tests/test_mapper.py b/tests/test_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c783aedff8d4ab7197fe316bb03be6c6ae315644 --- /dev/null +++ b/tests/test_mapper.py @@ -0,0 +1,39 @@ +import pytest + +from goblin import properties + + +def test_property_mapping(person, lives_in): + db_name, data_type = person.__mapping__._properties['name'] + assert db_name == 'person__name' + assert isinstance(data_type, properties.String) + db_name, data_type = person.__mapping__._properties['age'] + assert db_name == 'person__age' + assert isinstance(data_type, properties.Integer) + db_name, data_type = lives_in.__mapping__._properties['notes'] + assert db_name == 'lives_in__notes' + assert isinstance(data_type, properties.String) + + +def test_label_creation(place, lives_in): + assert place.__mapping__._label == 'place' + assert lives_in.__mapping__._label == 'lives_in' + + +def test_mapper_func(place, knows): + assert callable(place.__mapping__._mapper_func) + assert callable(knows.__mapping__._mapper_func) + + +def test_getattr_getdbname(person, lives_in): + db_name = person.__mapping__.name + assert db_name == 'person__name' + db_name = person.__mapping__.age + assert db_name == 'person__age' + db_name = lives_in.__mapping__.notes + assert db_name == 'lives_in__notes' + + +def test_getattr_doesnt_exist(person): + with pytest.raises(Exception): + db_name = person.__mapping__.doesnt_exits diff --git a/tests/test_properties.py b/tests/test_properties.py index 2fad4436d29931a6b2bbb340797217acd9af6e36..73b2d48feed4eff83b295a3e61aed5f45c667aa9 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,39 +1,61 @@ -import asyncio -import unittest +import pytest -from goblin.engine import create_engine -from goblin.element import Vertex, Edge, VertexProperty -from goblin.properties import Property, String +def test_set_change_property(person, lives_in): + # vertex + assert not person.name + person.name = 'leif' + assert person.name == 'leif' + person.name = 'leifur' + assert person.name == 'leifur' + # edge + assert not lives_in.notes + lives_in.notes = 'notable' + assert lives_in.notes == 'notable' + lives_in.notes = 'more notable' + assert lives_in.notes == 'more notable' -class TestVertexProperty(VertexProperty): - notes = Property(String) +def test_property_default(knows): + assert knows.notes == 'N/A' + knows.notes = 'notable' + assert knows.notes == 'notable' -class TestVertex(Vertex): - __label__ = 'test_vertex' - name = TestVertexProperty(String) - address = Property(String) +def test_validation(person): + person.age = 10 + with pytest.raises(Exception): + person.age = 'hello' -class TestProperties(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) +def test_setattr_validation(person): + setattr(person, 'age', 10) + assert person.age == 10 + with pytest.raises(Exception): + setattr(person, 'age', 'hello') - def tearDown(self): - self.loop.close() - def test_vertex_property(self): +class TestString: - t = TestVertex() - self.assertIsNone(t.name) - t.name = 'leif' - self.assertEqual(t.name._value, 'leif') - self.assertIsNone(t.name.notes) - t.name.notes = 'notes' - self.assertEqual(t.name.notes, 'notes') - t.name = ['leif', 'jon'] - self.assertEqual(t.name[0]._value, 'leif') - self.assertEqual(t.name[1]._value, 'jon') + def test_validation(self, string): + assert string.validate(1) == '1' + + def test_to_db(self, string): + assert string.to_db('hello') == 'hello' + + def test_to_ogm(self, string): + assert string.to_ogm('hello') == 'hello' + + +class TestInteger: + + def test_validation(self, integer): + assert integer.validate('1') == 1 + with pytest.raises(Exception): + integer.validate('hello') + + def test_to_db(self, integer): + assert integer.to_db(1) == 1 + + def test_to_ogm(self, integer): + assert integer.to_db(1) == 1 diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000000000000000000000000000000000000..84e0a15485e9f1023ded39308c43ec64142da760 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,187 @@ +import pytest + + +@pytest.mark.asyncio +async def test_session_close(session): + assert not session.conn.closed + await session.close() + assert session.conn.closed + + +@pytest.mark.asyncio +async def test_session_ctxt_mngr(session): + async with session: + assert not session.conn.closed + assert session.conn.closed + + +class TestCreationApi: + + @pytest.mark.asyncio + async def test_create_vertex(self, session, person_class): + async with session: + jon = person_class() + jon.name = 'jonathan' + jon.age = 38 + leif = person_class() + leif.name = 'leif' + leif.age = 28 + session.add(jon, leif) + assert not hasattr(jon, 'id') + assert not hasattr(leif, 'id') + await session.flush() + assert hasattr(jon, 'id') + assert session.current[jon.id] is jon + assert jon.name == 'jonathan' + assert hasattr(leif, 'id') + assert session.current[leif.id] is leif + assert leif.name == 'leif' + + @pytest.mark.asyncio + async def test_create_edge(self, session, person_class, place_class, + lives_in_class): + async with session: + jon = person_class() + jon.name = 'jonathan' + jon.age = 38 + montreal = place_class() + montreal.name = 'Montreal' + lives_in = lives_in_class(jon, montreal) + session.add(jon, montreal, lives_in) + await session.flush() + assert hasattr(lives_in, 'id') + assert session.current[lives_in.id] is lives_in + assert lives_in.source is jon + assert lives_in.target is montreal + assert lives_in.source.__label__ == 'person' + assert lives_in.target.__label__ == 'place' + + @pytest.mark.asyncio + async def test_create_edge_no_source(self, session, lives_in, person): + async with session: + lives_in.source = person + with pytest.raises(Exception): + await session.save(lives_in) + + @pytest.mark.asyncio + async def test_create_edge_no_target(self, session, lives_in, place): + async with session: + lives_in.target = place + with pytest.raises(Exception): + await session.save(lives_in) + + @pytest.mark.asyncio + async def test_create_edge_no_source_target(self, session, lives_in): + async with session: + with pytest.raises(Exception): + await session.save(lives_in) + + @pytest.mark.asyncio + async def test_get_vertex(self, session, person_class): + async with session: + jon = person_class() + jon.name = 'jonathan' + jon.age = 38 + await session.save(jon) + jid = jon.id + result = await session.get_vertex(jon) + assert result.id == jid + assert result is jon + + @pytest.mark.asyncio + async def test_get_edge(self, session, person_class, place_class, + lives_in_class): + async with session: + jon = person_class() + jon.name = 'jonathan' + jon.age = 38 + montreal = place_class() + montreal.name = 'Montreal' + lives_in = lives_in_class(jon, montreal) + session.add(jon, montreal, lives_in) + await session.flush() + lid = lives_in.id + result = await session.get_edge(lives_in) + assert result.id == lid + assert result is lives_in + + @pytest.mark.asyncio + async def test_get_vertex_doesnt_exist(self, session, person): + async with session: + person.id = 1000000000000000000000000000000000000000000000 + result = await session.get_vertex(person) + assert not result + + @pytest.mark.asyncio + async def test_get_edge_doesnt_exist(self, session, knows, person_class): + async with session: + jon = person_class() + leif = person_class() + works_with = knows + works_with.source = jon + works_with.target = leif + works_with.id = 1000000000000000000000000000000000000000000000 + result = await session.get_edge(works_with) + assert not result + + @pytest.mark.asyncio + async def test_remove_vertex(self, session, person): + async with session: + person.name = 'dave' + person.age = 35 + await session.save(person) + result = await session.g.V(person.id).one_or_none() + assert result is person + rid = result.id + await session.remove_vertex(person) + result = await session.g.V(rid).one_or_none() + assert not result + + @pytest.mark.asyncio + async def test_remove_edge(self, session, person_class, place_class, + lives_in_class): + async with session: + jon = person_class() + jon.name = 'jonathan' + jon.age = 38 + montreal = place_class() + montreal.name = 'Montreal' + lives_in = lives_in_class(jon, montreal) + session.add(jon, montreal, lives_in) + await session.flush() + result = await session.g.E(lives_in.id).one_or_none() + assert result is lives_in + rid = result.id + await session.remove_edge(lives_in) + result = await session.g.E(rid).one_or_none() + assert not result + + def test_update_vertex(self): + pass + + def test_update_edge(self): + pass + + +class TestTraversalApi: + + def test_all(self): + pass + + def test_one_or_none_one(self): + pass + + def test_one_or_none_none(self): + pass + + def test_vertex_deserialization(self): + pass + + def test_edge_desialization(self): + pass + + def test_unregistered_vertex_deserialization(self): + pass + + def test_unregistered_edge_desialization(self): + pass