Commit 329fbaf5 authored by davebshow's avatar davebshow
Browse files

integrated context manager, finally have ClientWebSocketResponse subclass that I wanted

parent 91fc71cd
......@@ -18,10 +18,10 @@ class AbstractFactory(metaclass=ABCMeta):
class AbstractConnection(metaclass=ABCMeta):
# @property
# @abstractmethod
# def closed(self):
# pass
@property
@abstractmethod
def closed(self):
pass
@abstractmethod
def close():
......@@ -35,6 +35,6 @@ class AbstractConnection(metaclass=ABCMeta):
def send(self):
pass
# @abstractmethod
# def receive(self):
# pass
@abstractmethod
def receive(self):
pass
......@@ -6,7 +6,6 @@ import ssl
import aiohttp
from aiogremlin.connection import AiohttpFactory
from aiogremlin.contextmanager import ClientContextManager
from aiogremlin.exceptions import RequestError
from aiogremlin.log import logger, INFO
from aiogremlin.pool import WebSocketPool
......@@ -102,11 +101,14 @@ class GremlinClient:
elif self._connected:
conn = self._conn
else:
conn = yield from self._conn
try:
self._conn = yield from self._conn
except TypeError:
self._conn = yield from self._connect()
except Exception:
raise RuntimeError("Unable to acquire connection.")
conn = self._conn
return conn
# Check here for error
# except Error:
# conn = yield from self._connect()
@asyncio.coroutine
def submit(self, gremlin, conn=None, bindings=None, lang=None, op=None,
......@@ -119,11 +121,11 @@ class GremlinClient:
if conn is None:
conn = yield from self._acquire()
writer = GremlinWriter(conn)
conn = yield from writer.write(gremlin, bindings=bindings,
conn = writer.write(gremlin, bindings=bindings,
lang=lang, op=op, processor=processor, session=session,
binary=binary)
return GremlinResponse(conn,
self,
pool=self._pool,
session=session,
loop=self._loop)
......@@ -142,11 +144,10 @@ class GremlinClient:
class GremlinResponse:
def __init__(self, conn, client, session=None, loop=None):
def __init__(self, conn, pool=None, session=None, loop=None):
self._loop = loop or asyncio.get_event_loop()
self._client = client
self._session = session
self._stream = GremlinResponseStream(conn, self, loop=self._loop)
self._stream = GremlinResponseStream(conn, pool=pool, loop=self._loop)
@property
def stream(self):
......@@ -172,24 +173,12 @@ class GremlinResponse:
results.append(message)
return results
# aioredis style
def __enter__(self):
raise RuntimeError(
"'yield from' should be used as a context manager expression")
def __exit__(self, *args):
pass
def __iter__(self):
yield from self._pool.create_pool()
return ClientContextManager(self)
class GremlinResponseStream:
def __init__(self, conn, resp, loop=None):
def __init__(self, conn, pool=None, loop=None):
self._conn = conn
self._resp = resp
self._pool = pool
self._loop = loop or asyncio.get_event_loop()
data_stream = aiohttp.DataQueue(loop=self._loop)
self._stream = self._conn.parser.set_parser(gremlin_response_parser,
......@@ -203,16 +192,14 @@ class GremlinResponseStream:
# message = None
# else:
# This will be different 3.0.0.M9
pool = self._resp._client._pool
try:
yield from self._conn.read()
yield from self._conn.receive()
except RequestError:
if pool:
pool.release(self._conn)
print("fed pool")
if self._pool:
self._pool.release(self._conn)
if self._stream.is_eof():
if pool:
pool.release(self._conn)
if self._pool:
self._pool.release(self._conn)
message = None
else:
message = yield from self._stream.read()
......
......@@ -126,10 +126,14 @@ class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop)
@property
def closed(self):
"""Required by ABC."""
return self._closed
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
closed = self._close()
if closed:
return True
......@@ -155,6 +159,7 @@ class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
return False
def _close(self):
self._closed = True
try:
self._writer.close(code, message)
except asyncio.CancelledError:
......@@ -171,7 +176,6 @@ class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
self._response.close(force=True)
return True
# @asyncio.coroutine
def send(self, message, binary=True):
if binary:
method = self.send_bytes
......@@ -187,74 +191,46 @@ class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
raise
@asyncio.coroutine
def read(self):
"""Implements a dispatcher using the aiohttp websocket protocol."""
def receive(self):
if self._waiting:
raise RuntimeError('Concurrent call to receive() is not allowed')
self._waiting = True
try:
message = yield from self.receive()
except (asyncio.CancelledError, asyncio.TimeoutError):
# Hmm maybe don't close here
yield from self.close()
raise
if message.tp == MsgType.binary:
try:
self.parser.feed_data(message.data.decode())
except Exception:
# Hmm maybe don't close here
yield from self.close()
raise
elif message.tp == MsgType.text:
try:
self.parser.feed_data(message.data.strip())
except Exception:
# Hmm maybe don't close here
yield from self.close()
raise
else:
if message.tp == MsgType.close:
raise RuntimeError("Socket connection closed by server.")
elif message.tp == MsgType.error:
raise SocketClientError(self.socket.exception())
elif message.tp == MsgType.closed:
raise RuntimeError("Socket closed.")
# @asyncio.coroutine
# def receive(self):
# if self._waiting:
# raise RuntimeError('Concurrent call to receive() is not allowed')
#
# self._waiting = True
# try:
# while True:
# if self._closed:
# return closedMessage
#
# try:
# msg = yield from self._reader.read()
# except (asyncio.CancelledError, asyncio.TimeoutError):
# raise
# except WebSocketError as exc:
# self._close_code = exc.code
# yield from self.close(code=exc.code)
# return Message(MsgType.error, exc, None)
# except Exception as exc:
# self._exception = exc
# self._closing = True
# self._close_code = 1006
# yield from self.close()
# return Message(MsgType.error, exc, None)
#
# if msg.tp == MsgType.close:
# self._closing = True
# self._close_code = msg.data
# if not self._closed and self._autoclose:
# yield from self.close()
# return msg
# elif not self._closed:
# if msg.tp == MsgType.ping and self._autoping:
# self._writer.pong(msg.data)
# elif msg.tp == MsgType.pong and self._autoping:
# continue
# else:
# return msg
# finally:
# self._waiting = False
while True:
if self._closed:
return closedMessage
try:
msg = yield from self._reader.read()
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except WebSocketError as exc:
self._close_code = exc.code
yield from self.close(code=exc.code)
raise
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
yield from self.close()
raise
if msg.tp == MsgType.close:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
yield from self.close()
raise RuntimeError("Socket connection closed by server.")
elif not self._closed:
if msg.tp == MsgType.ping and self._autoping:
self._writer.pong(msg.data)
elif msg.tp == MsgType.pong and self._autoping:
continue
else:
if msg.tp == MsgType.binary:
self.parser.feed_data(msg.data.decode())
elif msg.tp == MsgType.text:
self.parser.feed_data(msg.data.strip())
break
finally:
self._waiting = False
class ClientContextManager:
__slots__ = ("_client")
def __init__(self, client):
self._client = client
def __enter__(self):
return self._client
def __exit__(self, *args):
try:
yield from self._client.close()
finally:
self._client = None
class ConnectionContextManager:
__slots__ = ("_conn", "_pool")
......@@ -27,12 +10,8 @@ class ConnectionContextManager:
return self._conn
def __exit__(self, exception_type, exception_value, traceback):
print("in __exit__")
import ipdb; ipdb.set_trace()
print("agains")
try:
print("hello")
yield from self._conn.release()
self._conn._close()
finally:
self._conn = None
self._pool = None
......@@ -50,7 +50,6 @@ class GremlinWriter:
def __init__(self, connection):
self._connection = connection
@asyncio.coroutine
def write(self, gremlin, bindings=None, lang="gremlin-groovy", op="eval",
processor="", session=None, binary=True,
mime_type="application/json"):
......
......@@ -12,8 +12,7 @@ setup(
long_description=open("README.txt").read(),
packages=["aiogremlin", "tests"],
install_requires=[
"aiohttp==0.15.3",
"ujson==1.33"
"aiohttp==0.15.3"
],
test_suite="tests",
classifiers=[
......
......@@ -9,7 +9,34 @@ from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
SocketClientError, WebSocketPool, AiohttpFactory, create_client,
GremlinWriter, GremlinResponse)
#
class GremlinClientTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop)
def tearDown(self):
self.loop.run_until_complete(self.gc.close())
self.loop.close()
def test_connection(self):
@asyncio.coroutine
def conn_coro():
conn = yield from self.gc._acquire()
self.assertFalse(conn.closed)
return conn
conn = self.loop.run_until_complete(conn_coro())
# Clean up the resource.
self.loop.run_until_complete(conn.close())
def test_sub(self):
execute = self.gc.execute("x + x", bindings={"x": 4})
results = self.loop.run_until_complete(execute)
self.assertEqual(results[0].data[0], 8)
class GremlinClientPoolTests(unittest.TestCase):
def setUp(self):
......@@ -192,37 +219,42 @@ class WebSocketPoolTests(unittest.TestCase):
self.assertFalse(conn2.closed)
self.loop.run_until_complete(conn())
#
# class ContextMngrTest(unittest.TestCase):
#
# def setUp(self):
# self.loop = asyncio.new_event_loop()
# asyncio.set_event_loop(None)
# self.pool = WebSocketPool(poolsize=1, loop=self.loop,
# factory=AiohttpFactory)
#
# def tearDown(self):
# self.loop.run_until_complete(self.pool.close())
# self.loop.close()
#
# def test_connection_manager(self):
# results = []
# @asyncio.coroutine
# def go():
#
# # import ipdb; ipdb.set_trace()
# with (yield from self.pool) as conn:
# writer = GremlinWriter(conn)
# conn = yield from writer.write("1 + 1")
# resp = GremlinResponse(conn, loop=self.loop)
# while True:
# mssg = yield from resp.stream.read()
# if mssg is None:
# break
# results.append(mssg)
# conn = self.pool._pool.get_nowait()
# self.assertTrue(conn.closed)
# self.loop.run_until_complete(go())
class ContextMngrTest(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.pool = WebSocketPool(poolsize=1, loop=self.loop,
factory=AiohttpFactory, max_retries=0)
def tearDown(self):
# self.loop.run_until_complete(self.pool.close())
self.loop.close()
def test_connection_manager(self):
results = []
@asyncio.coroutine
def go():
with (yield from self.pool) as conn:
writer = GremlinWriter(conn)
conn = writer.write("1 + 1")
resp = GremlinResponse(conn, self.pool, loop=self.loop)
while True:
mssg = yield from resp.stream.read()
if mssg is None:
break
results.append(mssg)
conn = self.pool._pool.get_nowait()
self.assertTrue(conn.closed)
writer = GremlinWriter(conn)
try:
conn = yield from writer.write("1 + 1")
error = False
except RuntimeError:
error = True
self.assertTrue(error)
self.loop.run_until_complete(go())
class CreateClientTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment