Commit a216b4ae authored by davebshow's avatar davebshow
Browse files

params fixes

parent a717c621
......@@ -22,12 +22,105 @@ __all__ = ('WebSocketSession', 'GremlinFactory',
'GremlinClientWebSocketResponse')
class GremlinClientWebSocketResponse(ClientWebSocketResponse):
def __init__(self, reader, writer, protocol, response, timeout, autoclose,
autoping, loop):
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose,
autoping, loop)
self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop)
@property
def parser(self):
return self._parser
@property
def closed(self):
"""Required by ABC."""
return self._closed
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
do_close = self._close()
if do_close:
return True
while True:
try:
msg = yield from asyncio.wait_for(
self._reader.read(), self._timeout, loop=self._loop)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if msg.tp == MsgType.close:
self._close_code = msg.data
self._response.close(force=True)
return True
else:
return False
def _close(self, code=1000, message=b''):
self._closed = True
try:
self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if self._closing:
self._response.close(force=True)
return True
def send(self, message, binary=True):
if binary:
method = self.send_bytes
else:
method = self.send_str
try:
method(message)
except RuntimeError:
# Socket closed.
raise
except TypeError:
# Bytes/string input error.
raise
@asyncio.coroutine
def receive(self):
msg = yield from super().receive()
if msg.tp == MsgType.binary:
self.parser.feed_data(msg.data.decode())
elif msg.tp == MsgType.text:
self.parser.feed_data(msg.data.strip())
else:
if msg.tp == MsgType.close:
yield from ws.close()
elif msg.tp == MsgType.error:
raise msg[1]
elif msg.tp == MsgType.closed:
pass
# Basically cut and paste from aiohttp until merge/release of #374
class WebSocketSession(ClientSession):
def __init__(self, *, connector=None, loop=None,
cookies=None, headers=None, auth=None,
ws_response_class=None):
ws_response_class=GremlinClientWebSocketResponse):
super().__init__(connector=connector, loop=loop,
cookies=cookies, headers=headers, auth=auth)
......@@ -40,7 +133,6 @@ class WebSocketSession(ClientSession):
timeout=10.0,
autoclose=True,
autoping=True,
ws_response_class=None,
loop=None):
"""Initiate websocket connection."""
......@@ -91,11 +183,7 @@ class WebSocketSession(ClientSession):
reader = resp.connection.reader.set_parser(WebSocketParser)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
if ws_response_class is None:
ws_response_class = (self._ws_response_class or
ClientWebSocketResponse)
return ws_response_class(
return self._ws_response_class(
reader, writer, protocol, resp, timeout, autoclose, autoping, loop)
def detach(self):
......@@ -113,14 +201,16 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None,
asyncio.get_event_loop()
if connector is None:
connector = TCPConnector(loop=loop, force_close=True)
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
ws_session = WebSocketSession(loop=loop, connector=connector)
ws_session = WebSocketSession(loop=loop, connector=connector,
ws_response_class=ws_response_class)
try:
resp = yield from ws_session.ws_connect(
url,
protocols=protocols,
timeout=timeout,
ws_response_class=ws_response_class,
autoclose=autoclose,
autoping=autoping,
loop=loop)
......@@ -132,114 +222,19 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None,
class GremlinFactory:
def __init__(self, connector=None):
def __init__(self, connector=None, ws_response_class=None):
self._connector = connector
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
self._ws_response_class = ws_response_class
@asyncio.coroutine
def ws_connect(self, url='ws://localhost:8182/', protocols=(),
connector=None, autoclose=False, autoping=True,
ws_response_class=None, loop=None):
if connector is None:
connector = self._connector
if ws_response_class is None:
ws_response_class = GremlinClientWebSocketResponse
autoclose=False, autoping=True, loop=None):
try:
return (yield from ws_connect(
url, protocols=protocols, connector=connector,
ws_response_class=ws_response_class, autoclose=True,
url, protocols=protocols, connector=self._connector,
ws_response_class=self._ws_response_class, autoclose=True,
autoping=True, loop=loop))
except WSServerHandshakeError as e:
raise SocketClientError(e.message)
class GremlinClientWebSocketResponse(ClientWebSocketResponse):
def __init__(self, reader, writer, protocol, response, timeout, autoclose,
autoping, loop):
ClientWebSocketResponse.__init__(self, reader, writer, protocol,
response, timeout, autoclose,
autoping, loop)
self._parser = StreamParser(buf=DataQueue(loop=loop), loop=loop)
@property
def parser(self):
return self._parser
@property
def closed(self):
"""Required by ABC."""
return self._closed
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
do_close = self._close()
if do_close:
return True
while True:
try:
msg = yield from asyncio.wait_for(
self._reader.read(), self._timeout, loop=self._loop)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if msg.tp == MsgType.close:
self._close_code = msg.data
self._response.close(force=True)
return True
else:
return False
def _close(self, code=1000, message=b''):
self._closed = True
try:
self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close(force=True)
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close(force=True)
return True
if self._closing:
self._response.close(force=True)
return True
def send(self, message, binary=True):
if binary:
method = self.send_bytes
else:
method = self.send_str
try:
method(message)
except RuntimeError:
# Socket closed.
raise
except TypeError:
# Bytes/string input error.
raise
@asyncio.coroutine
def receive(self):
msg = yield from super().receive()
if msg.tp == MsgType.binary:
self.parser.feed_data(msg.data.decode())
elif msg.tp == MsgType.text:
self.parser.feed_data(msg.data.strip())
else:
if msg.tp == MsgType.close:
yield from ws.close()
elif msg.tp == MsgType.error:
raise msg[1]
elif msg.tp == MsgType.closed:
pass
......@@ -16,7 +16,11 @@ class WebSocketPool:
"""
"""
self.url = url
self._factory = factory or GremlinFactory(connector=connector)
self._ws_response_class = (ws_response_class or
GremlinClientWebSocketResponse)
self._factory = factory or GremlinFactory(
connector=connector,
ws_response_class=self._ws_response_class)
self.poolsize = poolsize
self.max_retries = max_retries
self.timeout = timeout
......@@ -25,8 +29,6 @@ class WebSocketPool:
self._pool = asyncio.Queue(maxsize=self.poolsize, loop=self._loop)
self.active_conns = set()
self.num_connecting = 0
self._response_class = (ws_response_class or
GremlinClientWebSocketResponse)
self._closed = False
if verbose:
logger.setLevel(INFO)
......@@ -38,7 +40,6 @@ class WebSocketPool:
for i in range(poolsize):
coro = self.factory.ws_connect(
self.url,
ws_response_class=self._response_class,
loop=self._loop)
task = asyncio.async(coro, loop=self._loop)
tasks.append(task)
......@@ -110,7 +111,6 @@ class WebSocketPool:
try:
socket = yield from self.factory.ws_connect(
url,
ws_response_class=self._response_class,
loop=loop)
finally:
self.num_connecting -= 1
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment