Commit 2357f109 authored by davebshow's avatar davebshow
Browse files

updated to reflect aiohttp 0.16.0 release

parent 8fe1c881
......@@ -5,7 +5,7 @@ import ssl
import aiohttp
from aiogremlin.connection import (GremlinFactory, WebSocketSession,
from aiogremlin.connection import (GremlinFactory,
GremlinClientWebSocketResponse)
from aiogremlin.exceptions import RequestError
from aiogremlin.log import logger, INFO
......@@ -23,7 +23,7 @@ def create_client(*, url='ws://localhost:8182/', loop=None,
timeout=None, verbose=False, fill_pool=True, connector=None):
if factory is None:
factory = WebSocketSession(
factory = aiohttp.ClientSession(
connector=connector,
ws_response_class=GremlinClientWebSocketResponse,
loop=loop)
......@@ -79,7 +79,8 @@ class GremlinClient:
else:
self._connected = False
self._conn = asyncio.async(self._connect(), loop=self._loop)
self._factory = factory or GremlinFactory(connector=self._connector)
self._factory = factory or GremlinFactory(connector=self._connector,
loop=self._loop)
if verbose:
logger.setLevel(INFO)
......@@ -101,8 +102,7 @@ class GremlinClient:
def _connect(self):
"""
"""
connection = yield from self._factory.ws_connect(self.url,
loop=self._loop)
connection = yield from self._factory.ws_connect(self.url)
self._connected = True
return connection
......
......@@ -5,21 +5,13 @@ import base64
import hashlib
import os
from aiohttp import (client, hdrs, DataQueue, StreamParser,
WSServerHandshakeError, ClientSession, TCPConnector)
from aiohttp.errors import WSServerHandshakeError
from aiohttp.websocket import WS_KEY, Message
from aiohttp.websocket import WebSocketParser, WebSocketWriter, WebSocketError
from aiohttp.websocket import (MSG_BINARY, MSG_TEXT, MSG_CLOSE, MSG_PING,
MSG_PONG)
from aiohttp.websocket_client import (MsgType, closedMessage,
ClientWebSocketResponse)
import aiohttp
from aiohttp.websocket_client import ClientWebSocketResponse
from aiogremlin.exceptions import SocketClientError
from aiogremlin.log import INFO, logger
__all__ = ('WebSocketSession', 'GremlinFactory',
'GremlinClientWebSocketResponse')
__all__ = ('GremlinFactory', 'GremlinClientWebSocketResponse')
class GremlinClientWebSocketResponse(ClientWebSocketResponse):
......@@ -29,7 +21,8 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse):
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose,
autoping, loop)
self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop)
self._parser = aiohttp.StreamParser(buf=aiohttp.DataQueue(loop=loop),
loop=loop)
@property
def parser(self):
......@@ -55,7 +48,7 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse):
self._response.close(force=True)
return True
if msg.tp == MsgType.close:
if msg.tp == aiohttp.MsgType.close:
self._close_code = msg.data
self._response.close(force=True)
return True
......@@ -97,139 +90,32 @@ class GremlinClientWebSocketResponse(ClientWebSocketResponse):
@asyncio.coroutine
def receive(self):
msg = yield from super().receive()
if msg.tp == MsgType.binary:
if msg.tp == aiohttp.MsgType.binary:
self.parser.feed_data(msg.data.decode())
elif msg.tp == MsgType.text:
elif msg.tp == aiohttp.MsgType.text:
self.parser.feed_data(msg.data.strip())
else:
if msg.tp == MsgType.close:
if msg.tp == aiohttp.MsgType.close:
yield from ws.close()
elif msg.tp == MsgType.error:
elif msg.tp == aiohttp.MsgType.error:
raise msg.data
elif msg.tp == MsgType.closed:
elif msg.tp == aiohttp.MsgType.closed:
pass
# Basically cut and paste from aiohttp until merge/release of #374
class WebSocketSession(ClientSession):
def __init__(self, *, connector=None, loop=None,
cookies=None, headers=None, auth=None,
ws_response_class=GremlinClientWebSocketResponse):
super().__init__(connector=connector, loop=loop,
cookies=cookies, headers=headers, auth=auth)
self._ws_response_class = ws_response_class
@asyncio.coroutine
def ws_connect(self, url, *,
protocols=(),
timeout=10.0,
autoclose=True,
autoping=True,
loop=None):
"""Initiate websocket connection."""
sec_key = base64.b64encode(os.urandom(16))
headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
# send request
resp = yield from self.request('get', url, headers=headers,
read_until_eof=False)
# check handshake
if resp.status != 101:
raise WSServerHandshakeError('Invalid response status')
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
raise WSServerHandshakeError('Invalid upgrade header')
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError('Invalid connection header')
# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(
hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
raise WSServerHandshakeError('Invalid challenge response')
# websocket protocol
protocol = None
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
resp_protocols = [
proto.strip() for proto in
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in resp_protocols:
if proto in protocols:
protocol = proto
break
reader = resp.connection.reader.set_parser(WebSocketParser)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
return self._ws_response_class(
reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
def detach(self):
"""Detach connector from session without closing the former.
Session is switched to closed state anyway.
"""
self._connector = None
# Cut and paste from aiohttp until merge/release of #374
def ws_connect(url, *, protocols=(), timeout=10.0, connector=None,
ws_response_class=None, autoclose=True, autoping=True,
loop=None):
if loop is None:
asyncio.get_event_loop()
if connector is None:
connector = TCPConnector(loop=loop, force_close=True)
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
ws_session = WebSocketSession(loop=loop, connector=connector,
ws_response_class=ws_response_class)
try:
resp = yield from ws_session.ws_connect(
url,
protocols=protocols,
timeout=timeout,
autoclose=autoclose,
autoping=autoping,
loop=loop)
return resp
finally:
ws_session.detach()
class GremlinFactory:
def __init__(self, connector=None, ws_response_class=None):
def __init__(self, connector=None, loop=None):
self._connector = connector
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
self._ws_response_class = ws_response_class
self._loop = loop or asyncio.get_event_loop()
@asyncio.coroutine
def ws_connect(self, url='ws://localhost:8182/', protocols=(),
autoclose=False, autoping=True, loop=None):
autoclose=False, autoping=True):
try:
return (yield from ws_connect(
return (yield from aiohttp.ws_connect(
url, protocols=protocols, connector=self._connector,
ws_response_class=self._ws_response_class, autoclose=True,
autoping=True, loop=loop))
except WSServerHandshakeError as e:
ws_response_class=GremlinClientWebSocketResponse,
autoclose=True, autoping=True, loop=self._loop))
except aiohttp.WSServerHandshakeError as e:
raise SocketClientError(e.message)
......@@ -18,14 +18,13 @@ class WebSocketPool:
self.url = url
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
self._factory = factory or GremlinFactory(
connector=connector,
ws_response_class=ws_response_class)
self.poolsize = poolsize
self.max_retries = max_retries
self.timeout = timeout
self._connected = False
self._loop = loop or asyncio.get_event_loop()
self._factory = factory or GremlinFactory(connector=connector,
loop=self._loop)
self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
self.active_conns = set()
self.num_connecting = 0
......@@ -38,9 +37,7 @@ class WebSocketPool:
tasks = []
poolsize = self.poolsize
for i in range(poolsize):
coro = self.factory.ws_connect(
self.url,
loop=self._loop)
coro = self.factory.ws_connect(self.url)
task = asyncio.async(coro, loop=self._loop)
tasks.append(task)
for f in asyncio.as_completed(tasks, loop=self._loop):
......@@ -72,6 +69,10 @@ class WebSocketPool:
@asyncio.coroutine
def close(self):
try:
self._factory.close()
except AttributeError:
pass
if not self._closed:
if self.active_conns:
yield from self._close_active_conns()
......@@ -109,9 +110,7 @@ class WebSocketPool:
else:
self.num_connecting += 1
try:
socket = yield from self.factory.ws_connect(
url,
loop=loop)
socket = yield from self.factory.ws_connect(url)
finally:
self.num_connecting -= 1
if not socket.closed:
......
......@@ -12,7 +12,7 @@ setup(
long_description=open("README.txt").read(),
packages=["aiogremlin", "tests"],
install_requires=[
"aiohttp==0.15.3"
"aiohttp==0.16.0"
],
test_suite="tests",
classifiers=[
......
......@@ -5,10 +5,12 @@ import asyncio
import itertools
import unittest
import uuid
import aiohttp
from aiogremlin import (GremlinClient, RequestError, GremlinServerError,
SocketClientError, WebSocketPool, GremlinFactory,
create_client, GremlinWriter, GremlinResponse,
WebSocketSession)
GremlinClientWebSocketResponse)
class GremlinClientTests(unittest.TestCase):
......@@ -48,7 +50,7 @@ class GremlinClientPoolTests(unittest.TestCase):
asyncio.set_event_loop(None)
pool = WebSocketPool("ws://localhost:8182/", loop=self.loop)
self.gc = GremlinClient(url="ws://localhost:8182/",
factory=GremlinFactory(),
factory=GremlinFactory(loop=self.loop),
pool=pool,
loop=self.loop)
......@@ -142,7 +144,7 @@ class WebSocketPoolTests(unittest.TestCase):
poolsize=2,
timeout=1,
loop=self.loop,
factory=GremlinFactory())
factory=GremlinFactory(loop=self.loop))
def tearDown(self):
self.loop.run_until_complete(self.pool.close())
......@@ -239,7 +241,7 @@ class ContextMngrTest(unittest.TestCase):
self.pool = WebSocketPool("ws://localhost:8182/",
poolsize=1,
loop=self.loop,
factory=GremlinFactory(),
factory=GremlinFactory(loop=self.loop),
max_retries=0)
def tearDown(self):
......@@ -338,14 +340,18 @@ class GremlinClientPoolSessionTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
pool = WebSocketPool("ws://localhost:8182/",
loop=self.loop,
factory=WebSocketSession(loop=self.loop))
pool = WebSocketPool(
"ws://localhost:8182/",
loop=self.loop,
factory=aiohttp.ClientSession(
loop=self.loop,
ws_response_class=GremlinClientWebSocketResponse))
self.gc = GremlinClient("ws://localhost:8182/",
pool=pool,
loop=self.loop)
def tearDown(self):
self.gc._pool._factory.close()
self.loop.run_until_complete(self.gc.close())
self.loop.close()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment