From cd9c1afc785d8b235b69f8fd29c41a4f69c6e9ae Mon Sep 17 00:00:00 2001
From: davebshow <davebshow@gmail.com>
Date: Fri, 15 Jul 2016 17:25:10 -0400
Subject: [PATCH] finished basic tests

---
 goblin/app.py               |   4 +-
 goblin/driver/connection.py |   7 +-
 goblin/element.py           |   8 +++
 goblin/mapper.py            |   2 +-
 goblin/session.py           |  17 +++--
 goblin/traversal.py         |   4 +-
 tests/test_session.py       | 129 +++++++++++++++++++++++++++++++-----
 7 files changed, 139 insertions(+), 32 deletions(-)

diff --git a/goblin/app.py b/goblin/app.py
index 790e7ec..06e2c47 100644
--- a/goblin/app.py
+++ b/goblin/app.py
@@ -54,8 +54,8 @@ class Goblin:
         self._config = self.DEFAULT_CONFIG
         self._config.update(config)
         self._vertices = collections.defaultdict(
-            lambda: element.Vertex)
-        self._edges = collections.defaultdict(lambda: element.Edge)
+            lambda: element.GenericVertex)
+        self._edges = collections.defaultdict(lambda: element.GenericEdge)
 
     @property
     def vertices(self):
diff --git a/goblin/driver/connection.py b/goblin/driver/connection.py
index 4c772c4..1453998 100644
--- a/goblin/driver/connection.py
+++ b/goblin/driver/connection.py
@@ -181,6 +181,7 @@ class Connection(AbstractConnection):
         request_id = message['requestId']
         status_code = message['status']['code']
         data = message["result"]["data"]
