session.py 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2016 ZEROFAIL
#
# This file is part of Goblin.
#
# Goblin is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Goblin is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Goblin.  If not, see <http://www.gnu.org/licenses/>.

davebshow's avatar
davebshow committed
18
"""Main OGM API classes and constructors"""
19

20
import asyncio
davebshow's avatar
davebshow committed
21
22
import collections
import logging
davebshow's avatar
davebshow committed
23
import weakref
davebshow's avatar
davebshow committed
24

davebshow's avatar
davebshow committed
25
from goblin import exception, mapper, traversal
26
from goblin.driver import connection, graph
davebshow's avatar
davebshow committed
27
from goblin.element import GenericVertex
davebshow's avatar
davebshow committed
28
29
30
31
32


logger = logging.getLogger(__name__)


33
class Session(connection.AbstractConnection):
34
35
36
37
38
39
40
41
42
    """
    Provides the main API for interacting with the database. Does not
    necessarily correpsond to a database session. Don't instantiate directly,
    instead use :py:meth:`Goblin.session<goblin.app.Goblin.session>`.

    :param goblin.app.Goblin app:
    :param goblin.driver.connection conn:
    :param bool use_session: Support for Gremlin Server session. Not implemented
    """
davebshow's avatar
davebshow committed
43

44
45
    def __init__(self, app, conn, get_hashable_id, *, use_session=False,
                 aliases=None):
46
47
48
        self._app = app
        self._conn = conn
        self._loop = self._app._loop
davebshow's avatar
davebshow committed
49
        self._use_session = False
50
        self._aliases = aliases or dict()
davebshow's avatar
davebshow committed
51
        self._pending = collections.deque()
davebshow's avatar
davebshow committed
52
        self._current = weakref.WeakValueDictionary()
53
        self._get_hashable_id = get_hashable_id
54
55
56
57
        remote_graph = graph.AsyncRemoteGraph(
            self._app.translator, self,
            graph_traversal=traversal.GoblinTraversal)
        self._traversal_factory = traversal.TraversalFactory(remote_graph)
davebshow's avatar
davebshow committed
58
59

    @property
60
61
62
63
64
65
    def app(self):
        return self._app

    @property
    def conn(self):
        return self._conn
davebshow's avatar
davebshow committed
66
67

    @property
68
69
    def traversal_factory(self):
        return self._traversal_factory
davebshow's avatar
davebshow committed
70
71
72
73
74

    @property
    def current(self):
        return self._current

75
76
77
    async def __aenter__(self):
        return self

davebshow's avatar
davebshow committed
78
    async def __aexit__(self, exc_type, exc, tb):
79
80
81
        await self.close()

    async def close(self):
82
83
84
85
        """
        Close the underlying db connection and disconnect session from Goblin
        application.
        """
86
87
88
89
90
91
        await self.conn.close()
        self._app = None

    # Traversal API
    @property
    def g(self):
92
93
94
95
96
97
98
        """
        Get a simple traversal source.

        :returns:
            :py:class:`goblin.gremlin_python.process.GraphTraversalSource`
            object
        """
99
100
101
        return self.traversal_factory.traversal()

    def traversal(self, element_class):
102
103
104
105
106
107
108
109
110
        """
        Get a traversal spawned from an element class.

        :param :goblin.element.Element element_class: Element class
            used to spawn traversal.

        :returns: :py:class:`GoblinTraversal<goblin.traversal.GoblinTraversal>`
            object
        """
111
112
113
114
115
116
117
        return self.traversal_factory.traversal(element_class=element_class)

    async def submit(self,
                    gremlin,
                    *,
                    bindings=None,
                    lang='gremlin-groovy'):
118
119
120
121
122
123
124
125
126
127
128
129
        """
        Submit a query to the Gremiln Server.

        :param str gremlin: Gremlin script to submit to server.
        :param dict bindings: A mapping of bindings for Gremlin script.
        :param str lang: Language of scripts submitted to the server.
            "gremlin-groovy" by default

        :returns:
            :py:class:`TraversalResponse<goblin.traversal.TraversalResponse>`
            object
        """
