From 58aee6ea50a0992515a5269be0e0501623cbbdfc Mon Sep 17 00:00:00 2001 From: Leifur Halldor Asgeirsson <lasgeirsson@zerofail.com> Date: Mon, 14 Nov 2016 11:45:59 -0500 Subject: [PATCH] implement config_from_module --- goblin/app.py | 3 +++ goblin/driver/cluster.py | 12 ++++++++++-- tests/config_module.py | 4 ++++ tests/test_config.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 tests/config_module.py diff --git a/goblin/app.py b/goblin/app.py index a225be4..9808c64 100644 --- a/goblin/app.py +++ b/goblin/app.py @@ -118,6 +118,9 @@ class Goblin: """ self._cluster.config_from_json(filename) + def config_from_module(self, module): + self._cluster.config_from_module(module) + def register_from_module(self, module, *, package=None): if isinstance(module, str): module = importlib.import_module(module, package) diff --git a/goblin/driver/cluster.py b/goblin/driver/cluster.py index 0cab896..b1ed360 100644 --- a/goblin/driver/cluster.py +++ b/goblin/driver/cluster.py @@ -18,6 +18,7 @@ import asyncio import collections import configparser +import importlib import ssl try: @@ -165,8 +166,15 @@ class Cluster: config['message_serializer'] = my_import(message_serializer) return config - def config_from_module(self, filename): - raise NotImplementedError + def config_from_module(self, module): + if isinstance(module, str): + module = importlib.import_module(module) + config = dict() + for item in dir(module): + if not item.startswith('_') and item.lower() in self.DEFAULT_CONFIG: + config[item.lower()] = getattr(module, item) + config = self._get_message_serializer(config) + self.config.update(config) async def connect(self, processor=None, op=None, aliases=None, session=None): diff --git a/tests/config_module.py b/tests/config_module.py new file mode 100644 index 0000000..5127d61 --- /dev/null +++ b/tests/config_module.py @@ -0,0 +1,4 @@ +SCHEME = 'wss' +HOSTS = ['localhost'] +PORT = 8183 +MESSAGE_SERIALIZER = 'goblin.driver.GraphSON2MessageSerializer' diff --git a/tests/test_config.py b/tests/test_config.py index 72d4bcf..be33b7f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -20,10 +20,20 @@ import pytest from goblin import driver, exception +import config_module + dirname = os.path.dirname(os.path.dirname(__file__)) +@pytest.fixture(params=[0, 1]) +def conf_module(request): + if request.param: + return 'config_module' + else: + return config_module + + def test_cluster_default_config(cluster): assert cluster.config['scheme'] == 'ws' assert cluster.config['hosts'] == ['localhost'] @@ -95,6 +105,15 @@ def test_cluster_config_from_yaml(event_loop, cluster_class): assert issubclass(cluster.config['message_serializer'], driver.GraphSONMessageSerializer) + +def test_cluster_config_from_module(event_loop, cluster_class, conf_module): + cluster = cluster_class(event_loop) + cluster.config_from_module(conf_module) + assert cluster.config['scheme'] == 'wss' + assert cluster.config['hosts'] == ['localhost'] + assert cluster.config['port'] == 8183 + assert cluster.config['message_serializer'] is driver.GraphSON2MessageSerializer + @pytest.mark.asyncio async def test_app_config_from_json(app): app.config_from_file(dirname + '/tests/config/config.json') @@ -126,3 +145,12 @@ async def test_app_config_from_yaml(app): assert issubclass(app.config['message_serializer'], driver.GraphSONMessageSerializer) await app.close() + + +@pytest.mark.asyncio +async def test_app_config_from_module(app, conf_module): + app.config_from_module(conf_module) + assert app.config['scheme'] == 'wss' + assert app.config['hosts'] == ['localhost'] + assert app.config['port'] == 8183 + assert app.config['message_serializer'] is driver.GraphSON2MessageSerializer -- GitLab