protocol.py 3.25 KB
Newer Older
davebshow's avatar
davebshow committed
1
2
3
4
5
6
7
8
9
10
11
import base64
import collections
import logging

import aiohttp

try:
    import ujson as json
except ImportError:
    import json

12
from gremlin_python.driver import protocol, request, serializer
davebshow's avatar
davebshow committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26


__author__ = 'David M. Brown (davebshow@gmail.com)'


logger = logging.getLogger(__name__)


Message = collections.namedtuple(
    "Message",
    ["status_code", "data", "message"])


class GremlinServerWSProtocol(protocol.AbstractBaseProtocol):
davebshow's avatar
davebshow committed
27
    """Implemenation of the Gremlin Server Websocket protocol"""
davebshow's avatar
davebshow committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(self, message_serializer, username='', password=''):
        if isinstance(message_serializer, type):
            message_serializer = message_serializer()
        self._message_serializer = message_serializer
        self._username = username
        self._password = password

    def connection_made(self, transport):
        self._transport = transport

    def write(self, request_id, request_message):
        message = self._message_serializer.serialize_message(
            request_id, request_message)
        self._transport.write(message)

    async def data_received(self, data, results_dict):
davebshow's avatar
davebshow committed
44
        if data.tp == aiohttp.WSMsgType.close:
davebshow's avatar
davebshow committed
45
            await self._transport.close()
davebshow's avatar
davebshow committed
46
        elif data.tp == aiohttp.WSMsgType.error:
davebshow's avatar
davebshow committed
47
48
            # This won't raise properly, fix
            raise data.data
davebshow's avatar
davebshow committed
49
        elif data.tp == aiohttp.WSMsgType.closed:
davebshow's avatar
davebshow committed
50
51
52
            # Hmm
            pass
        else:
davebshow's avatar
davebshow committed
53
            if data.tp == aiohttp.WSMsgType.binary:
davebshow's avatar
davebshow committed
54
                data = data.data.decode()
davebshow's avatar
davebshow committed
55
            elif data.tp == aiohttp.WSMsgType.text:
davebshow's avatar
davebshow committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                data = data.data.strip()
            message = json.loads(data)
            request_id = message['requestId']
            status_code = message['status']['code']
            data = message['result']['data']
            msg = message['status']['message']
            if request_id in results_dict:
                result_set = results_dict[request_id]
                aggregate_to = message['result']['meta'].get('aggregateTo',
                                                             'list')
                result_set.aggregate_to = aggregate_to
                if status_code == 407:
                    auth = b''.join([b'\x00', self._username.encode('utf-8'),
                                     b'\x00', self._password.encode('utf-8')])
                    request_message = request.RequestMessage(
                        'traversal', 'authentication',
                        {'sasl': base64.b64encode(auth).decode()})
                    self.write(request_id, request_message)
                elif status_code == 204:
                    result_set.queue_result(None)
                else:
                    if data:
                        for result in data:
                            result = self._message_serializer.deserialize_message(result)
                            message = Message(status_code, result, msg)
                            result_set.queue_result(message)
                    else:
                        data = self._message_serializer.deserialize_message(data)
                        message = Message(status_code, data, msg)
                        result_set.queue_result(message)
                    if status_code != 206:
                        result_set.queue_result(None)