130
        await self.flush()
131
        async_iter = await self.conn.submit(
132
            gremlin, bindings=bindings, lang=lang, aliases=self._aliases)
133
134
135
136
137
138
        response_queue = asyncio.Queue(loop=self._loop)
        self._loop.create_task(
            self._receive(async_iter, response_queue))
        return traversal.TraversalResponse(response_queue)

    async def _receive(self, async_iter, response_queue):
139
        async for result in async_iter:
140
141
            if (isinstance(result, dict) and
                    result.get('type', '') in ['vertex', 'edge']):
142
143
                hashable_id = self._get_hashable_id(result['id'])
                current = self.current.get(hashable_id, None)
144
145
146
147
148
149
150
151
152
153
154
155
156
                if not current:
                    element_type = result['type']
                    label = result['label']
                    if element_type == 'vertex':
                        current = self.app.vertices[label]()
                    else:
                        current = self.app.edges[label]()
                        current.source = GenericVertex()
                        current.target = GenericVertex()
                element = current.__mapping__.mapper_func(result, current)
                response_queue.put_nowait(element)
            else:
                response_queue.put_nowait(result)
157
158
159
        response_queue.put_nowait(None)

    # Creation API
davebshow's avatar
davebshow committed
160
    def add(self, *elements):
161
162
163
164
165
        """
        Add elements to session pending queue.

        :param goblin.element.Element elements: Elements to be added
        """
davebshow's avatar
davebshow committed
166
167
168
169
        for elem in elements:
            self._pending.append(elem)

    async def flush(self):
170
171
172
173
        """
        Issue creation/update queries to database for all elements in the
        session pending queue.
        """
davebshow's avatar
davebshow committed
174
175
176
177
        while self._pending:
            elem = self._pending.popleft()
            await self.save(elem)

178
179
180
181
182
183
184
185
    async def remove_vertex(self, vertex):
        """
        Remove a vertex from the db.

        :param goblin.element.Vertex vertex: Vertex to be removed
        """
        traversal = self.traversal_factory.remove_vertex(vertex)
        result = await self._simple_traversal(traversal, vertex)
186
187
        hashable_id = self._get_hashable_id(vertex.id)
        vertex = self.current.pop(hashable_id)
188
        del vertex
189
        return result
190

191
192
193
194
195
196
197
198
    async def remove_edge(self, edge):
        """
        Remove an edge from the db.

        :param goblin.element.Edge edge: Element to be removed
        """
        traversal = self.traversal_factory.remove_edge(edge)
        result = await self._simple_traversal(traversal, edge)
199
200
        hashable_id = self._get_hashable_id(edge.id)
        edge = self.current.pop(hashable_id)
201
        del edge
202
        return result
davebshow's avatar
davebshow committed
203

davebshow's avatar
davebshow committed
204
    async def save(self, elem):
205
206
207
208
209
210
211
        """
        Save an element to the db.

        :param goblin.element.Element element: Vertex or Edge to be saved

        :returns: :py:class:`Element<goblin.element.Element>` object
        """
davebshow's avatar
davebshow committed
212
213
214
215
        if elem.__type__ == 'vertex':
            result = await self.save_vertex(elem)
        elif elem.__type__ == 'edge':
            result = await self.save_edge(elem)
davebshow's avatar
davebshow committed
216
        else:
davebshow's avatar
davebshow committed
217
            raise exception.ElementError(
davebshow's avatar
davebshow committed
218
                "Unknown element type: {}".format(elem.__type__))
davebshow's avatar
davebshow committed
219
220
        return result

221
222
223
224
225
226
    async def save_vertex(self, vertex):
        """
        Save a vertex to the db.

        :param goblin.element.Vertex element: Vertex to be saved

davebshow's avatar
davebshow committed
227
        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
228
        """
229
        result = await self._save_element(
230
            vertex, self._check_vertex,
davebshow's avatar
davebshow committed
231
            self._add_vertex,
davebshow's avatar
davebshow committed
232
            self.update_vertex)
233
234
        hashable_id = self._get_hashable_id(result.id)
        self.current[hashable_id] = result
davebshow's avatar
davebshow committed
235
236
        return result

237
238
239
240
241
242
243
244
245
    async def save_edge(self, edge):
        """
        Save an edge to the db.

        :param goblin.element.Edge element: Edge to be saved

        :returns: :py:class:`Edge<goblin.element.Edge>` object
        """
        if not (hasattr(edge, 'source') and hasattr(edge, 'target')):
davebshow's avatar
davebshow committed
246
247
            raise exception.ElementError(
                "Edges require both source/target vertices")
248
        result = await self._save_element(
249
            edge, self._check_edge,
davebshow's avatar
davebshow committed
250
            self._add_edge,
davebshow's avatar
davebshow committed
251
            self.update_edge)
252
253
        hashable_id = self._get_hashable_id(result.id)
        self.current[hashable_id] = result
davebshow's avatar
davebshow committed
254
255
        return result

256
257
258
259
260
261
262
263
    async def get_vertex(self, vertex):
        """
        Get a vertex from the db. Vertex must have id.

        :param goblin.element.Vertex element: Vertex to be retrieved

        :returns: :py:class:`Vertex<goblin.element.Vertex>` | None
        """
264
        return await self.traversal_factory.get_vertex_by_id(
265
266
267
268
269
            vertex).one_or_none()

    async def get_edge(self, edge):
        """
        Get a edge from the db. Edge must have id.
270

271
272
273
274
        :param goblin.element.Edge element: Edge to be retrieved

        :returns: :py:class:`Edge<goblin.element.Edge>` | None
        """
275
        return await self.traversal_factory.get_edge_by_id(
276
277
278
279
280
281
282
283
284
285
286
            edge).one_or_none()

    async def update_vertex(self, vertex):
        """
        Update a vertex, generally to change/remove property values.

        :param goblin.element.Vertex vertex: Vertex to be updated

        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
        """
        props = mapper.map_props_to_db(vertex, vertex.__mapping__)
287
        # vert_props = mapper.map_vert_props_to_db
288
        traversal = self.g.V(vertex.id)
289
        return await self._update_vertex_properties(vertex, traversal, props)
290
291
292
293

    async def update_edge(self, edge):
        """
        Update an edge, generally to change/remove property values.
294

295
        :param goblin.element.Edge edge: Edge to be updated
davebshow's avatar
davebshow committed
296

297
298
299
300
        :returns: :py:class:`Edge<goblin.element.Edge>` object
        """
        props = mapper.map_props_to_db(edge, edge.__mapping__)
        traversal = self.g.E(edge.id)
301
        return await self._update_edge_properties(edge, traversal, props)
davebshow's avatar
davebshow committed
302

303
304
    # Transaction support
    def tx(self):
305
        """Not implemented"""
306
307
308
309
310
311
        raise NotImplementedError

    def _wrap_in_tx(self):
        raise NotImplementedError

    async def commit(self):
312
        """Not implemented"""
313
314
315
316
317
318
319
320
321
322
323
        await self.flush()
        if self.engine._features['transactions'] and self._use_session():
            await self.tx()
        raise NotImplementedError

    async def rollback(self):
        raise NotImplementedError

    # *metodos especiales privados for creation API
    async def _simple_traversal(self, traversal, element):
        stream = await self.conn.submit(
324
325
            repr(traversal), bindings=traversal.bindings,
            aliases=self._aliases)
326
        msg = await stream.fetch_data()
327
328
        if msg:
            msg = element.__mapping__.mapper_func(msg, element)
329
330
            return msg

davebshow's avatar
davebshow committed
331
    async def _save_element(self,
davebshow's avatar
davebshow committed
332
                            elem,
333
                            check_func,
davebshow's avatar
davebshow committed
334
                            create_func,
davebshow's avatar
davebshow committed
335
                            update_func):
