diff --git a/goblin/api.py b/goblin/api.py
index 2771d0e3adfe3e113ea8f8582aef01488bd82cc4..27df06aefa5f1abc71addbcd48994624c7df2d39 100644
--- a/goblin/api.py
+++ b/goblin/api.py
@@ -153,10 +153,16 @@ class Session:
         await stream.close()
         return mapper_func(result.data[0], element, element.__mapping__)
 
+    async def get_vertex_by_id(self, element):
+        pass
+
     def _get_vertex_by_id(self, element):
         traversal = self.g.V(element.id)
         return query.parse_traversal(traversal)
 
+    async def get_edge_by_id(self, element):
+        pass
+
     def _get_edge_by_id(self, element):
         traversal = self.g.E(element.id)
         return query.parse_traversal(traversal)
diff --git a/goblin/query.py b/goblin/query.py
index 7c49515ca66db96626f8189c40bedbfac09f65bc..12af826aad46ce8d62ffe81894df1ec7a7af7d86 100644
--- a/goblin/query.py
+++ b/goblin/query.py
@@ -1,4 +1,5 @@
 """Query API and helpers"""
+from goblin import mapper
 
 
 def parse_traversal(traversal):
@@ -6,17 +7,21 @@ def parse_traversal(traversal):
     bindings = traversal.bindings
     return script, bindings
 
+
 class Query:
 
     def __init__(self, session, element_class):
         self._session = session
         self._engine = session.engine
+        self._element_class = element_class
         if element_class.__type__ == 'vertex':
             self._traversal = self._session.g.V().hasLabel(
                 element_class.__mapping__.label)
+            self._mapper = mapper.map_vertex_to_ogm
         elif element_class.__type__ == 'edge':
             self._traversal = self._session.g.E().hasLabel(
                 element_class.__mapping__.label)
+            self._mapper = mapper.map_edge_to_ogm
         else:
             raise Exception("unknown element type")
 
diff --git a/tests/test_engine.py b/tests/test_engine.py
index 5cdfacc14d07a7c90ea16414bc42bd895f74b670..4081dd6bf1726c94265f6edce035f0f112f9e2bc 100644
--- a/tests/test_engine.py
+++ b/tests/test_engine.py
@@ -50,12 +50,11 @@ class TestEngine(unittest.TestCase):
             leif.name = 'leifur'
             jon = TestVertex()
             jon.name = 'jonathan'
-            session.add(leif, jon)
             works_for = TestEdge()
             works_for.source = jon
             works_for.target = leif
             works_for.notes = 'zerofail'
-            session.add(works_for)
+            session.add(leif, jon, works_for)
             await session.flush()
             current = session._current[works_for.id]
             self.assertEqual(current.notes, 'zerofail')