diff --git a/goblin/app.py b/goblin/app.py index 89d05ca0bf125199da6c417086c3737f0fff8212..c46d6912ae86c1d26ba2bdc26a437257cabe0f29 100644 --- a/goblin/app.py +++ b/goblin/app.py @@ -18,9 +18,9 @@ """Goblin application class and class constructor""" import collections +import importlib import logging -from gremlin_python import process from goblin import driver, element, session @@ -120,8 +120,15 @@ class Goblin: """ self._cluster.config_from_json(filename) - def register_from_module(self, modulename): - raise NotImplementedError + def register_from_module(self, module, *, package=None): + if isinstance(module, str): + module = importlib.import_module(module, package) + elements = list() + for item_name in dir(module): + item = getattr(module, item_name) + if isinstance(item, element.ElementMeta): + elements.append(item) + self.register(*elements) async def session(self, *, use_session=False, processor='', op='eval', aliases=None): diff --git a/tests/register_models.py b/tests/register_models.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6eb78fd4d99db194f984ba989697eb512b750e --- /dev/null +++ b/tests/register_models.py @@ -0,0 +1,21 @@ +from goblin import element + + +class TestRegisterVertex1(element.Vertex): + pass + + +class TestRegisterVertex2(element.Vertex): + pass + + +class TestRegisterEdge1(element.Edge): + pass + + +class TestRegisterEdge2(element.Edge): + pass + + +class NotAModelShouldNotBeRegistered: + pass diff --git a/tests/test_app.py b/tests/test_app.py index 89067f1a3dcdab29121562b75350b88ece3cc8d5..73f823797876acbe0134d8738b660db069f295db 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -17,11 +17,33 @@ import pytest +import goblin from goblin import element from goblin.driver import serializer from gremlin_python import process +def test_register_from_module(app): + import register_models + app.register_from_module(register_models) + vertices, edges = app._vertices.values(), app._edges.values() + assert register_models.TestRegisterVertex1 in vertices + assert register_models.TestRegisterVertex2 in vertices + assert register_models.TestRegisterEdge1 in edges + assert register_models.TestRegisterEdge2 in edges + + +def test_register_from_module_string(app): + app.register_from_module('register_models', package=__package__) + vertices, edges = app._vertices.values(), app._edges.values() + + import register_models + assert register_models.TestRegisterVertex1 in vertices + assert register_models.TestRegisterVertex2 in vertices + assert register_models.TestRegisterEdge1 in edges + assert register_models.TestRegisterEdge2 in edges + + @pytest.mark.asyncio async def test_registry(app, person, place, knows, lives_in): assert len(app.vertices) == 2