davebshow's avatar
davebshow committed
336
337
338
339
        if hasattr(elem, 'id'):
            exists = await check_func(elem)
            if not exists:
                result = await create_func(elem)
davebshow's avatar
davebshow committed
340
            else:
davebshow's avatar
davebshow committed
341
                result = await update_func(elem)
davebshow's avatar
davebshow committed
342
        else:
davebshow's avatar
davebshow committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            result = await create_func(elem)
        return result

    async def _add_vertex(self, elem):
        """Convenience function for generating crud traversals."""
        props = mapper.map_props_to_db(elem, elem.__mapping__)
        traversal = self.g.addV(elem.__mapping__.label)
        traversal, _, metaprops = self.traversal_factory.add_properties(
            traversal, props)
        result = await self._simple_traversal(traversal, elem)
        if metaprops:
            await self._add_metaprops(result, metaprops)
            traversal = self.traversal_factory.get_vertex_by_id(elem)
            result = await self._simple_traversal(traversal, elem)
        return result

    async def _add_edge(self, elem):
        """Convenience function for generating crud traversals."""
        props = mapper.map_props_to_db(elem, elem.__mapping__)
        traversal = self.g.V(elem.source.id)
        traversal = traversal.addE(elem.__mapping__._label)
        traversal = traversal.to(
            self.g.V(elem.target.id))
        traversal, _, _ = self.traversal_factory.add_properties(
            traversal, props)
        return await self._simple_traversal(traversal, elem)
369

davebshow's avatar
davebshow committed
370
371
372
    async def _check_vertex(self, vertex):
        """Used to check for existence, does not update session vertex"""
        traversal = self.g.V(vertex.id)
373
        stream = await self.conn.submit(repr(traversal), aliases=self._aliases)
374
375
        return await stream.fetch_data()

davebshow's avatar
davebshow committed
376
377
378
    async def _check_edge(self, edge):
        """Used to check for existence, does not update session edge"""
        traversal = self.g.E(edge.id)
379
        stream = await self.conn.submit(repr(traversal), aliases=self._aliases)
380
        return await stream.fetch_data()
davebshow's avatar
davebshow committed
381

davebshow's avatar
davebshow committed
382
383
    async def _update_vertex_properties(self, vertex, traversal, props):
        traversal, removals, metaprops = self.traversal_factory.add_properties(
384
385
            traversal, props)
        for k in removals:
davebshow's avatar
davebshow committed
386
387
388
389
390
391
392
393
394
395
            await self.g.V(vertex.id).properties(k).drop().one_or_none()
        result = await self._simple_traversal(traversal, vertex)
        if metaprops:
            removals = await self._add_metaprops(result, metaprops)
            for db_name, key, value in removals:
                await self.g.V(vertex.id).properties(
                    db_name).has(key, value).drop().one_or_none()
            traversal = self.traversal_factory.get_vertex_by_id(vertex)
            result = await self._simple_traversal(traversal, vertex)
        return result
396

davebshow's avatar
davebshow committed
397
398
    async def _update_edge_properties(self, edge, traversal, props):
        traversal, removals, _ = self.traversal_factory.add_properties(
399
400
            traversal, props)
        for k in removals:
davebshow's avatar
davebshow committed
401
402
403
404
405
406
407
408
409
410
411
412
            await self.g.E(edge.id).properties(k).drop().one_or_none()
        return await self._simple_traversal(traversal, edge)

    async def _add_metaprops(self, result, metaprops):
        potential_removals = []
        for metaprop in metaprops:
            db_name, (binding, value), metaprops = metaprop
            for key, val in metaprops.items():
                if val:
                    traversal = self.g.V(result.id).properties(
                        db_name).hasValue(value).property(key, val)
                    stream = await self.conn.submit(
413
414
                        repr(traversal), bindings=traversal.bindings,
                        aliases=self._aliases)
davebshow's avatar
davebshow committed
415
416
417
418
                    await stream.fetch_data()
                else:
                    potential_removals.append((db_name, key, value))
        return potential_removals