From cd9dcd0a8cff1bfe277b66f1301d90f51078bf04 Mon Sep 17 00:00:00 2001 From: davebshow <davebshow@gmail.com> Date: Fri, 15 Jul 2016 19:19:41 -0400 Subject: [PATCH] added better error handling --- goblin/driver/connection.py | 58 +++++++++++++++++++++++-------------- goblin/exception.py | 14 +++++++++ goblin/mapper.py | 9 ++++-- goblin/properties.py | 12 ++++---- goblin/session.py | 9 +++--- tests/test_driver.py | 4 ++- tests/test_mapper.py | 4 +-- tests/test_session.py | 2 -- 8 files changed, 73 insertions(+), 39 deletions(-) create mode 100644 goblin/exception.py diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 1453998..01015b3 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -6,6 +6,10 @@ import json import logging import uuid +import aiohttp + +from goblin import exception + logger = logging.getLogger(__name__) @@ -21,7 +25,7 @@ def error_handler(fn): msg = await fn(self) if msg: if msg.status_code not in [200, 206, 204]: - raise RuntimeError( + raise exception.GremlinServerError( "{0}: {1}".format(msg.status_code, msg.message)) msg = msg.data return msg @@ -176,30 +180,40 @@ class Connection(AbstractConnection): async def receive(self): data = await self._ws.receive() - # parse aiohttp response here - message = json.loads(data.data.decode("utf-8")) - request_id = message['requestId'] - status_code = message['status']['code'] - data = message["result"]["data"] - msg = message["status"]["message"] - response_queue = self._response_queues[request_id] - if status_code == 407: - await self._authenticate(self._username, self._password, - self._processor, self._session) - self._loop.create_task(self.receive()) + if data.tp == aiohttp.MsgType.close: + await ws.close() + elif data.tp == aiohttp.MsgType.error: + raise data.data + elif data.tp == aiohttp.MsgType.closed: + pass else: - if data: - for result in data: - message = Message(status_code, result, msg) - response_queue.put_nowait(message) - else: - message = Message(status_code, data, msg) - response_queue.put_nowait(message) - if status_code == 206: + if data.tp == aiohttp.MsgType.binary: + data = data.data.decode() + elif data.tp == aiohttp.MsgType.text: + data = data.strip() + message = json.loads(data) + request_id = message['requestId'] + status_code = message['status']['code'] + data = message["result"]["data"] + msg = message["status"]["message"] + response_queue = self._response_queues[request_id] + if status_code == 407: + await self._authenticate(self._username, self._password, + self._processor, self._session) self._loop.create_task(self.receive()) else: - response_queue.put_nowait(None) - del self._response_queues[request_id] + if data: + for result in data: + message = Message(status_code, result, msg) + response_queue.put_nowait(message) + else: + message = Message(status_code, data, msg) + response_queue.put_nowait(message) + if status_code == 206: + self._loop.create_task(self.receive()) + else: + response_queue.put_nowait(None) + del self._response_queues[request_id] async def __aenter__(self): return self diff --git a/goblin/exception.py b/goblin/exception.py new file mode 100644 index 0000000..ecce2cc --- /dev/null +++ b/goblin/exception.py @@ -0,0 +1,14 @@ +class MappingError(Exception): + pass + + +class ValidationError(Exception): + pass + + +class ElementError(Exception): + pass + + +class GremlinServerError(Exception): + pass diff --git a/goblin/mapper.py b/goblin/mapper.py index 0017754..8ff92b9 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -2,6 +2,7 @@ import logging import functools +from goblin import exception logger = logging.getLogger(__name__) @@ -81,7 +82,7 @@ class Mapping: OGM element and a DB element""" def __init__(self, namespace, element_type, mapper_func, properties): self._label = namespace['__label__'] - self._type = element_type + self._element_type = element_type self._mapper_func = functools.partial(mapper_func, mapping=self) self._properties = {} self._map_properties(properties) @@ -103,7 +104,9 @@ class Mapping: mapping, _ = self._properties[value] return mapping except: - raise Exception("Unknown property") + raise exception.MappingError( + "unrecognized property {} for class: {}".format( + value, self._element_type)) def _map_properties(self, properties): for name, prop in properties.items(): @@ -114,5 +117,5 @@ class Mapping: def __repr__(self): return '<{}(type={}, label={}, properties={})'.format( - self.__class__.__name__, self._type, self._label, + self.__class__.__name__, self._element_type, self._label, self._properties) diff --git a/goblin/properties.py b/goblin/properties.py index 3bc6aef..a49d9ca 100644 --- a/goblin/properties.py +++ b/goblin/properties.py @@ -1,7 +1,7 @@ """Classes to handle proerties and data type definitions""" import logging -from goblin import abc +from goblin import abc, exception logger = logging.getLogger(__name__) @@ -60,8 +60,9 @@ class String(abc.DataType): if val is not None: try: return str(val) - except Exception as e: - raise Exception("Invalid") from e + except ValueError as e: + raise exception.ValidationError( + '{} is not a valid string'.format(val)) from e def to_db(self, val): return super().to_db(val) @@ -78,8 +79,9 @@ class Integer(abc.DataType): if val is not None: try: return int(val) - except Exception as e: - raise Exception("Invalid") from e + except ValueError as e: + raise exception.ValidationError( + '{} is not a valid integer'.format(val)) from e def to_db(self, val): return super().to_db(val) diff --git a/goblin/session.py b/goblin/session.py index d0f4dc8..3643a9e 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -4,8 +4,7 @@ import collections import logging import weakref -from goblin import mapper -from goblin import traversal +from goblin import exception, mapper, traversal from goblin.driver import connection, graph from goblin.element import GenericVertex @@ -125,7 +124,8 @@ class Session(connection.AbstractConnection): elif element.__type__ == 'edge': result = await self.save_edge(element) else: - raise Exception("Unknown element type") + raise exception.ElementError( + "Unknown element type: {}".format(element.__type__)) return result async def save_vertex(self, element): @@ -138,7 +138,8 @@ class Session(connection.AbstractConnection): async def save_edge(self, element): if not (hasattr(element, 'source') and hasattr(element, 'target')): - raise Exception("Edges require source/target vetices") + raise exception.ElementError( + "Edges require both source/target vertices") result = await self._save_element( element, self._check_edge, self.traversal_factory.add_edge, diff --git a/tests/test_driver.py b/tests/test_driver.py index 7b2de03..6c4391a 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,5 +1,7 @@ import pytest +from goblin import exception + @pytest.mark.asyncio async def test_get_close_conn(connection): @@ -43,7 +45,7 @@ async def test_204_empty_stream(connection): async def test_server_error(connection): async with connection: stream = await connection.submit('g. V jla;sdf') - with pytest.raises(Exception): + with pytest.raises(exception.GremlinServerError): async for msg in stream: pass diff --git a/tests/test_mapper.py b/tests/test_mapper.py index c783aed..a8b8454 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -1,6 +1,6 @@ import pytest -from goblin import properties +from goblin import exception, properties def test_property_mapping(person, lives_in): @@ -35,5 +35,5 @@ def test_getattr_getdbname(person, lives_in): def test_getattr_doesnt_exist(person): - with pytest.raises(Exception): + with pytest.raises(exception.MappingError): db_name = person.__mapping__.doesnt_exits diff --git a/tests/test_session.py b/tests/test_session.py index 2bde914..a340c6e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -228,7 +228,6 @@ class TestTraversalApi: resp = await session.traversal(person_class).one_or_none() assert isinstance(resp, person_class) - @pytest.mark.asyncio async def test_one_or_none_none(self, session): async with session: @@ -267,7 +266,6 @@ class TestTraversalApi: assert dave.name == 'dave' assert dave.__label__ == 'unregistered' - @pytest.mark.asyncio async def test_unregistered_edge_desialization(self, session): async with session: -- GitLab