+        msg = message["status"]["message"]
         response_queue = self._response_queues[request_id]
         if status_code == 407:
             await self._authenticate(self._username, self._password,
@@ -189,12 +190,10 @@ class Connection(AbstractConnection):
         else:
             if data:
                 for result in data:
-                    message = Message(status_code, result,
-                                      message["status"]["message"])
+                    message = Message(status_code, result, msg)
                     response_queue.put_nowait(message)
             else:
-                message = Message(status_code, data,
-                                  message["status"]["message"])
+                message = Message(status_code, data, msg)
                 response_queue.put_nowait(message)
             if status_code == 206:
                 self._loop.create_task(self.receive())
diff --git a/goblin/element.py b/goblin/element.py
index 42b310c..58139e0 100644
--- a/goblin/element.py
+++ b/goblin/element.py
@@ -43,6 +43,10 @@ class Vertex(Element):
     pass
 
 
+class GenericVertex(Vertex):
+    pass
+
+
 class Edge(Element):
     """Base class for user defined Edge classes"""
 
@@ -75,6 +79,10 @@ class Edge(Element):
     target = property(gettarget, settarget, deltarget)
 
 
+class GenericEdge(Edge):
+    pass
+
+
 class VertexPropertyDescriptor:
     """Descriptor that validates user property input and gets/sets properties
        as instance attributes."""
diff --git a/goblin/mapper.py b/goblin/mapper.py
index 1b8d09a..0017754 100644
--- a/goblin/mapper.py
+++ b/goblin/mapper.py
@@ -32,7 +32,7 @@ def map_vertex_to_ogm(result, element, *, mapping=None):
 
 def map_edge_to_ogm(result, element, *, mapping=None):
     """Map an edge returned by DB to OGM edge"""
-    for db_name, value in result.items():
+    for db_name, value in result.get('properties', {}).items():
         name, data_type = mapping.properties.get(db_name, (db_name, None))
         if data_type:
             value = data_type.to_ogm(value)
diff --git a/goblin/session.py b/goblin/session.py
index b9b6a99..d0f4dc8 100644
--- a/goblin/session.py
+++ b/goblin/session.py
@@ -7,6 +7,7 @@ import weakref
 from goblin import mapper
 from goblin import traversal
 from goblin.driver import connection, graph
+from goblin.element import GenericVertex
 
 
 logger = logging.getLogger(__name__)
@@ -88,8 +89,8 @@ class Session(connection.AbstractConnection):
                     current = self.app.vertices[label]()
                 else:
                     current = self.app.edges[label]()
-                    current.source = element.Vertex()
-                    current.target = element.Vertex()
+                    current.source = GenericVertex()
+                    current.target = GenericVertex()
             element = current.__mapping__.mapper_func(result, current)
             response_queue.put_nowait(element)
         response_queue.put_nowait(None)
@@ -155,12 +156,12 @@ class Session(connection.AbstractConnection):
 
     async def update_vertex(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        traversal = self.traversal().V(element.id)
-        traversal = await self._update_properties(element, traversal, props)
+        traversal = self.g.V(element.id)
+        return await self._update_properties(element, traversal, props)
 
     async def update_edge(self, element):
         props = mapper.map_props_to_db(element, element.__mapping__)
-        traversal = self.traversal().E(element.id)
+        traversal = self.g.E(element.id)
         return await self._update_properties(element, traversal, props)
 
     # Transaction support
@@ -223,7 +224,11 @@ class Session(connection.AbstractConnection):
                     ('k' + str(binding), k),
                     ('v' + str(binding), v))
             else:
-                await self.g.V(element.id).properties(
+                if element.__type__ == 'vertex':
+                    traversal_source = self.g.V(element.id)
+                else:
+                    traversal_source = self.g.E(element.id)
+                await traversal_source.properties(
                     ('k' + str(binding), k)).drop().one_or_none()
             binding += 1
         return traversal
diff --git a/goblin/traversal.py b/goblin/traversal.py
index 74c7b9a..8362c13 100644
--- a/goblin/traversal.py
+++ b/goblin/traversal.py
@@ -37,8 +37,10 @@ class GoblinTraversal(graph.AsyncGraphTraversal):
         return await self.next()
 
     async def one_or_none(self):
+        result = None
         async for msg in await self.next():
-            return msg
+            result = msg
+        return result
 
 
 class TraversalFactory:
diff --git a/tests/test_session.py b/tests/test_session.py
index 84e0a15..2bde914 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,5 +1,7 @@
 import pytest
 
+from goblin import element
+
 
 @pytest.mark.asyncio
 async def test_session_close(session):
@@ -156,32 +158,123 @@ class TestCreationApi:
             result = await session.g.E(rid).one_or_none()
             assert not result
 
-    def test_update_vertex(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_update_vertex(self, session, person):
+        async with session:
+            person.name = 'dave'
+            person.age = 35
+            result = await session.save(person)
+            assert result.age == 35
+            person.name = 'david'
+            person.age = None
+            result = await session.save(person)
+            assert result is person
+            assert result.name == 'david'
+            assert not result.age
 
-    def test_update_edge(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_update_edge(self, session, person_class, knows):
+        async with session:
+            dave = person_class()
+            leif = person_class()
+            knows.source = dave
+            knows.target = leif
+            knows.notes = 'online'
+            session.add(dave, leif)
+            await session.flush()
+            result = await session.save(knows)
+            assert knows.notes == 'online'
+            knows.notes = None
+            result = await session.save(knows)
+            assert result is knows
+            assert not result.notes
 
 
 class TestTraversalApi:
 
-    def test_all(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_traversal_source_generation(self, session, person_class,
+                                               knows_class):
+        async with session:
+            traversal = session.traversal(person_class)
+            assert repr(traversal) == 'g.V().hasLabel("person")'
+            traversal = session.traversal(knows_class)
+            assert repr(traversal) == 'g.E().hasLabel("knows")'
+
+
+    @pytest.mark.asyncio
+    async def test_all(self, session, person_class):
+        async with session:
+            dave = person_class()
+            leif = person_class()
+            jon = person_class()
+            session.add(dave, leif, jon)
+            await session.flush()
+            resp = await session.traversal(person_class).all()
+            results = []
+            async for msg in resp:
+                assert isinstance(msg, person_class)
+                results.append(msg)
+            assert len(results) > 2
+
+    @pytest.mark.asyncio
+    async def test_one_or_none_one(self, session, person_class):
+        async with session:
+            dave = person_class()
+            leif = person_class()
+            jon = person_class()
+            session.add(dave, leif, jon)
+            await session.flush()
+            resp = await session.traversal(person_class).one_or_none()
+            assert isinstance(resp, person_class)
 
-    def test_one_or_none_one(self):
-        pass
 
-    def test_one_or_none_none(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_one_or_none_none(self, session):
+        async with session:
+            none = await session.g.V().hasLabel(
+                'a very unlikey label').one_or_none()
+            assert not none
+
+    @pytest.mark.asyncio
+    async def test_vertex_deserialization(self, session, person_class):
+        async with session:
+            resp = await session.g.addV('person').property(
+                person_class.name, 'leif').property('birthplace', 'detroit').one_or_none()
+            assert isinstance(resp, person_class)
+            assert resp.name == 'leif'
+            assert resp.birthplace == 'detroit'
 
-    def test_vertex_deserialization(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_edge_desialization(self, session, knows_class):
+        async with session:
+            p1 = await session.g.addV('person').one_or_none()
+            p2 = await session.g.addV('person').one_or_none()
+            e1 = await session.g.V(p1.id).addE('knows').to(
+                session.g.V(p2.id)).property(
+                    knows_class.notes, 'somehow').property(
+                    'how_long', 1).one_or_none()
+            assert isinstance(e1, knows_class)
+            assert e1.notes == 'somehow'
+            assert e1.how_long == 1
 
-    def test_edge_desialization(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_unregistered_vertex_deserialization(self, session):
+        async with session:
+            dave = await session.g.addV(
+                'unregistered').property('name', 'dave').one_or_none()
+            assert isinstance(dave, element.GenericVertex)
+            assert dave.name == 'dave'
+            assert dave.__label__ == 'unregistered'
 
-    def test_unregistered_vertex_deserialization(self):
-        pass
 
-    def test_unregistered_edge_desialization(self):
-        pass
+    @pytest.mark.asyncio
+    async def test_unregistered_edge_desialization(self, session):
+        async with session:
+            p1 = await session.g.addV('person').one_or_none()
+            p2 = await session.g.addV('person').one_or_none()
+            e1 = await session.g.V(p1.id).addE('unregistered').to(
+                session.g.V(p2.id)).property('how_long', 1).one_or_none()
+            assert isinstance(e1, element.GenericEdge)
+            assert e1.how_long == 1
+            assert e1.__label__ == 'unregistered'
-- 
GitLab