diff --git a/aiogremlin/client.py b/aiogremlin/client.py index 429c5c142117e1acc8d6273dc0f55350dc1ab86e..4f2139109b8f0d09f18d71c19f028600c71474a2 100644 --- a/aiogremlin/client.py +++ b/aiogremlin/client.py @@ -34,7 +34,7 @@ class GremlinClient: def __init__(self, *, url='http://localhost:8182/', loop=None, lang="gremlin-groovy", op="eval", processor="", - timeout=None, ws_connector=None, connector=None, + timeout=None, ws_connector=None, client_session=None, username="", password=""): self._lang = lang self._op = op @@ -46,10 +46,6 @@ class GremlinClient: self._timeout = timeout self._username = username self._password = password - if connector is None: - connector = aiohttp.TCPConnector(verify_ssl=False, loop=self._loop) - client_session = aiohttp.ClientSession(connector=connector, - loop=self._loop) if ws_connector is None: ws_connector = GremlinConnector(loop=self._loop, client_session=client_session) @@ -211,10 +207,12 @@ class GremlinClientSession(GremlinClient): def __init__(self, *, url='http://localhost:8182/', loop=None, lang="gremlin-groovy", op="eval", processor="session", - session=None, timeout=None, - ws_connector=None): + session=None, timeout=None, client_session=None, + ws_connector=None, username="", password=""): super().__init__(url=url, lang=lang, op=op, processor=processor, - loop=loop, timeout=timeout, ws_connector=ws_connector) + loop=loop, timeout=timeout, ws_connector=ws_connector, + client_session=client_session, username=username, + password=password) if session is None: session = str(uuid.uuid4()) @@ -287,7 +285,6 @@ class GremlinResponse: @asyncio.coroutine def _run(self): - import ipdb; ipdb.set_trace() results = [] while True: message = yield from self._stream.read() @@ -337,6 +334,8 @@ class GremlinResponseStream: writer = GremlinWriter(self._ws) writer.write(op="authentication", username=self._username, password=self._password) + asyncio.Task(self._ws.receive(), loop=self._loop) + message = yield from self._stream.read() except (RequestError, GremlinServerError): yield from self._ws.release() raise diff --git a/aiogremlin/subprotocol.py b/aiogremlin/subprotocol.py index 3bfb1756d127a61e9a8ac83f0be852894b235dbf..8b650a4f287d3cf3bc094a3610654d16e0bc83a6 100644 --- a/aiogremlin/subprotocol.py +++ b/aiogremlin/subprotocol.py @@ -1,5 +1,6 @@ """Implements the Gremlin Server subprotocol.""" +import base64 import collections import uuid @@ -62,12 +63,13 @@ class GremlinWriter: session) if op == "authentication": - message = self._authenticate(username, password, session, processor) + message = self._authenticate( + username, password, session, processor) message = json.dumps(message) if binary: message = self._set_message_header(message, mime_type) self.ws.send(message, binary=binary) - print(message) + # print(message) return self.ws @staticmethod @@ -100,16 +102,17 @@ class GremlinWriter: message["args"].update({"session": session}) return message + @staticmethod def _authenticate(username, password, session, processor): - auth_bytes = "".join(["0", username, "0", password]) - print(auth_bytes) + auth = b"".join([b"\x00", bytes(username, "utf-8"), b"\x00", bytes(password, "utf-8")]) + print("auth:",auth) message = { "requestId": str(uuid.uuid4()), "op": "authentication", "processor": processor, "args": { - "sasl": auth_bytes + "sasl": base64.b64encode(auth) } } if session is None: diff --git a/tests/tests.py b/tests/tests.py index fbff3c84adc618f999ae081c3beb8f9d6f43fa89..ae333979e36b5c88f2ce82ba437992ddc7404590 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -4,9 +4,10 @@ import asyncio import unittest import uuid - +import aiohttp from aiogremlin import (submit, GremlinConnector, GremlinClient, - GremlinClientSession, GremlinServerError) + GremlinClientSession, GremlinServerError, + GremlinClientWebSocketResponse) class SubmitTest(unittest.TestCase): @@ -22,220 +23,242 @@ class SubmitTest(unittest.TestCase): @asyncio.coroutine def go(): - resp = yield from submit("4 + 4", url='https://localhost:8182/', + resp = yield from submit("x + x", url='https://localhost:8182/', bindings={"x": 4}, loop=self.loop, username="stephen", password="password") results = yield from resp.get() return results + results = self.loop.run_until_complete(go()) + self.assertEqual(results[0].data[0], 8) + + def test_rebinding(self): + + @asyncio.coroutine + def go1(): + result = yield from submit("graph2.addVertex()", + url='https://localhost:8182/', + loop=self.loop, username="stephen", + password="password") + resp = yield from result.get() + + try: + self.loop.run_until_complete(go1()) + error = False + except GremlinServerError: + error = True + self.assertTrue(error) + + @asyncio.coroutine + def go2(): + result = yield from submit( + "graph2.addVertex()", url='https://localhost:8182/', + rebindings={"graph2": "graph"}, loop=self.loop, + username="stephen", password="password") + resp = yield from result.get() + self.assertEqual(len(resp), 1) + + try: + self.loop.run_until_complete(go2()) + except GremlinServerError: + print("RELEASE DOES NOT SUPPORT REBINDINGS") + + +class GremlinClientTest(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + connector = aiohttp.TCPConnector(force_close=False, loop=self.loop, + verify_ssl=False) + + client_session = aiohttp.ClientSession( + connector=connector, loop=self.loop, + ws_response_class=GremlinClientWebSocketResponse) + + self.gc = GremlinClient(url="https://localhost:8182/", loop=self.loop, + username="stephen", password="password", + client_session=client_session) + + def tearDown(self): + self.loop.run_until_complete(self.gc.close()) + self.loop.close() + + def test_connection(self): + + @asyncio.coroutine + def go(): + ws = yield from self.gc._connector.ws_connect(self.gc.url) + self.assertFalse(ws.closed) + yield from ws.close() + + self.loop.run_until_complete(go()) + + def test_execute(self): + + @asyncio.coroutine + def go(): + resp = yield from self.gc.execute("x + x", bindings={"x": 4}) + return resp results = self.loop.run_until_complete(go()) self.assertEqual(results[0].data[0], 8) -# def test_rebinding(self): -# -# @asyncio.coroutine -# def go1(): -# result = yield from submit("graph2.addVertex()", loop=self.loop) -# resp = yield from result.get() -# -# try: -# self.loop.run_until_complete(go1()) -# error = False -# except GremlinServerError: -# error = True -# self.assertTrue(error) -# -# @asyncio.coroutine -# def go2(): -# result = yield from submit( -# "graph2.addVertex()", rebindings={"graph2": "graph"}, -# loop=self.loop) -# resp = yield from result.get() -# self.assertEqual(len(resp), 1) -# -# try: -# self.loop.run_until_complete(go2()) -# except GremlinServerError: -# print("RELEASE DOES NOT SUPPORT REBINDINGS") -# -# -# class GremlinClientTest(unittest.TestCase): -# -# def setUp(self): -# self.loop = asyncio.new_event_loop() -# asyncio.set_event_loop(None) -# self.gc = GremlinClient(url="http://localhost:8182/", loop=self.loop) -# -# def tearDown(self): -# self.loop.run_until_complete(self.gc.close()) -# self.loop.close() -# -# def test_connection(self): -# -# @asyncio.coroutine -# def go(): -# ws = yield from self.gc._connector.ws_connect(self.gc.url) -# self.assertFalse(ws.closed) -# yield from ws.close() -# -# self.loop.run_until_complete(go()) -# -# def test_execute(self): -# -# @asyncio.coroutine -# def go(): -# resp = yield from self.gc.execute("x + x", bindings={"x": 4}) -# return resp -# -# results = self.loop.run_until_complete(go()) -# self.assertEqual(results[0].data[0], 8) -# -# def test_sub_waitfor(self): -# sub1 = self.gc.execute("x + x", bindings={"x": 1}) -# sub2 = self.gc.execute("x + x", bindings={"x": 2}) -# sub3 = self.gc.execute("x + x", bindings={"x": 4}) -# coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop), -# asyncio.async(sub2, loop=self.loop), -# asyncio.async(sub3, loop=self.loop)], -# loop=self.loop) -# # Here I am looking for resource warnings. -# results = self.loop.run_until_complete(coro) -# self.assertIsNotNone(results) -# -# def test_resp_stream(self): -# @asyncio.coroutine -# def stream_coro(): -# results = [] -# resp = yield from self.gc.submit("x + x", bindings={"x": 4}) -# while True: -# f = yield from resp.stream.read() -# if f is None: -# break -# results.append(f) -# self.assertEqual(results[0].data[0], 8) -# self.loop.run_until_complete(stream_coro()) -# -# def test_execute_error(self): -# execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4}) -# try: -# self.loop.run_until_complete(execute) -# error = False -# except: -# error = True -# self.assertTrue(error) -# -# def test_rebinding(self): -# execute = self.gc.execute("graph2.addVertex()") -# try: -# self.loop.run_until_complete(execute) -# error = False -# except GremlinServerError: -# error = True -# self.assertTrue(error) -# -# @asyncio.coroutine -# def go(): -# result = yield from self.gc.execute( -# "graph2.addVertex()", rebindings={"graph2": "graph"}) -# self.assertEqual(len(result), 1) -# -# try: -# self.loop.run_until_complete(go()) -# except GremlinServerError: -# print("RELEASE DOES NOT SUPPORT REBINDINGS") -# -# -# class GremlinClientSessionTest(unittest.TestCase): -# -# def setUp(self): -# self.loop = asyncio.new_event_loop() -# asyncio.set_event_loop(None) -# self.gc = GremlinClientSession(url="http://localhost:8182/", -# loop=self.loop) -# self.script1 = """v = graph.addVertex('name', 'Dave')""" -# -# self.script2 = "v.property('name')" -# -# def tearDown(self): -# self.loop.run_until_complete(self.gc.close()) -# self.loop.close() -# -# def test_session(self): -# -# @asyncio.coroutine -# def go(): -# yield from self.gc.execute(self.script1) -# results = yield from self.gc.execute(self.script2) -# return results -# -# results = self.loop.run_until_complete(go()) -# self.assertEqual(results[0].data[0]['value'], 'Dave') -# -# def test_session_reset(self): -# -# @asyncio.coroutine -# def go(): -# yield from self.gc.execute(self.script1) -# self.gc.reset_session() -# results = yield from self.gc.execute(self.script2) -# return results -# try: -# results = self.loop.run_until_complete(go()) -# error = False -# except GremlinServerError: -# error = True -# self.assertTrue(error) -# -# def test_session_manual_reset(self): -# -# @asyncio.coroutine -# def go(): -# yield from self.gc.execute(self.script1) -# new_sess = str(uuid.uuid4()) -# sess = self.gc.reset_session(session=new_sess) -# self.assertEqual(sess, new_sess) -# self.assertEqual(self.gc.session, new_sess) -# results = yield from self.gc.execute(self.script2) -# return results -# try: -# results = self.loop.run_until_complete(go()) -# error = False -# except GremlinServerError: -# error = True -# self.assertTrue(error) -# -# def test_session_set(self): -# -# @asyncio.coroutine -# def go(): -# yield from self.gc.execute(self.script1) -# new_sess = str(uuid.uuid4()) -# self.gc.session = new_sess -# self.assertEqual(self.gc.session, new_sess) -# results = yield from self.gc.execute(self.script2) -# return results -# try: -# results = self.loop.run_until_complete(go()) -# error = False -# except GremlinServerError: -# error = True -# self.assertTrue(error) -# -# def test_resp_session(self): -# -# @asyncio.coroutine -# def go(): -# session = str(uuid.uuid4()) -# self.gc.session = session -# resp = yield from self.gc.submit("x + x", bindings={"x": 4}) -# while True: -# f = yield from resp.stream.read() -# if f is None: -# break -# self.assertEqual(resp.session, session) -# -# self.loop.run_until_complete(go()) -# + def test_sub_waitfor(self): + sub1 = self.gc.execute("x + x", bindings={"x": 1}) + sub2 = self.gc.execute("x + x", bindings={"x": 2}) + sub3 = self.gc.execute("x + x", bindings={"x": 4}) + coro = asyncio.gather(*[asyncio.async(sub1, loop=self.loop), + asyncio.async(sub2, loop=self.loop), + asyncio.async(sub3, loop=self.loop)], + loop=self.loop) + # Here I am looking for resource warnings. + results = self.loop.run_until_complete(coro) + self.assertIsNotNone(results) + + def test_resp_stream(self): + @asyncio.coroutine + def stream_coro(): + results = [] + resp = yield from self.gc.submit("x + x", bindings={"x": 4}) + while True: + f = yield from resp.stream.read() + if f is None: + break + results.append(f) + self.assertEqual(results[0].data[0], 8) + self.loop.run_until_complete(stream_coro()) + + def test_execute_error(self): + execute = self.gc.execute("x + x g.asdfas", bindings={"x": 4}) + try: + self.loop.run_until_complete(execute) + error = False + except: + error = True + self.assertTrue(error) + + def test_rebinding(self): + execute = self.gc.execute("graph2.addVertex()") + try: + self.loop.run_until_complete(execute) + error = False + except GremlinServerError: + error = True + self.assertTrue(error) + + @asyncio.coroutine + def go(): + result = yield from self.gc.execute( + "graph2.addVertex()", rebindings={"graph2": "graph"}) + self.assertEqual(len(result), 1) + + try: + self.loop.run_until_complete(go()) + except GremlinServerError: + print("RELEASE DOES NOT SUPPORT REBINDINGS") + + +class GremlinClientSessionTest(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + connector = aiohttp.TCPConnector(force_close=False, loop=self.loop, + verify_ssl=False) + + client_session = aiohttp.ClientSession( + connector=connector, loop=self.loop, + ws_response_class=GremlinClientWebSocketResponse) + + self.gc = GremlinClientSession(url="https://localhost:8182/", + loop=self.loop, + username="stephen", password="password", + client_session=client_session) + + self.script1 = """v=graph.addVertex('name', 'Dave')""" + + self.script2 = "v.property('name')" + + def tearDown(self): + self.loop.run_until_complete(self.gc.close()) + self.loop.close() + + def test_session(self): + + @asyncio.coroutine + def go(): + yield from self.gc.execute(self.script1) + results = yield from self.gc.execute(self.script2) + return results + + results = self.loop.run_until_complete(go()) + self.assertEqual(results[0].data[0]['value'], 'Dave') + + # def test_session_reset(self): + # + # @asyncio.coroutine + # def go(): + # yield from self.gc.execute(self.script1) + # self.gc.reset_session() + # results = yield from self.gc.execute(self.script2) + # return results + # try: + # results = self.loop.run_until_complete(go()) + # error = False + # except GremlinServerError: + # error = True + # self.assertTrue(error) + # + # def test_session_manual_reset(self): + # + # @asyncio.coroutine + # def go(): + # yield from self.gc.execute(self.script1) + # new_sess = str(uuid.uuid4()) + # sess = self.gc.reset_session(session=new_sess) + # self.assertEqual(sess, new_sess) + # self.assertEqual(self.gc.session, new_sess) + # results = yield from self.gc.execute(self.script2) + # return results + # try: + # results = self.loop.run_until_complete(go()) + # error = False + # except GremlinServerError: + # error = True + # self.assertTrue(error) + # + # def test_session_set(self): + # + # @asyncio.coroutine + # def go(): + # yield from self.gc.execute(self.script1) + # new_sess = str(uuid.uuid4()) + # self.gc.session = new_sess + # self.assertEqual(self.gc.session, new_sess) + # results = yield from self.gc.execute(self.script2) + # return results + # try: + # results = self.loop.run_until_complete(go()) + # error = False + # except GremlinServerError: + # error = True + # self.assertTrue(error) + # + # def test_resp_session(self): + # + # @asyncio.coroutine + # def go(): + # session = str(uuid.uuid4()) + # self.gc.session = session + # resp = yield from self.gc.submit("x + x", bindings={"x": 4}) + # while True: + # f = yield from resp.stream.read() + # if f is None: + # break + # self.assertEqual(resp.session, session) + # + # self.loop.run_until_complete(go()) + # if __name__ == "__main__": unittest.main()