Commit 91fc71cd authored by davebshow's avatar davebshow
Browse files

working on some refactoring

parent a24ba6ff
from .abc import AbstractFactory, AbstractConnection from .abc import AbstractFactory, AbstractConnection
from .connection import (WebsocketPool, AiohttpFactory, BaseFactory, from .connection import AiohttpFactory, BaseFactory, BaseConnection
from .client import (create_client, GremlinClient, GremlinResponse, from .client import (create_client, GremlinClient, GremlinResponse,
GremlinResponseStream) GremlinResponseStream)
from .exceptions import RequestError, GremlinServerError, SocketClientError from .exceptions import RequestError, GremlinServerError, SocketClientError
from .pool import WebSocketPool
from .protocol import GremlinWriter from .protocol import GremlinWriter
__version__ = "0.0.6" __version__ = "0.0.6"
...@@ -18,32 +18,23 @@ class AbstractFactory(metaclass=ABCMeta): ...@@ -18,32 +18,23 @@ class AbstractFactory(metaclass=ABCMeta):
class AbstractConnection(metaclass=ABCMeta): class AbstractConnection(metaclass=ABCMeta):
@abstractmethod # @property
def feed_pool(self): # @abstractmethod
pass # def closed(self):
# pass
def release(self):
def pool(self):
@abstractmethod @abstractmethod
def closed(self): def close():
pass pass
@abstractmethod @abstractmethod
def close(): def _close():
pass pass
@abstractmethod @abstractmethod
def send(self): def send(self):
pass pass
@abstractmethod # @abstractmethod
def _receive(self): # def receive(self):
pass # pass
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
import asyncio import asyncio
import ssl import ssl
import uuid
import aiohttp import aiohttp
from aiogremlin.connection import WebsocketPool from aiogremlin.connection import AiohttpFactory
from aiogremlin.log import client_logger, INFO from aiogremlin.contextmanager import ClientContextManager
from aiogremlin.exceptions import RequestError
from aiogremlin.log import logger, INFO
from aiogremlin.pool import WebSocketPool
from aiogremlin.protocol import gremlin_response_parser, GremlinWriter from aiogremlin.protocol import gremlin_response_parser, GremlinWriter
...@@ -16,14 +18,14 @@ def create_client(uri='ws://localhost:8182/', loop=None, ssl=None, ...@@ -16,14 +18,14 @@ def create_client(uri='ws://localhost:8182/', loop=None, ssl=None,
protocol=None, lang="gremlin-groovy", op="eval", protocol=None, lang="gremlin-groovy", op="eval",
processor="", pool=None, factory=None, poolsize=10, processor="", pool=None, factory=None, poolsize=10,
timeout=None, verbose=False, **kwargs): timeout=None, verbose=False, **kwargs):
pool = WebsocketPool(uri, pool = WebSocketPool(uri,
factory=factory, factory=factory,
poolsize=poolsize, poolsize=poolsize,
timeout=timeout, timeout=timeout,
loop=loop, loop=loop,
verbose=verbose) verbose=verbose)
yield from pool.init_pool() yield from pool.fill_pool()
return GremlinClient(uri=uri, return GremlinClient(uri=uri,
loop=loop, loop=loop,
...@@ -42,7 +44,7 @@ class GremlinClient: ...@@ -42,7 +44,7 @@ class GremlinClient:
def __init__(self, uri='ws://localhost:8182/', loop=None, ssl=None, def __init__(self, uri='ws://localhost:8182/', loop=None, ssl=None,
protocol=None, lang="gremlin-groovy", op="eval", protocol=None, lang="gremlin-groovy", op="eval",
processor="", pool=None, factory=None, poolsize=10, processor="", pool=None, factory=None, poolsize=10,
timeout=None, verbose=True, **kwargs): timeout=None, verbose=False, **kwargs):
""" """
""" """
self.uri = uri self.uri = uri
...@@ -60,11 +62,15 @@ class GremlinClient: ...@@ -60,11 +62,15 @@ class GremlinClient:
self.processor = processor or "" self.processor = processor or ""
self.poolsize = poolsize self.poolsize = poolsize
self.timeout = timeout self.timeout = timeout
self.pool = pool or WebsocketPool(uri, factory=factory, self._pool = pool
poolsize=poolsize, timeout=timeout, loop=self._loop) self._factory = factory or AiohttpFactory
self.factory = factory or self.pool.factory if self._pool is None:
self._connected = False
self._conn = asyncio.async(self._connect(), loop=self._loop)
self._connected = self._pool._connected
if verbose: if verbose:
client_logger.setLevel(INFO) logger.setLevel(INFO)
@property @property
def loop(self): def loop(self):
...@@ -72,45 +78,54 @@ class GremlinClient: ...@@ -72,45 +78,54 @@ class GremlinClient:
@asyncio.coroutine @asyncio.coroutine
def close(self): def close(self):
yield from self.pool.close() try:
if self._pool:
yield from self._pool.close()
elif self._connected:
yield from self._conn.close()
self._connected = False
@asyncio.coroutine @asyncio.coroutine
def connect(self, **kwargs): def _connect(self, **kwargs):
""" """
""" """
loop = kwargs.get("loop", "") or self.loop loop = kwargs.get("loop", "") or self._loop
connection = yield from self.factory.connect(self.uri, loop=loop, connection = yield from self._factory.connect(self.uri, loop=loop)
**kwargs) self._connected = True
return connection return connection
@asyncio.coroutine @asyncio.coroutine
def submit(self, gremlin, connection=None, bindings=None, lang=None, def _acquire(self, **kwargs):
op=None, processor=None, session=None, binary=True): if self._pool:
conn = yield from self._pool.acquire()
elif self._connected:
conn = self._conn
conn = yield from self._conn
return conn
# Check here for error
# except Error:
# conn = yield from self._connect()
def submit(self, gremlin, conn=None, bindings=None, lang=None, op=None,
processor=None, session=None, binary=True):
""" """
""" """
lang = lang or self.lang lang = lang or self.lang
op = op or self.op op = op or self.op
processor = processor or self.processor processor = processor or self.processor
message = { if conn is None:
"requestId": str(uuid.uuid4()), conn = yield from self._acquire()
"op": op, writer = GremlinWriter(conn)
"processor": processor, conn = yield from writer.write(gremlin, bindings=bindings,
"args":{ lang=lang, op=op, processor=processor, session=session,
"gremlin": gremlin, binary=binary)
"bindings": bindings, return GremlinResponse(conn,
"language": lang self,
} session=session,
} loop=self._loop)
if processor == "session":
session = session or str(uuid.uuid4())
message["args"]["session"] = session
"Session ID: {}".format(message["args"]["session"]))
if connection is None:
connection = yield from self.pool.connect(self.uri, loop=self.loop)
writer = GremlinWriter(connection)
connection = yield from writer.write(message, binary=binary)
return GremlinResponse(connection, session=session, loop=self._loop)
@asyncio.coroutine @asyncio.coroutine
def execute(self, gremlin, bindings=None, lang=None, def execute(self, gremlin, bindings=None, lang=None,
...@@ -127,10 +142,11 @@ class GremlinClient: ...@@ -127,10 +142,11 @@ class GremlinClient:
class GremlinResponse: class GremlinResponse:
def __init__(self, conn, session=None, loop=None): def __init__(self, conn, client, session=None, loop=None):
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._client = client
self._session = session self._session = session
self._stream = GremlinResponseStream(conn, loop=self._loop) self._stream = GremlinResponseStream(conn, self, loop=self._loop)
@property @property
def stream(self): def stream(self):
...@@ -156,11 +172,24 @@ class GremlinResponse: ...@@ -156,11 +172,24 @@ class GremlinResponse:
results.append(message) results.append(message)
return results return results
# aioredis style
def __enter__(self):
raise RuntimeError(
"'yield from' should be used as a context manager expression")
def __exit__(self, *args):
def __iter__(self):
yield from self._pool.create_pool()
return ClientContextManager(self)
class GremlinResponseStream: class GremlinResponseStream:
def __init__(self, conn, loop=None): def __init__(self, conn, resp, loop=None):
self._conn = conn self._conn = conn
self._resp = resp
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
data_stream = aiohttp.DataQueue(loop=self._loop) data_stream = aiohttp.DataQueue(loop=self._loop)
self._stream = self._conn.parser.set_parser(gremlin_response_parser, self._stream = self._conn.parser.set_parser(gremlin_response_parser,
...@@ -170,13 +199,20 @@ class GremlinResponseStream: ...@@ -170,13 +199,20 @@ class GremlinResponseStream:
def read(self): def read(self):
# For 3.0.0.M9 # For 3.0.0.M9
# if self._stream.at_eof(): # if self._stream.at_eof():
# self._conn.feed_pool() # self._pool.release(self._conn)
# message = None # message = None
# else: # else:
# This will be different 3.0.0.M9 # This will be different 3.0.0.M9
yield from self._conn._receive() pool = self._resp._client._pool
yield from
except RequestError:
if pool:
print("fed pool")
if self._stream.is_eof(): if self._stream.is_eof():
self._conn.feed_pool() if pool:
message = None message = None
else: else:
message = yield from message = yield from
""" """
""" """
import asyncio import asyncio
import base64
import aiohttp import hashlib
import os
from aiohttp import (client, hdrs, DataQueue, StreamParser,
from aiohttp.errors import WSServerHandshakeError
from aiohttp.websocket import WS_KEY, Message
from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING,
from aiohttp.websocket_client import (MsgType, closedMessage,
from import AbstractFactory, AbstractConnection from import AbstractFactory, AbstractConnection
from aiogremlin.exceptions import SocketClientError from aiogremlin.exceptions import SocketClientError
from aiogremlin.log import INFO, conn_logger from aiogremlin.log import INFO, logger
class WebsocketPool: # This is temporary until aiohttp pull #367 is merged/released.
def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10, def ws_connect(url, protocols=(), timeout=10.0, connector=None,
max_retries=10, timeout=None, loop=None, verbose=False): response_class=None, autoclose=True, autoping=True, loop=None):
""" """Initiate websocket connection."""
""" if loop is None:
self.uri = uri loop = asyncio.get_event_loop()
self._factory = factory or AiohttpFactory
self.poolsize = poolsize sec_key = base64.b64encode(os.urandom(16))
self.max_retries = max_retries
self.timeout = timeout headers = {
self._loop = loop or asyncio.get_event_loop() hdrs.UPGRADE: hdrs.WEBSOCKET,
self.pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop) hdrs.CONNECTION: hdrs.UPGRADE,
self.active_conns = set() hdrs.SEC_WEBSOCKET_VERSION: '13',
self.num_connecting = 0 hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
self._closed = False }
if verbose: if protocols:
conn_logger.setLevel(INFO) headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
@asyncio.coroutine # send request
def init_pool(self): resp = yield from client.request(
for i in range(self.poolsize): 'get', url, headers=headers,
conn = yield from self.factory.connect(self.uri, pool=self, read_until_eof=False,
loop=self._loop) connector=connector, loop=loop)
# check handshake
@property if resp.status != 101:
def loop(self): raise WSServerHandshakeError('Invalid response status')
return self._loop
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
@property raise WSServerHandshakeError('Invalid upgrade header')
def factory(self):
return self._factory if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError('Invalid connection header')
def closed(self): # key calculation
return self._closed key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
@property if key != match:
def num_active_conns(self): raise WSServerHandshakeError('Invalid challenge response')
return len(self.active_conns)
# websocket protocol
def feed_pool(self, conn): protocol = None
if self._closed: if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
raise RuntimeError("WebsocketPool is closed.") resp_protocols = [proto.strip() for proto in
self.active_conns.discard(conn) resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in resp_protocols:
@asyncio.coroutine if proto in protocols:
def close(self): protocol = proto
if not self._closed:
if self.active_conns:
yield from self._close_active_conns()
yield from self._purge_pool()
self._closed = True
def _close_active_conns(self):
tasks = [asyncio.async(conn.close(), loop=self.loop) for conn
in self.active_conns]
yield from asyncio.wait(tasks, loop=self.loop)
def _purge_pool(self):
while True:
conn = self.pool.get_nowait()
except asyncio.QueueEmpty:
break break
yield from conn.close()
@asyncio.coroutine reader = resp.connection.reader.set_parser(WebSocketParser)
def connect(self, uri=None, loop=None, num_retries=None): writer = WebSocketWriter(resp.connection.writer, use_mask=True)
if self._closed:
raise RuntimeError("WebsocketPool is closed.")
if num_retries is None:
num_retries = self.max_retries
uri = uri or self.uri
loop = loop or self.loop
if not self.pool.empty():
socket = self.pool.get_nowait()"Reusing socket: {} at {}".format(socket, uri))
elif self.num_active_conns + self.num_connecting >= self.poolsize:"Waiting for socket...")
socket = yield from asyncio.wait_for(self.pool.get(),
self.timeout, loop=loop)"Socket acquired: {} at {}".format(socket, uri))
self.num_connecting += 1
socket = yield from self.factory.connect(uri, pool=self,
self.num_connecting -= 1
if not socket.closed:"New connection on socket: {} at {}".format(
socket, uri))
# Untested.
elif num_retries > 0:
conn_logger.warning("Got bad socket, retry...")
socket = yield from self.connect(uri, loop, num_retries - 1)
raise RuntimeError("Unable to connect, max retries exceeded.")
return socket
def _put(self, socket): if response_class is None:
try: response_class = ClientWebSocketResponse
except asyncio.QueueFull: return response_class(
pass reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
# This should be - not working
# yield from socket.release()
class BaseFactory(AbstractFactory): class BaseFactory(AbstractFactory):
...@@ -141,107 +98,163 @@ class AiohttpFactory(BaseFactory): ...@@ -141,107 +98,163 @@ class AiohttpFactory(BaseFactory):
if pool: if pool:
loop = loop or pool.loop loop = loop or pool.loop
try: try:
socket = yield from aiohttp.ws_connect(uri, protocols=protocols, return (yield from ws_connect(
connector=connector, autoclose=False, autoping=True, uri, protocols=protocols, connector=connector,
loop=loop) response_class=GremlinClientWebSocketResponse,
except aiohttp.WSServerHandshakeError as e: autoclose=True, autoping=True, loop=loop))
except WSServerHandshakeError as e:
raise SocketClientError(e.message) raise SocketClientError(e.message)
return AiohttpConnection(socket, pool, loop=loop)
class BaseConnection(AbstractConnection): class BaseConnection(AbstractConnection):
def __init__(self, socket, pool=None, loop=None): def __init__(self, loop=None):
self.socket = socket
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._pool = pool self._parser = StreamParser(
self._parser = aiohttp.StreamParser( buf=DataQueue(loop=self._loop), loop=self._loop)
buf=aiohttp.DataQueue(loop=self._loop), loop=self._loop)
@property @property
def parser(self): def parser(self):
return self._parser return self._parser
def feed_pool(self):
if self.pool:
if self in self.pool.active_conns:
@asyncio.coroutine class GremlinClientWebSocketResponse(BaseConnection, ClientWebSocketResponse):
def release(self):
yield from self.close()
if self in self.pool.active_conns:
def pool(self):
return self._pool
def __init__(self, reader, writer, protocol, response, timeout, autoclose,
autoping, loop):
BaseConnection.__init__(self, loop=loop)
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop)
class AiohttpConnection(BaseConnection): @asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
closed = self._close()
if closed:
return True
while True: