Commit 98cea15e authored by davebshow's avatar davebshow
Browse files

using a subclassed ClientSession as factory to leverage aiohttp connection pooling

parent 2ffa605b
from .abc import AbstractFactory, AbstractConnection
from .connection import AiohttpFactory, BaseFactory, BaseConnection
from .connection import (AiohttpFactory, BaseFactory, BaseConnection,
WebSocketSession)
from .client import (create_client, GremlinClient, GremlinResponse,
GremlinResponseStream)
from .exceptions import RequestError, GremlinServerError, SocketClientError
......
......@@ -5,14 +5,8 @@ from abc import ABCMeta, abstractmethod
class AbstractFactory(metaclass=ABCMeta):
@classmethod
@abstractmethod
def connect(cls):
pass
@property
@abstractmethod
def factory(self):
def ws_connect(cls):
pass
......
......@@ -62,7 +62,7 @@ class GremlinClient:
self.poolsize = poolsize
self.timeout = timeout
self._pool = pool
self._factory = factory or AiohttpFactory
self._factory = factory or AiohttpFactory()
if self._pool is None:
self._connected = False
self._conn = asyncio.async(self._connect(), loop=self._loop)
......@@ -90,7 +90,7 @@ class GremlinClient:
"""
"""
loop = kwargs.get("loop", "") or self._loop
connection = yield from self._factory.connect(self.uri, loop=loop)
connection = yield from self._factory.ws_connect(self.uri, loop=loop)
self._connected = True
return connection
......
......@@ -6,7 +6,7 @@ import hashlib
import os
from aiohttp import (client, hdrs, DataQueue, StreamParser,
WSServerHandshakeError)
WSServerHandshakeError, ClientSession, TCPConnector)
from aiohttp.errors import WSServerHandshakeError
from aiohttp.websocket import WS_KEY, Message
from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
......@@ -20,68 +20,98 @@ from aiogremlin.exceptions import SocketClientError
from aiogremlin.log import INFO, logger
# This is temporary until aiohttp pull #367 is merged/released.
@asyncio.coroutine
def ws_connect(url, protocols=(), timeout=10.0, connector=None,
response_class=None, autoclose=True, autoping=True, loop=None):
"""Initiate websocket connection."""
if loop is None:
loop = asyncio.get_event_loop()
class WebSocketSession(AbstractFactory, ClientSession):
@asyncio.coroutine
def ws_connect(self, url, protocols=(), timeout=10.0, connector=None,
response_class=None, autoclose=True, autoping=True,
loop=None):
"""Initiate websocket connection."""
sec_key = base64.b64encode(os.urandom(16))
sec_key = base64.b64encode(os.urandom(16))
headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
# send request
resp = yield from client.request(
'get', url, headers=headers,
read_until_eof=False,
connector=connector, loop=loop)
# send request
resp = yield from self.request('get', url, headers=headers,
read_until_eof=False)
# check handshake
if resp.status != 101:
raise WSServerHandshakeError('Invalid response status')
# check handshake
if resp.status != 101:
raise WSServerHandshakeError('Invalid response status')
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
raise WSServerHandshakeError('Invalid upgrade header')
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
raise WSServerHandshakeError('Invalid upgrade header')
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError('Invalid connection header')
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError('Invalid connection header')
# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
raise WSServerHandshakeError('Invalid challenge response')
# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
raise WSServerHandshakeError('Invalid challenge response')
# websocket protocol
protocol = None
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
resp_protocols = [proto.strip() for proto in
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
# websocket protocol
protocol = None
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
resp_protocols = [proto.strip() for proto in
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in resp_protocols:
if proto in protocols:
protocol = proto
break
for proto in resp_protocols:
if proto in protocols:
protocol = proto
break
reader = resp.connection.reader.set_parser(WebSocketParser)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
reader = resp.connection.reader.set_parser(WebSocketParser)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
if response_class is None:
response_class = ClientWebSocketResponse
if response_class is None:
response_class = ClientWebSocketResponse
return response_class(
reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
return response_class(
reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
def detach(self):
"""Detach connector from session without closing the former.
Session is switched to closed state anyway.
"""
self._connector = None
def ws_connect(url, protocols=(), timeout=10.0, connector=None,
response_class=None, autoclose=True, autoping=True,
loop=None):
if loop is None:
asyncio.get_event_loop()
if connector is None:
connector = TCPConnector(loop=loop, force_close=True)
ws_session = WebSocketSession(loop=loop, connector=connector)
try:
resp = yield from ws_session.ws_connect(url,
protocols=protocols,
timeout=timeout,
connector=connector,
response_class=response_class,
autoclose=autoclose,
autoping=autoping,
loop=loop)
return resp
finally:
ws_session.detach()
# Will drop 'pluggable sockets' implementation in favour of aiohttp default.
class BaseFactory(AbstractFactory):
@property
......@@ -91,17 +121,18 @@ class BaseFactory(AbstractFactory):
class AiohttpFactory(BaseFactory):
@classmethod
@asyncio.coroutine
def connect(cls, uri='ws://localhost:8182/', pool=None, protocols=(),
connector=None, autoclose=False, autoping=True, loop=None):
if pool:
loop = loop or pool.loop
def ws_connect(cls, uri='ws://localhost:8182/', protocols=(),
connector=None, autoclose=False, autoping=True,
response_class=None, loop=None):
if response_class is None:
response_class = GremlinClientWebSocketResponse
try:
return (yield from ws_connect(
uri, protocols=protocols, connector=connector,
response_class=GremlinClientWebSocketResponse,
autoclose=True, autoping=True, loop=loop))
response_class=response_class, autoclose=True, autoping=True,
loop=loop))
except WSServerHandshakeError as e:
raise SocketClientError(e.message)
......
import asyncio
from aiogremlin.connection import AiohttpFactory
from aiogremlin.connection import (AiohttpFactory,
GremlinClientWebSocketResponse)
from aiogremlin.contextmanager import ConnectionContextManager
from aiogremlin.log import logger
......@@ -11,11 +12,12 @@ def create_pool():
class WebSocketPool:
def __init__(self, uri='ws://localhost:8182/', factory=None, poolsize=10,
max_retries=10, timeout=None, loop=None, verbose=False):
def __init__(self, url='ws://localhost:8182/', factory=None, poolsize=10,
max_retries=10, timeout=None, loop=None, verbose=False,
response_class=None):
"""
"""
self.uri = uri
self.url = url
self._factory = factory or AiohttpFactory
self.poolsize = poolsize
self.max_retries = max_retries
......@@ -25,15 +27,21 @@ class WebSocketPool:
self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
self.active_conns = set()
self.num_connecting = 0
self._response_class = response_class or GremlinClientWebSocketResponse
self._closed = False
if verbose:
logger.setLevel(INFO)
@asyncio.coroutine
def fill_pool(self):
for i in range(self.poolsize):
conn = yield from self.factory.connect(self.uri, pool=self,
loop=self._loop)
tasks = []
poolsize = self.poolsize
for i in range(poolsize):
task = asyncio.async(self.factory.ws_connect(self.url,
response_class=self._response_class, loop=self._loop), loop=self._loop)
tasks.append(task)
for f in asyncio.as_completed(tasks, loop=self._loop):
conn = yield from f
self._put(conn)
self._connected = True
......@@ -84,36 +92,36 @@ class WebSocketPool:
yield from conn.close()
@asyncio.coroutine
def acquire(self, uri=None, loop=None, num_retries=None):
def acquire(self, url=None, loop=None, num_retries=None):
if self._closed:
raise RuntimeError("WebsocketPool is closed.")
if num_retries is None:
num_retries = self.max_retries
uri = uri or self.uri
url = url or self.url
loop = loop or self.loop
if not self._pool.empty():
socket = self._pool.get_nowait()
logger.info("Reusing socket: {} at {}".format(socket, uri))
logger.info("Reusing socket: {} at {}".format(socket, url))
elif self.num_active_conns + self.num_connecting >= self.poolsize:
logger.info("Waiting for socket...")
socket = yield from asyncio.wait_for(self._pool.get(),
self.timeout, loop=loop)
logger.info("Socket acquired: {} at {}".format(socket, uri))
logger.info("Socket acquired: {} at {}".format(socket, url))
else:
self.num_connecting += 1
try:
socket = yield from self.factory.connect(uri, pool=self,
loop=loop)
socket = yield from self.factory.ws_connect(url,
response_class=self._response_class, loop=loop)
finally:
self.num_connecting -= 1
if not socket.closed:
logger.info("New connection on socket: {} at {}".format(
socket, uri))
socket, url))
self.active_conns.add(socket)
# Untested.
elif num_retries > 0:
logger.warning("Got bad socket, retry...")
socket = yield from self.acquire(uri, loop, num_retries - 1)
socket = yield from self.acquire(url, loop, num_retries - 1)
else:
raise RuntimeError("Unable to connect, max retries exceeded.")
return socket
......
......@@ -89,6 +89,10 @@ ARGS.add_argument(
'-w', '--warmups', action="store",
nargs='?', type=int, default=5,
help='num warmups (default: `%(default)s`)')
ARGS.add_argument(
'-s', '--session', action="store",
nargs='?', type=str, default="false",
help='use session to establish connections (default: `%(default)s`)')
if __name__ == "__main__":
......@@ -98,12 +102,20 @@ if __name__ == "__main__":
concurr = args.concurrency
poolsize = args.poolsize
num_warmups = args.warmups
session = args.session
loop = asyncio.get_event_loop()
t1 = loop.time()
if session == "true":
factory = aiogremlin.WebSocketSession()
else:
factory = aiogremlin.AiohttpFactory()
client = loop.run_until_complete(
aiogremlin.create_client(loop=loop, poolsize=poolsize))
aiogremlin.create_client(loop=loop, factory=factory, poolsize=poolsize))
t2 = loop.time()
print("time to establish conns: {}".format(t2 - t1))
try:
print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}".format(
num_tests, num_warmups, num_mssg, concurr, poolsize))
print("Runs: {}. Warmups: {}. Messages: {}. Concurrency: {}. Poolsize: {}. Use Session: {}".format(
num_tests, num_warmups, num_mssg, concurr, poolsize, session))
main = main(client, num_tests, num_mssg, concurr, num_warmups, loop)
loop.run_until_complete(main)
finally:
......
import argparse
import asyncio
import aiogremlin
@asyncio.coroutine
def create_destroy(loop, factory, poolsize):
client = yield from aiogremlin.create_client(loop=loop,
factory=factory,
poolsize=poolsize)
yield from client.close()
# NEED TO ADD MORE ARGS/CLEAN UP like benchmark.py
ARGS = argparse.ArgumentParser(description="Run benchmark.")
ARGS.add_argument(
'-t', '--tests', action="store",
nargs='?', type=int, default=10,
help='number of tests (default: `%(default)s`)')
ARGS.add_argument(
'-s', '--session', action="store",
nargs='?', type=str, default="false",
help='use session to establish connections (default: `%(default)s`)')
if __name__ == "__main__":
args = ARGS.parse_args()
tests = args.tests
print("tests", tests)
session = args.session
loop = asyncio.get_event_loop()
if session == "true":
factory = aiogremlin.WebSocketSession()
else:
factory = aiogremlin.AiohttpFactory()
print("factory: {}".format(factory))
try:
m1 = loop.time()
for x in range(50):
tasks = []
for x in range(tests):
task = asyncio.async(
create_destroy(loop, factory, 100)
)
tasks.append(task)
t1 = loop.time()
loop.run_until_complete(asyncio.async(asyncio.gather(*tasks, loop=loop)))
t2 = loop.time()
print("avg: time to establish conn: {}".format((t2 - t1) / (tests * 50)))
m2 = loop.time()
print("time to establish conns: {}".format((m2 - m1)))
print("avg time to establish conns: {}".format((m2 - m1) / (tests * 100 * 50)))
finally:
loop.close()
print("CLOSED CLIENT AND LOOP")
......@@ -7,44 +7,264 @@ import unittest
import uuid
from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
SocketClientError, WebSocketPool, AiohttpFactory, create_client,
GremlinWriter, GremlinResponse)
class GremlinClientTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop)
def tearDown(self):
self.loop.run_until_complete(self.gc.close())
self.loop.close()
def test_connection(self):
@asyncio.coroutine
def conn_coro():
conn = yield from self.gc._acquire()
self.assertFalse(conn.closed)
return conn
conn = self.loop.run_until_complete(conn_coro())
# Clean up the resource.
self.loop.run_until_complete(conn.close())
def test_sub(self):
execute = self.gc.execute("x + x", bindings={"x": 4})
results = self.loop.run_until_complete(execute)
self.assertEqual(results[0].data[0], 8)
class GremlinClientPoolTests(unittest.TestCase):
GremlinWriter, GremlinResponse, WebSocketSession)
#
#
# class GremlinClientTests(unittest.TestCase):
#
# def setUp(self):
# self.loop = asyncio.new_event_loop()
# asyncio.set_event_loop(None)
# self.gc = GremlinClient("ws://localhost:8182/", loop=self.loop)
#
# def tearDown(self):
# self.loop.run_until_complete(self.gc.close())
# self.loop.close()
#
# def test_connection(self):
# @asyncio.coroutine
# def conn_coro():
# conn = yield from self.gc._acquire()
# self.assertFalse(conn.closed)
# return conn
# conn = self.loop.run_until_complete(conn_coro())
# # Clean up the resource.
# self.loop.run_until_complete(conn.close())
#
# def test_sub(self):
# execute = self.gc.execute("x + x", bindings={"x": 4})
# results = self.loop.run_until_complete(execute)
# self.assertEqual(results[0].data[0], 8)
#
#
# class GremlinClientPoolTests(unittest.TestCase):
#
# def setUp(self):
# self.loop = asyncio.new_event_loop()
# asyncio.set_event_loop(None)
# self.gc = GremlinClient("ws://localhost:8182/",
# factory=AiohttpFactory(), pool=WebSocketPool("ws://localhost:8182/",
# loop=self.loop),
# loop=self.loop)
#
# def tearDown(self):
# self.loop.run_until_complete(self.gc.close())
# self.loop.close()
#
# def test_connection(self):
# @asyncio.coroutine
# def conn_coro():
# conn = yield from self.gc._acquire()
# self.assertFalse(conn.closed)
# return conn
# conn = self.loop.run_until_complete(conn_coro())
# # Clean up the resource.
# self.loop.run_until_complete(conn.close())
#
# def test_sub(self):
# execute = self.gc.execute("x + x", bindings={"x": 4})
# results = self.loop.run_until_complete(execute)
# self.assertEqual(results[0].data[0], 8)
#
# def test_sub_waitfor(self):
# sub1 = self.gc.execute("x + x", bindings={"x": 1})
# sub2 = self.gc.execute("x + x", bindings={"x": 2})
# sub3 = self.gc.execute("x + x", bindings={"x": 4})
# coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop),
# asyncio.async(sub2, loop=self.loop),
# asyncio.async(sub3, loop=self.loop)], loop=self.loop)
# # Here I am looking for resource warnings.
# results = self.loop.run_until_complete(coro)
# self.assertIsNotNone(results)
#
# def test_resp_stream(self):
# @asyncio.coroutine
# def stream_coro():
# results = []
# resp = yield from self.gc.submit("x + x", bindings={"x": 4})
# while True:
# f = yield from resp.stream.read()
# if f is None:
# break
# results.append(f)
# self.assertEqual(results[0].data[0], 8)
# self.loop.run_until_complete(stream_coro())
#
# def test_resp_get(self):
# @asyncio.coroutine
# def get_coro():
# conn = yield from self.gc.submit("x + x", bindings={"x": 4})
# results = yield from conn.get()
# self.assertEqual(results[0].data[0], 8)
# self.loop.run_until_complete(get_coro())
#
# def test_execute_error(self):
# execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4})
# try:
# self.loop.run_until_complete(execute)
# error = False
# except:
# error = True
# self.assertTrue(error)
#
# def test_session_gen(self):
# execute = self.gc.execute("x + x", processor="session", bindings={"x": 4})
# results = self.loop.run_until_complete(execute)
# self.assertEqual(results[0].data[0], 8)
#
# def test_session(self):
# @asyncio.coroutine
# def stream_coro():
# session = str(uuid.uuid4())
# resp = yield from self.gc.submit("x + x", bindings={"x": 4},
# session=session)
# while True:
# f = yield from resp.stream.read()
# if f is None:
# break
# self.assertEqual(resp.session, session)
# self.loop.run_until_complete(stream_coro())
#
#
# class WebSocketPoolTests(unittest.TestCase):
#
# def setUp(self):
# self.loop = asyncio.new_event_loop()
# asyncio.set_event_loop(None)
# self.pool = WebSocketPool(poolsize=2, timeout=1, loop=self.loop,
# factory=AiohttpFactory())
#
# def tearDown(self):
# self.loop.run_until_complete(self.pool.close())
# self.loop.close()
#
# def test_connect(self):
#
# @asyncio.coroutine
# def conn():
# conn = yield from self.pool.acquire()
# self.assertFalse(conn.closed)
# self.pool.release(conn)
# self.assertEqual(self.pool.num_active_conns, 0)
#
# self.loop.run_until_complete(conn())
#
# def test_multi_connect(self):
#
# @asyncio.coroutine
# def conn():
# conn1 = yield from self.pool.acquire()
# conn2 = yield from self.pool.acquire()
# self.assertFalse(conn1.closed)
# self.assertFalse(conn2.closed)
# self.pool.release(conn1)
# self.assertEqual(self.pool.num_active_conns, 1)
# self.pool.release(conn2)
# self.assertEqual(self.pool.num_active_conns, 0)
#
# self.loop.run_until_complete(conn())
#
# def test_timeout(self):
#
# @asyncio.coroutine
# def conn():
# conn1 = yield from self.pool.acquire()
# conn2 = yield from self.pool.acquire()
# try:
# conn3 = yield from self.pool.acquire()