Commit d1340c45 authored by davebshow's avatar davebshow
Browse files

simplify!

parent a7b92662
......@@ -3,5 +3,4 @@ from .connection import WebsocketPool, AiohttpFactory
from .client import GremlinClient
from .contextmanager import GremlinContext
from .exceptions import RequestError, GremlinServerError, SocketClientError
from .tasks import async, Group, Chain, Chord
__version__ = "0.0.1"
"""Abstract classes for creating pluggable websocket clients."""
from abc import ABCMeta, abstractmethod
......
......@@ -5,11 +5,9 @@ import json
import ssl
import uuid
from .connection import WebsocketPool
from .log import client_logger
from .protocol import gremlin_response_parser, GremlinWriter
from .response import GremlinResponse
from .tasks import async
from aiogremlin.connection import WebsocketPool
from aiogremlin.log import client_logger
from aiogremlin.protocol import gremlin_response_parser, GremlinWriter
class GremlinBase:
......@@ -118,7 +116,6 @@ class GremlinClient(GremlinBase):
message = yield from self.recv(connection)
if message is None:
break
message = GremlinResponse(message)
if consumer:
message = consumer(message)
if asyncio.iscoroutine(message):
......
......@@ -4,9 +4,9 @@ import asyncio
import aiohttp
from .abc import AbstractFactory, AbstractConnection
from .exceptions import SocketClientError
from .log import INFO, conn_logger
from aiogremlin.abc import AbstractFactory, AbstractConnection
from aiogremlin.exceptions import SocketClientError
from aiogremlin.log import INFO, conn_logger
class WebsocketPool:
......@@ -165,9 +165,6 @@ class BaseConnection(AbstractConnection):
class AiohttpConnection(BaseConnection):
def __init__(self, socket, pool=None):
super().__init__(socket, pool=pool)
@property
def closed(self):
return self.socket.closed
......@@ -178,6 +175,7 @@ class AiohttpConnection(BaseConnection):
try:
yield from self.socket.close()
finally:
# Socket should close despite errors.
self._closed = True
@asyncio.coroutine
......@@ -198,7 +196,8 @@ class AiohttpConnection(BaseConnection):
raise
@asyncio.coroutine
def recv(self):
def receive(self):
"""Implements a dispatcher using the aiohttp websocket protocol."""
while True:
try:
message = yield from self.socket.receive()
......
from contextlib import contextmanager
from .client import GremlinBase, GremlinClient
from .connection import WebsocketPool
from aiogremlin.client import GremlinBase, GremlinClient
from aiogremlin.connection import WebsocketPool
class GremlinContext(GremlinBase):
......
"""
gizmo.exceptions
This module defines exceptions for the Gremlin Server.
Gremlin Server exceptions.
"""
class SocketClientError(IOError): pass
......@@ -10,12 +8,8 @@ class SocketClientError(IOError): pass
class StatusException(IOError):
def __init__(self, value, result):
"""
Handle all exceptions returned from the Gremlin Server as per:
"""Handle all exceptions returned from the Gremlin Server as per:
https://github.com/apache/incubator-tinkerpop/blob/ddd0b36bed9a2b1ce5b335b1753d881f0614a6c4/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/message/ResponseStatusCode.java
:param value: ResultCode:
:param value: message:
"""
self.value = value
self.response = {
......
import logging
# logging.basicConfig(level=logging.DEBUG)
INFO = logging.INFO
......@@ -10,4 +10,3 @@ logging.basicConfig(
client_logger = logging.getLogger("aiogremlin.client")
conn_logger = logging.getLogger("aiogremlin.connection")
task_logger = logging.getLogger("aiogremlin.task")
"""
Implements a very simple "protocol" for the Gremlin server.
"""
"""Implements a very simple "protocol" for the Gremlin server."""
import asyncio
import collections
import json
from .exceptions import RequestError, GremlinServerError
from aiogremlin.exceptions import RequestError, GremlinServerError
Message = collections.namedtuple("Message", ["status_code", "data", "message",
"metadata"])
@asyncio.coroutine
def gremlin_response_parser(connection):
message = yield from connection.recv()
message = yield from connection.receive()
message = json.loads(message)
status_code = message["status"]["code"]
if status_code == 200:
message = Message(message["status"]["code"],
message["result"]["data"],
message["result"]["meta"],
message["status"]["message"])
if message.status_code == 200:
return message
elif status_code == 299:
elif message.status_code == 299:
connection.feed_pool()
# Return None
else:
try:
message = message["status"]["message"]
if status_code < 500:
raise RequestError(status_code, message)
raise RequestError(message.status_code, message.message)
else:
raise GremlinServerError(status_code, message)
raise GremlinServerError(message.status_code, message.message)
finally:
yield from connection.release()
......
"""
THIS MODULE WILL BE REMOVED
"""
class GremlinResponse(list):
def __init__(self, message):
"""
A subclass of list that parses and flattens the Gremlin Server's
response a bit. Make standard usecase easier for end user to process.
:param message: Message from Gremlin Server.
"""
super().__init__()
data = message["result"].get("data", "")
if data:
for datum in data:
if isinstance(datum, dict):
try:
datum = parse_struct(datum)
except (KeyError, IndexError):
pass
self.append(datum)
self.meta = message["result"]["meta"]
self.request_id = message["requestId"]
self.status_code = message["status"]["code"]
self.message = message["status"]["message"]
self.attrs = message["status"]["attributes"]
def parse_struct(struct):
"""
Flatten out Gremlin Vertex and Edges a bit.
:param struct: Vertex or Edge.
:return: dict
"""
output = {}
for k, v in struct.items():
if k != "properties":
output[k] = v
# TODO - Make sure no info is being lost here.
properties = {k: [val["value"] for val in v] for (k, v) in
struct["properties"].items()}
output.update(properties)
return output
"""
THIS MODULE WILL BE COMPLETELY REFACTORED
"""
import asyncio
import itertools
from .log import task_logger, INFO
def async(coro, *args, **kwargs):
return Task(coro, *args, **kwargs)
class BaseTask:
def __init__(self, **kwargs):
self._loop = kwargs.get("loop", "") or asyncio.get_event_loop()
self.coro = None
self._result = None
verbose = kwargs.get("verbose", False)
if verbose:
task_logger.setLevel(INFO)
@property
def loop(self):
return self._loop
@property
def result(self):
return self._result
def __call__(self):
self.task = asyncio.async(self.coro, loop=self.loop)
task_logger.info("Task scheduled: {}".format(self))
return self.task
def execute(self):
if not hasattr(self, "task"):
self.__call__()
task_logger.info("Execute task: {}".format(self.loop))
self.loop.run_until_complete(self.task)
task_logger.info("Completed task: {}".format(self.task))
return self._result
def get(self):
return self.execute()
@asyncio.coroutine
def _dequeue(self, queue):
self._result = []
while not queue.empty():
t = queue.get_nowait()
result = yield from t()
self._result.append(result)
return self._result
class Task(BaseTask):
def __init__(self, coro, *args, **kwargs):
super().__init__(**kwargs)
self.coro = self.set_raise_result(coro(*args, **kwargs))
@asyncio.coroutine
def set_raise_result(self, coro):
result = yield from coro
try:
self._result = list(itertools.chain.from_iterable(result))
except TypeError:
self._result = result
return self._result
class Group(BaseTask):
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
if len(args) == 1:
args = args[0]
coro = asyncio.wait([t.coro for t in args], loop=self._loop,
return_when=asyncio.FIRST_EXCEPTION)
self.coro = self.set_raise_result(coro)
@asyncio.coroutine
def set_raise_result(self, tasks):
done, pending = yield from tasks
result = [f.result() for f in done]
self._result = result
return self._result
class Chain(BaseTask):
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
if len(args) == 1:
args = args[0]
task_queue = asyncio.Queue(loop=self._loop)
for t in args:
task_queue.put_nowait(t)
self.coro = self._dequeue(task_queue)
class Chord(BaseTask):
def __init__(self, itrbl, callback, **kwargs):
super().__init__(**kwargs)
g = Group(itrbl, loop=self._loop)
task_queue = asyncio.Queue(loop=self._loop)
task_queue.put_nowait(g)
task_queue.put_nowait(callback)
self.coro = self._dequeue(task_queue)
......@@ -5,7 +5,7 @@ import asyncio
import itertools
import websockets
import unittest
from aiogremlin import (GremlinClient, async, Group, Chain, Chord, RequestError,
from aiogremlin import (GremlinClient, RequestError,
GremlinServerError, SocketClientError, WebsocketPool, AiohttpFactory)
......@@ -21,248 +21,20 @@ class GremlinClientTests(unittest.TestCase):
self.loop.run_until_complete(self.gc.close())
self.loop.close()
@asyncio.coroutine
def consumer_coro1(self, x):
yield from asyncio.sleep(0.25, loop=self.loop)
return x[0] ** 0
@asyncio.coroutine
def consumer_coro2(self, x):
yield from asyncio.sleep(0.5, loop=self.loop)
return x[0] ** 1
def test_connection(self):
@asyncio.coroutine
def conn_coro():
conn = yield from self.gc.connect()
self.assertFalse(conn.closed)
self.loop.run_until_complete(conn_coro())
def test_task(self):
t = async(self.gc.submit, "x + x", bindings={"x": 2},
consumer=lambda x : x, loop=self.loop)
message = t.execute()
self.assertEqual(4, message[0])
def test_task_error(self):
t = async(self.gc.submit, "x + x g.adasdfd", bindings={"x": 2},
consumer=lambda x : x[0] ** 2, loop=self.loop)
try:
t.execute()
error = False
except:
error = True
self.assertTrue(error)
def test_submittask(self):
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
t()
message = t.get()
self.assertEqual(16, message[0])
def test_group(self):
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
g = Group(slow, t, loop=self.loop)
results = g.execute()
self.assertEqual(len(results), 2)
results = list(itertools.chain.from_iterable(results))
self.assertTrue(16 in results)
self.assertTrue(1 in results)
def test_group_error(self):
t = self.gc.s("x + x g.sdfa", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
g = Group(slow, t, loop=self.loop)
try:
g.execute()
error = False
except:
error = True
self.assertTrue(error)
def test_group_of_groups(self):
fast = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
fast1 = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2}, consumer=self.consumer_coro1)
slow1 = self.gc.s("x + x", bindings={"x": 2}, consumer=self.consumer_coro1)
g = Group(fast, fast1, loop=self.loop)
g1 = Group(slow, slow1, loop=self.loop)
results = Group(g, g1, loop=self.loop).execute()
self.assertEqual(len(results), 2)
self.assertEqual(len(results[0]), 2)
self.assertEqual(len(results[1]), 2)
results = list(itertools.chain.from_iterable(results))
results = list(itertools.chain.from_iterable(results))
self.assertTrue(1 in results)
self.assertTrue(16 in results)
results.remove(1)
results.remove(16)
self.assertTrue(1 in results)
self.assertTrue(16 in results)
def test_group_itrbl_arg(self):
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
g = Group([slow, t], loop=self.loop)
results = g.execute()
self.assertEqual(len(results), 2)
results = list(itertools.chain.from_iterable(results))
self.assertTrue(1 in results)
self.assertTrue(16 in results)
def test_chain(self):
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
results = Chain(slow, t, loop=self.loop).execute()
self.assertEqual(results[0][0], 1)
self.assertEqual(results[1][0], 16)
def test_chain_error(self):
t = self.gc.s("x + x g.sadf", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
try:
Chain(slow, t, loop=self.loop).execute()
error = False
except:
error = True
self.assertTrue(error)
def test_chains_in_group(self):
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro2)
slow1 = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
slow_chain = Chain(slow, slow1, loop=self.loop)
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
results = Group(slow_chain, t, loop=self.loop).execute()
self.assertEqual(slow_chain.result[0][0], 4)
self.assertEqual(slow_chain.result[1][0], 1)
self.assertEqual(t.result[0], 16)
def test_chains_in_group_error(self):
slow = self.gc.s("x + x g.edfsa", bindings={"x": 2},
consumer=self.consumer_coro2)
slow1 = self.gc.s("x + x g.eafwa", bindings={"x": 2},
consumer=self.consumer_coro1)
slow_chain = Chain(slow, slow1, loop=self.loop)
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
try:
Group(slow_chain, t, loop=self.loop).execute()
error = False
except:
error = True
self.assertTrue(error)
def test_chain_itrbl_arg(self):
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
slow = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
results = Chain([slow, t], loop=self.loop).execute()
self.assertEqual(results[0][0], 1)
self.assertEqual(results[1][0], 16)
def test_group_chain(self):
results = []
slow = self.gc.s("x + x", bindings={"x": 2}, consumer=self.consumer_coro1)
slow1 = self.gc.s("x + x", bindings={"x": 2}, consumer=self.consumer_coro1)
slow_group = Group(slow, slow1, loop=self.loop)
fast = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
fast1 = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
fast_group = Group(fast, fast1, loop=self.loop)
results = Chain(slow_group, fast_group, loop=self.loop).execute()
self.assertEqual(results[0][0][0], 1)
self.assertEqual(results[0][1][0], 1)
self.assertEqual(results[1][0][0], 16)
self.assertEqual(results[1][1][0], 16)
def test_chord(self):
slow1 = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro1)
slow2 = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro2)
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
results = Chord([slow2, slow1], t, loop=self.loop).execute()
flat = list(itertools.chain.from_iterable(results[0]))
self.assertTrue(1 in flat)
self.assertTrue(4 in flat)
self.assertEqual(results[1][0], 16)
def test_chord_group_error(self):
slow1 = self.gc.s("x + x g.asdf", bindings={"x": 2},
consumer=self.consumer_coro1)
slow2 = self.gc.s("x + x", bindings={"x": 2},
consumer=self.consumer_coro2)
t = self.gc.s("x + x", bindings={"x": 2},
consumer=lambda x : x[0] ** 2)
try:
Chord([slow2, slow1], t, loop=self.loop).execute()
error = False
except:
error = True
self.assertTrue(error)
def test_z_e2e(self):
t = self.gc.s("g.V().remove(); g.E().remove();", collect=False)
t1 = self.gc.s("g.addVertex('uniqueId', x)", bindings={"x": "joe"},
collect=False)
t2 = self.gc.s("g.addVertex('uniqueId', x)", bindings={"x": "maria"},
collect=False)
t3 = self.gc.s("g.addVertex('uniqueId', x)", bindings={"x": "jill"},
collect=False)
t4 = self.gc.s("g.addVertex('uniqueId', x)", bindings={"x": "jack"},
collect=False)
g1 = Group(t1, t2, t3, t4, loop=self.loop)
t5 = self.gc.s("""
joe = g.V().has('uniqueId', 'joe').next();
maria = g.V().has('uniqueId', 'maria').next();
joe.addEdge('marriedTo', maria);""")
t6 = self.gc.s("""
jill = g.V().has('uniqueId', 'jill').next();
jack = g.V().has('uniqueId', 'jack').next();
jill.addEdge('marriedTo', jack);""")
t7 = self.gc.s("""
jill = g.V().has('uniqueId', 'jill').next();
joe = g.V().has('uniqueId', 'joe').next();
jill.addEdge('hasSibling', joe);""")
g2 = Group(t5, t6, t7, loop=self.loop)
t8 = self.gc.s("g.V();", consumer=lambda x: print(x))
t9 = self.gc.s("g.E();", consumer=lambda x: print(x))
t10 = self.gc.s("g.V().count();", consumer=lambda x: self.assertEqual(x[0], 4))
t11 = self.gc.s("g.E().count();", consumer=lambda x: self.assertEqual(x[0], 3))
c = Chain(t, g1, g2, t8, t9, t10, t11, t, loop=self.loop)
results = c.execute()
print(results)
return conn
conn = self.loop.run_until_complete(conn_coro())
# Clean up the resource.
self.loop.run_until_complete(conn.close())
def test_sub(self):
@asyncio.coroutine
def sub_coro():
results = []
results = yield from self.gc.submit("x + x", bindings={"x": 4})
self.assertEqual(results[0][0], 8)
self.loop.run_until_complete(sub_coro())
sub = self.gc.submit("x + x", bindings={"x": 4})
results = self.loop.run_until_complete(sub)
self.assertEqual(results[0].data[0], 8)
def test_recv(self):
@asyncio.coroutine
......@@ -273,19 +45,15 @@ class GremlinClientTests(unittest.TestCase):
f = yield from self.gc.recv(websocket)
if f is None:
break
else:
results.append(f)