Commit 91fc71cd authored by davebshow's avatar davebshow
Browse files

working on some refactoring

parent a24ba6ff
from .abc import AbstractFactory, AbstractConnection
from .connection import (WebsocketPool, AiohttpFactory, BaseFactory,
BaseConnection)
from .connection import AiohttpFactory, BaseFactory, BaseConnection
from .client import (create_client, GremlinClient, GremlinResponse,
GremlinResponseStream)
from .exceptions import RequestError, GremlinServerError, SocketClientError
from .pool import WebSocketPool
from .protocol import GremlinWriter
__version__ = "0.0.6"
......@@ -18,32 +18,23 @@ class AbstractFactory(metaclass=ABCMeta):
class AbstractConnection(metaclass=ABCMeta):
@abstractmethod
def feed_pool(self):
pass
@abstractmethod
def release(self):
pass
@property
@abstractmethod
def pool(self):
pass
# @property
# @abstractmethod
# def closed(self):
# pass
@property
@abstractmethod
def closed(self):
def close():
pass
@abstractmethod
def close():
def _close():
pass
@abstractmethod
def send(self):
pass
@abstractmethod
def _receive(self):
pass
# @abstractmethod
# def receive(self):
# pass
......@@ -2,12 +2,14 @@
import asyncio
import ssl
import uuid
import aiohttp
from aiogremlin.connection import WebsocketPool
from aiogremlin.log import client_logger, INFO
from aiogremlin.connection import AiohttpFactory
from aiogremlin.contextmanager import ClientContextManager
from aiogremlin.exceptions import RequestError
from aiogremlin.log import logger, INFO
from aiogremlin.pool import WebSocketPool
from aiogremlin.protocol import gremlin_response_parser, GremlinWriter
......@@ -16,14 +18,14 @@ def create_client(uri='ws://localhost:8182/', loop=None, ssl=None,
protocol=None, lang="gremlin-groovy", op="eval",
processor="", pool=None, factory=None, poolsize=10,
timeout=None, verbose=False, **kwargs):
pool = WebsocketPool(uri,
pool = WebSocketPool(uri,
factory=factory,
poolsize=poolsize,
timeout=timeout,
loop=loop,
verbose=verbose)
yield from pool.init_pool()
yield from pool.fill_pool()
return GremlinClient(uri=uri,
loop=loop,
......@@ -42,7 +44,7 @@ class GremlinClient:
def __init__(self, uri='ws://localhost:8182/', loop=None, ssl=None,
protocol=None, lang="gremlin-groovy", op="eval",
processor="", pool=None, factory=None, poolsize=10,
timeout=None, verbose=True, **kwargs):
timeout=None, verbose=False, **kwargs):
"""
"""
self.uri = uri
......@@ -60,11 +62,15 @@ class GremlinClient:
self.processor = processor or ""
self.poolsize = poolsize
self.timeout = timeout
self.pool = pool or WebsocketPool(uri, factory=factory,
poolsize=poolsize, timeout=timeout, loop=self._loop)
self.factory = factory or self.pool.factory
self._pool = pool
self._factory = factory or AiohttpFactory
if self._pool is None:
self._connected = False
self._conn = asyncio.async(self._connect(), loop=self._loop)
else:
self._connected = self._pool._connected
if verbose:
client_logger.setLevel(INFO)
logger.setLevel(INFO)
@property
def loop(self):
......@@ -72,45 +78,54 @@ class GremlinClient:
@asyncio.coroutine
def close(self):
yield from self.pool.close()
try:
if self._pool:
yield from self._pool.close()
elif self._connected:
yield from self._conn.close()
finally:
self._connected = False
@asyncio.coroutine
def connect(self, **kwargs):
def _connect(self, **kwargs):
"""
"""
loop = kwargs.get("loop", "") or self.loop
connection = yield from self.factory.connect(self.uri, loop=loop,
**kwargs)
loop = kwargs.get("loop", "") or self._loop
connection = yield from self._factory.connect(self.uri, loop=loop)
self._connected = True
return connection
@asyncio.coroutine
def submit(self, gremlin, connection=None, bindings=None, lang=None,
op=None, processor=None, session=None, binary=True):
def _acquire(self, **kwargs):
if self._pool:
conn = yield from self._pool.acquire()
elif self._connected:
conn = self._conn
else:
conn = yield from self._conn
return conn
# Check here for error
# except Error:
# conn = yield from self._connect()
@asyncio.coroutine
def submit(self, gremlin, conn=None, bindings=None, lang=None, op=None,
processor=None, session=None, binary=True):
"""
"""
lang = lang or self.lang
op = op or self.op
processor = processor or self.processor
message = {
"requestId": str(uuid.uuid4()),
"op": op,
"processor": processor,
"args":{
"gremlin": gremlin,
"bindings": bindings,
"language": lang
}
}
if processor == "session":
session = session or str(uuid.uuid4())
message["args"]["session"] = session
client_logger.info(
"Session ID: {}".format(message["args"]["session"]))
if connection is None:
connection = yield from self.pool.connect(self.uri, loop=self.loop)
writer = GremlinWriter(connection)
connection = yield from writer.write(message, binary=binary)
return GremlinResponse(connection, session=session, loop=self._loop)
if conn is None:
conn = yield from self._acquire()
writer = GremlinWriter(conn)
conn = yield from writer.write(gremlin, bindings=bindings,
lang=lang, op=op, processor=processor, session=session,
binary=binary)
return GremlinResponse(conn,
self,
session=session,
loop=self._loop)
@asyncio.coroutine
def execute(self, gremlin, bindings=None, lang=None,
......@@ -127,10 +142,11 @@ class GremlinClient:
class GremlinResponse:
def __init__(self, conn, session=None, loop=None):
def __init__(self, conn, client, session=None, loop=None):
self._loop = loop or asyncio.get_event_loop()
self._client = client
self._session = session
self._stream = GremlinResponseStream(conn, loop=self._loop)
self._stream = GremlinResponseStream(conn, self, loop=self._loop)
@property
def stream(self):
......@@ -156,11 +172,24 @@ class GremlinResponse:
results.append(message)
return results
# aioredis style
def __enter__(self):
raise RuntimeError(
"'yield from' should be used as a context manager expression")
def __exit__(self, *args):
pass
def __iter__(self):
yield from self._pool.create_pool()
return ClientContextManager(self)
class GremlinResponseStream:
def __init__(self, conn, loop=None):
def __init__(self, conn, resp, loop=None):
self._conn = conn
self._resp = resp
self._loop = loop or asyncio.get_event_loop()
data_stream = aiohttp.DataQueue(loop=self._loop)
self._stream = self._conn.parser.set_parser(gremlin_response_parser,
......@@ -170,13 +199,20 @@ class GremlinResponseStream:
def read(self):
# For 3.0.0.M9
# if self._stream.at_eof():
# self._conn.feed_pool()
# self._pool.release(self._conn)
# message = None
# else:
# This will be different 3.0.0.M9
yield from self._conn._receive()
pool = self._resp._client._pool
try:
yield from self._conn.read()
except RequestError:
if pool:
pool.release(self._conn)
print("fed pool")
if self._stream.is_eof():
self._conn.feed_pool()
if pool:
pool.release(self._conn)
message = None
else:
message = yield from self._stream.read()
......
"""
"""
import asyncio
import aiohttp
import base64
import hashlib
import os
from aiohttp import (client, hdrs, DataQueue, StreamParser,
WSServerHandshakeError)
from aiohttp.errors import WSServerHandshakeError
from aiohttp.websocket import WS_KEY, Message
from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING,
MSG_PONG)
from aiohttp.websocket_client import (MsgType, closedMessage,
ClientWebSocketResponse)
from aiogremlin.abc import AbstractFactory, AbstractConnection
from aiogremlin.exceptions import SocketClientError
from aiogremlin.log import INFO, conn_logger
class WebsocketPool:
def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10,
max_retries=10, timeout=None, loop=None, verbose=False):
"""
"""
self.uri = uri
self._factory = factory or AiohttpFactory
self.poolsize = poolsize
self.max_retries = max_retries
self.timeout = timeout
self._loop = loop or asyncio.get_event_loop()
self.pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
self.active_conns = set()
self.num_connecting = 0
self._closed = False
if verbose:
conn_logger.setLevel(INFO)
@asyncio.coroutine
def init_pool(self):
for i in range(self.poolsize):
conn = yield from self.factory.connect(self.uri, pool=self,
loop=self._loop)
self._put(conn)
@property
def loop(self):
return self._loop
@property
def factory(self):
return self._factory
@property
def closed(self):
return self._closed
@property
def num_active_conns(self):
return len(self.active_conns)
def feed_pool(self, conn):
if self._closed:
raise RuntimeError("WebsocketPool is closed.")
self.active_conns.discard(conn)
self._put(conn)
@asyncio.coroutine
def close(self):
if not self._closed:
if self.active_conns:
yield from self._close_active_conns()
yield from self._purge_pool()
self._closed = True
@asyncio.coroutine
def _close_active_conns(self):
tasks = [asyncio.async(conn.close(), loop=self.loop) for conn
in self.active_conns]
yield from asyncio.wait(tasks, loop=self.loop)
@asyncio.coroutine
def _purge_pool(self):
while True:
try:
conn = self.pool.get_nowait()
except asyncio.QueueEmpty:
from aiogremlin.log import INFO, logger
# This is temporary until aiohttp pull #367 is merged/released.
@asyncio.coroutine
def ws_connect(url, protocols=(), timeout=10.0, connector=None,
response_class=None, autoclose=True, autoping=True, loop=None):
"""Initiate websocket connection."""
if loop is None:
loop = asyncio.get_event_loop()
sec_key = base64.b64encode(os.urandom(16))
headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
# send request
resp = yield from client.request(
'get', url, headers=headers,
read_until_eof=False,
connector=connector, loop=loop)
# check handshake
if resp.status != 101:
raise WSServerHandshakeError('Invalid response status')
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
raise WSServerHandshakeError('Invalid upgrade header')
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError('Invalid connection header')
# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
raise WSServerHandshakeError('Invalid challenge response')
# websocket protocol
protocol = None
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
resp_protocols = [proto.strip() for proto in
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in resp_protocols:
if proto in protocols:
protocol = proto
break
else:
yield from conn.close()
@asyncio.coroutine
def connect(self, uri=None, loop=None, num_retries=None):
if self._closed:
raise RuntimeError("WebsocketPool is closed.")
if num_retries is None:
num_retries = self.max_retries
uri = uri or self.uri
loop = loop or self.loop
if not self.pool.empty():
socket = self.pool.get_nowait()
conn_logger.info("Reusing socket: {} at {}".format(socket, uri))
elif self.num_active_conns + self.num_connecting >= self.poolsize:
conn_logger.info("Waiting for socket...")
socket = yield from asyncio.wait_for(self.pool.get(),
self.timeout, loop=loop)
conn_logger.info("Socket acquired: {} at {}".format(socket, uri))
else:
self.num_connecting += 1
try:
socket = yield from self.factory.connect(uri, pool=self,
loop=loop)
finally:
self.num_connecting -= 1
if not socket.closed:
conn_logger.info("New connection on socket: {} at {}".format(
socket, uri))
self.active_conns.add(socket)
# Untested.
elif num_retries > 0:
conn_logger.warning("Got bad socket, retry...")
socket = yield from self.connect(uri, loop, num_retries - 1)
else:
raise RuntimeError("Unable to connect, max retries exceeded.")
return socket
reader = resp.connection.reader.set_parser(WebSocketParser)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
def _put(self, socket):
try:
self.pool.put_nowait(socket)
except asyncio.QueueFull:
pass
# This should be - not working
# yield from socket.release()
if response_class is None:
response_class = ClientWebSocketResponse
return response_class(
reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
class BaseFactory(AbstractFactory):
......@@ -141,107 +98,163 @@ class AiohttpFactory(BaseFactory):
if pool:
loop = loop or pool.loop
try:
socket = yield from aiohttp.ws_connect(uri, protocols=protocols,
connector=connector, autoclose=False, autoping=True,
loop=loop)
except aiohttp.WSServerHandshakeError as e:
return (yield from ws_connect(
uri, protocols=protocols, connector=connector,
response_class=GremlinClientWebSocketResponse,
autoclose=True, autoping=True, loop=loop))
except WSServerHandshakeError as e:
raise SocketClientError(e.message)
return AiohttpConnection(socket, pool, loop=loop)
class BaseConnection(AbstractConnection):
def __init__(self, socket, pool=None, loop=None):
self.socket = socket
def __init__(self, loop=None):
self._loop = loop or asyncio.get_event_loop()
self._pool = pool
self._parser = aiohttp.StreamParser(
buf=aiohttp.DataQueue(loop=self._loop), loop=self._loop)
self._parser = StreamParser(
buf=DataQueue(loop=self._loop), loop=self._loop)
@property
def parser(self):
return self._parser
def feed_pool(self):
if self.pool:
if self in self.pool.active_conns:
self.pool.feed_pool(self)
@asyncio.coroutine
def release(self):
try:
yield from self.close()
finally:
if self in self.pool.active_conns:
self.pool.active_conns.discard(self)
class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
@property
def pool(self):
return self._pool
class AiohttpConnection(BaseConnection):
@property
def closed(self):
return self.socket.closed
def __init__(self, reader, writer, protocol, response, timeout, autoclose,
autoping, loop):
BaseConnection.__init__(self, loop=loop)
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop)
@asyncio.coroutine
def close(self):
if not self.socket._closed:
try:
yield from self.socket.close()
finally:
# Socket should close despite errors.
def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
closed = self._close()
if closed:
return True
while True:
try:
msg = yield from asyncio.wait_for(
self._reader.read(), self._timeout, loop=self._loop)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if msg.tp == MsgType.close:
self._close_code = msg.data
self._response.close(force=True)
return True
else:
return False
@asyncio.coroutine
def _close(self):
try:
self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if self._closing:
self._response.close(force=True)
return True
# @asyncio.coroutine
def send(self, message, binary=True):
if binary:
method = self.socket.send_bytes
method = self.send_bytes
else:
method = self.socket.send_str
method = self.send_str
try:
method(message)
except RuntimeError:
# Socket closed.
yield from self.release()
raise
except TypeError:
# Bytes/string input error.
yield from self.release()
raise
@asyncio.coroutine
def _receive(self):
def read(self):
"""Implements a dispatcher using the aiohttp websocket protocol."""
try:
message = yield from self.socket.receive()
message = yield from self.receive()
except (asyncio.CancelledError, asyncio.TimeoutError):
yield from self.release()
raise
except RuntimeError:
yield from self.release()
# Hmm maybe don't close here