[dataimport] Uniformize the API across the different stores.
authorVladimir Popescu <vladimir.popescu@logilab.fr>
Tue, 02 Apr 2013 12:11:44 +0200
changeset 8832 26cdfc6dd6f8
parent 8830 7fd6c52ef878
child 8833 39f81e2db2fc
[dataimport] Uniformize the API across the different stores. This is achieved by modifying the ``relate`` method so that it takes an extra ``**kwargs``. More specifically, ``SQLGenObjectStore``'s ``relate`` method needs the the type of the subject entity which is passed through ``**kwargs`` as the ``subjtype`` keyword argument. Actually, it is the ``add_relation`` method of the ``SQLGenObjectStore`` who needs this argument. However, as this method is not called directly (but via the ``relate`` method), the ``subjtype`` argument is passed to ``add_relation`` via ``relate``. The other stores' ``relate`` methods do not need this extra argument, hence for the other stores ``**kwargs`` is empty. In this manner, the API is unified across the different stores.
dataimport.py
--- a/dataimport.py	Mon Dec 17 14:03:56 2012 +0100
+++ b/dataimport.py	Tue Apr 02 12:11:44 2013 +0200
@@ -72,6 +72,7 @@
 import traceback
 import cPickle
 import os.path as osp
+import inspect
 from collections import defaultdict
 from contextlib import contextmanager
 from copy import copy
@@ -323,7 +324,6 @@
     return [(k, len(v)) for k, v in buckets.items()
             if k is not None and len(v) > 1]
 
-
 # sql generator utility functions #############################################
 
 
@@ -506,7 +506,7 @@
         item['eid'] = data['eid']
         return item
 
-    def relate(self, eid_from, rtype, eid_to, inlined=False):
+    def relate(self, eid_from, rtype, eid_to, **kwargs):
         """Add new relation"""
         relation = eid_from, rtype, eid_to
         self.relations.add(relation)
@@ -583,9 +583,9 @@
                                       for k in item)
         return self.rql(query, item)[0][0]
 
-    def relate(self, eid_from, rtype, eid_to, inlined=False):
+    def relate(self, eid_from, rtype, eid_to, **kwargs):
         eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
-            eid_from, rtype, eid_to)
+            eid_from, rtype, eid_to, **kwargs)
         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                  {'x': int(eid_from), 'y': int(eid_to)})
 
@@ -751,20 +751,23 @@
         session = self.session
         self.source.add_entity(session, entity)
         self.source.add_info(session, entity, self.source, None, complete=False)
+        kwargs = dict()
+        if inspect.getargspec(self.add_relation).keywords:
+            kwargs['subjtype'] = entity.__regid__
         for rtype, targeteids in rels.iteritems():
             # targeteids may be a single eid or a list of eids
             inlined = self.rschema(rtype).inlined
             try:
                 for targeteid in targeteids:
                     self.add_relation(session, entity.eid, rtype, targeteid,
-                                      inlined)
+                                      inlined, **kwargs)
             except TypeError:
                 self.add_relation(session, entity.eid, rtype, targeteids,
-                                  inlined)
+                                  inlined, **kwargs)
         self._nb_inserted_entities += 1
         return entity
 
-    def relate(self, eid_from, rtype, eid_to):
+    def relate(self, eid_from, rtype, eid_to, **kwargs):
         assert not rtype.startswith('reverse_')
         self.add_relation(self.session, eid_from, rtype, eid_to,
                           self.rschema(rtype).inlined)
@@ -888,12 +891,12 @@
         """Flush data to the database"""
         self.source.flush()
 
-    def relate(self, subj_eid, rtype, obj_eid, subjtype=None):
+    def relate(self, subj_eid, rtype, obj_eid, **kwargs):
         if subj_eid is None or obj_eid is None:
             return
         # XXX Could subjtype be inferred ?
         self.source.add_relation(self.session, subj_eid, rtype, obj_eid,
-                                 self.rschema(rtype).inlined, subjtype)
+                                 self.rschema(rtype).inlined, **kwargs)
 
     def drop_indexes(self, etype):
         """Drop indexes for a given entity type"""
@@ -1012,13 +1015,13 @@
             _relations_sql.clear()
             _insertdicts.clear()
             _inlined_relations_sql.clear()
-            print 'flush done'
 
     def add_relation(self, session, subject, rtype, object,
-                     inlined=False, subjtype=None):
+                     inlined=False, **kwargs):
         if inlined:
             _sql = self._sql.inlined_relations
             data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
+            subjtype = kwargs.get('subjtype')
             if subjtype is None:
                 # Try to infer it
                 targets = [t.type for t in