From 58aee6ea50a0992515a5269be0e0501623cbbdfc Mon Sep 17 00:00:00 2001
From: Leifur Halldor Asgeirsson <lasgeirsson@zerofail.com>
Date: Mon, 14 Nov 2016 11:45:59 -0500
Subject: [PATCH] implement config_from_module

---
 goblin/app.py            |  3 +++
 goblin/driver/cluster.py | 12 ++++++++++--
 tests/config_module.py   |  4 ++++
 tests/test_config.py     | 28 ++++++++++++++++++++++++++++
 4 files changed, 45 insertions(+), 2 deletions(-)
 create mode 100644 tests/config_module.py

diff --git a/goblin/app.py b/goblin/app.py
index a225be4..9808c64 100644
--- a/goblin/app.py
+++ b/goblin/app.py
@@ -118,6 +118,9 @@ class Goblin:
         """
         self._cluster.config_from_json(filename)
 
+    def config_from_module(self, module):
+        self._cluster.config_from_module(module)
+
     def register_from_module(self, module, *, package=None):
         if isinstance(module, str):
             module = importlib.import_module(module, package)
diff --git a/goblin/driver/cluster.py b/goblin/driver/cluster.py
index 0cab896..b1ed360 100644
--- a/goblin/driver/cluster.py
+++ b/goblin/driver/cluster.py
@@ -18,6 +18,7 @@
 import asyncio
 import collections
 import configparser
+import importlib
 import ssl
 
 try:
@@ -165,8 +166,15 @@ class Cluster:
             config['message_serializer'] = my_import(message_serializer)
         return config
 
-    def config_from_module(self, filename):
-        raise NotImplementedError
+    def config_from_module(self, module):
+        if isinstance(module, str):
+            module = importlib.import_module(module)
+        config = dict()
+        for item in dir(module):
+            if not item.startswith('_') and item.lower() in self.DEFAULT_CONFIG:
+                config[item.lower()] = getattr(module, item)
+        config = self._get_message_serializer(config)
+        self.config.update(config)
 
     async def connect(self, processor=None, op=None, aliases=None,
                       session=None):
diff --git a/tests/config_module.py b/tests/config_module.py
new file mode 100644
index 0000000..5127d61
--- /dev/null
+++ b/tests/config_module.py
@@ -0,0 +1,4 @@
+SCHEME = 'wss'
+HOSTS = ['localhost']
+PORT = 8183
+MESSAGE_SERIALIZER = 'goblin.driver.GraphSON2MessageSerializer'
diff --git a/tests/test_config.py b/tests/test_config.py
index 72d4bcf..be33b7f 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -20,10 +20,20 @@ import pytest
 
 from goblin import driver, exception
 
+import config_module
+
 
 dirname = os.path.dirname(os.path.dirname(__file__))
 
 
+@pytest.fixture(params=[0, 1])
+def conf_module(request):
+    if request.param:
+        return 'config_module'
+    else:
+        return config_module
+
+
 def test_cluster_default_config(cluster):
     assert cluster.config['scheme'] == 'ws'
     assert cluster.config['hosts'] == ['localhost']
@@ -95,6 +105,15 @@ def test_cluster_config_from_yaml(event_loop, cluster_class):
     assert issubclass(cluster.config['message_serializer'],
                       driver.GraphSONMessageSerializer)
 
+
+def test_cluster_config_from_module(event_loop, cluster_class, conf_module):
+    cluster = cluster_class(event_loop)
+    cluster.config_from_module(conf_module)
+    assert cluster.config['scheme'] == 'wss'
+    assert cluster.config['hosts'] == ['localhost']
+    assert cluster.config['port'] == 8183
+    assert cluster.config['message_serializer'] is driver.GraphSON2MessageSerializer
+
 @pytest.mark.asyncio
 async def test_app_config_from_json(app):
     app.config_from_file(dirname + '/tests/config/config.json')
@@ -126,3 +145,12 @@ async def test_app_config_from_yaml(app):
     assert issubclass(app.config['message_serializer'],
                       driver.GraphSONMessageSerializer)
     await app.close()
+
+
+@pytest.mark.asyncio
+async def test_app_config_from_module(app, conf_module):
+    app.config_from_module(conf_module)
+    assert app.config['scheme'] == 'wss'
+    assert app.config['hosts'] == ['localhost']
+    assert app.config['port'] == 8183
+    assert app.config['message_serializer'] is driver.GraphSON2MessageSerializer
-- 
GitLab