diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py index 14539986757ed008760447a0c70aff4a296e578b..01015b327371a528dec3261fddf0bc00b14775e9 100644 --- a/goblin/driver/connection.py +++ b/goblin/driver/connection.py @@ -6,6 +6,10 @@ import json import logging import uuid +import aiohttp + +from goblin import exception + logger = logging.getLogger(__name__) @@ -21,7 +25,7 @@ def error_handler(fn): msg = await fn(self) if msg: if msg.status_code not in [200, 206, 204]: - raise RuntimeError( + raise exception.GremlinServerError( "{0}: {1}".format(msg.status_code, msg.message)) msg = msg.data return msg @@ -176,30 +180,40 @@ class Connection(AbstractConnection): async def receive(self): data = await self._ws.receive() - # parse aiohttp response here - message = json.loads(data.data.decode("utf-8")) - request_id = message['requestId'] - status_code = message['status']['code'] - data = message["result"]["data"] - msg = message["status"]["message"] - response_queue = self._response_queues[request_id] - if status_code == 407: - await self._authenticate(self._username, self._password, - self._processor, self._session) - self._loop.create_task(self.receive()) + if data.tp == aiohttp.MsgType.close: + await ws.close() + elif data.tp == aiohttp.MsgType.error: + raise data.data + elif data.tp == aiohttp.MsgType.closed: + pass else: - if data: - for result in data: - message = Message(status_code, result, msg) - response_queue.put_nowait(message) - else: - message = Message(status_code, data, msg) - response_queue.put_nowait(message) - if status_code == 206: + if data.tp == aiohttp.MsgType.binary: + data = data.data.decode() + elif data.tp == aiohttp.MsgType.text: + data = data.strip() + message = json.loads(data) + request_id = message['requestId'] + status_code = message['status']['code'] + data = message["result"]["data"] + msg = message["status"]["message"] + response_queue = self._response_queues[request_id] + if status_code == 407: + await self._authenticate(self._username, self._password, + self._processor, self._session) self._loop.create_task(self.receive()) else: - response_queue.put_nowait(None) - del self._response_queues[request_id] + if data: + for result in data: + message = Message(status_code, result, msg) + response_queue.put_nowait(message) + else: + message = Message(status_code, data, msg) + response_queue.put_nowait(message) + if status_code == 206: + self._loop.create_task(self.receive()) + else: + response_queue.put_nowait(None) + del self._response_queues[request_id] async def __aenter__(self): return self diff --git a/goblin/exception.py b/goblin/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..ecce2cc5fafc337f898250a46ec725d43a4eacfa --- /dev/null +++ b/goblin/exception.py @@ -0,0 +1,14 @@ +class MappingError(Exception): + pass + + +class ValidationError(Exception): + pass + + +class ElementError(Exception): + pass + + +class GremlinServerError(Exception): + pass diff --git a/goblin/mapper.py b/goblin/mapper.py index 001775400f495e955b0b0117d3dfbb7834b4a4d1..8ff92b9195467e668f83fe3e906e2749bba99532 100644 --- a/goblin/mapper.py +++ b/goblin/mapper.py @@ -2,6 +2,7 @@ import logging import functools +from goblin import exception logger = logging.getLogger(__name__) @@ -81,7 +82,7 @@ class Mapping: OGM element and a DB element""" def __init__(self, namespace, element_type, mapper_func, properties): self._label = namespace['__label__'] - self._type = element_type + self._element_type = element_type self._mapper_func = functools.partial(mapper_func, mapping=self) self._properties = {} self._map_properties(properties) @@ -103,7 +104,9 @@ class Mapping: mapping, _ = self._properties[value] return mapping except: - raise Exception("Unknown property") + raise exception.MappingError( + "unrecognized property {} for class: {}".format( + value, self._element_type)) def _map_properties(self, properties): for name, prop in properties.items(): @@ -114,5 +117,5 @@ class Mapping: def __repr__(self): return '<{}(type={}, label={}, properties={})'.format( - self.__class__.__name__, self._type, self._label, + self.__class__.__name__, self._element_type, self._label, self._properties) diff --git a/goblin/properties.py b/goblin/properties.py index 3bc6aef4c639f71c6fce84de55c8e6ce0f33f88e..a49d9cad1b8f268a6a78897fcf2c66bf34b4b438 100644 --- a/goblin/properties.py +++ b/goblin/properties.py @@ -1,7 +1,7 @@ """Classes to handle proerties and data type definitions""" import logging -from goblin import abc +from goblin import abc, exception logger = logging.getLogger(__name__) @@ -60,8 +60,9 @@ class String(abc.DataType): if val is not None: try: return str(val) - except Exception as e: - raise Exception("Invalid") from e + except ValueError as e: + raise exception.ValidationError( + '{} is not a valid string'.format(val)) from e def to_db(self, val): return super().to_db(val) @@ -78,8 +79,9 @@ class Integer(abc.DataType): if val is not None: try: return int(val) - except Exception as e: - raise Exception("Invalid") from e + except ValueError as e: + raise exception.ValidationError( + '{} is not a valid integer'.format(val)) from e def to_db(self, val): return super().to_db(val) diff --git a/goblin/session.py b/goblin/session.py index d0f4dc89103ebce317d91f8e20875cd4f6a61341..3643a9eabf08faa374fc516e177f8646f59305e9 100644 --- a/goblin/session.py +++ b/goblin/session.py @@ -4,8 +4,7 @@ import collections import logging import weakref -from goblin import mapper -from goblin import traversal +from goblin import exception, mapper, traversal from goblin.driver import connection, graph from goblin.element import GenericVertex @@ -125,7 +124,8 @@ class Session(connection.AbstractConnection): elif element.__type__ == 'edge': result = await self.save_edge(element) else: - raise Exception("Unknown element type") + raise exception.ElementError( + "Unknown element type: {}".format(element.__type__)) return result async def save_vertex(self, element): @@ -138,7 +138,8 @@ class Session(connection.AbstractConnection): async def save_edge(self, element): if not (hasattr(element, 'source') and hasattr(element, 'target')): - raise Exception("Edges require source/target vetices") + raise exception.ElementError( + "Edges require both source/target vertices") result = await self._save_element( element, self._check_edge, self.traversal_factory.add_edge, diff --git a/tests/test_driver.py b/tests/test_driver.py index 7b2de032277cfef758a3c172ba14a422a2896e22..6c4391adccf2f9f17daf3fc88631037ff7769980 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,5 +1,7 @@ import pytest +from goblin import exception + @pytest.mark.asyncio async def test_get_close_conn(connection): @@ -43,7 +45,7 @@ async def test_204_empty_stream(connection): async def test_server_error(connection): async with connection: stream = await connection.submit('g. V jla;sdf') - with pytest.raises(Exception): + with pytest.raises(exception.GremlinServerError): async for msg in stream: pass diff --git a/tests/test_mapper.py b/tests/test_mapper.py index c783aedff8d4ab7197fe316bb03be6c6ae315644..a8b8454566134a1396410c2ce100ec244a261804 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -1,6 +1,6 @@ import pytest -from goblin import properties +from goblin import exception, properties def test_property_mapping(person, lives_in): @@ -35,5 +35,5 @@ def test_getattr_getdbname(person, lives_in): def test_getattr_doesnt_exist(person): - with pytest.raises(Exception): + with pytest.raises(exception.MappingError): db_name = person.__mapping__.doesnt_exits diff --git a/tests/test_session.py b/tests/test_session.py index 2bde91486f0ac5924628dd4634066e633e1ea7c5..a340c6e79e326d5c0ed560010f4227797ccab480 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -228,7 +228,6 @@ class TestTraversalApi: resp = await session.traversal(person_class).one_or_none() assert isinstance(resp, person_class) - @pytest.mark.asyncio async def test_one_or_none_none(self, session): async with session: @@ -267,7 +266,6 @@ class TestTraversalApi: assert dave.name == 'dave' assert dave.__label__ == 'unregistered' - @pytest.mark.asyncio async def test_unregistered_edge_desialization(self, session): async with session: