dataimport.py
changeset 4912 9767cc516b4f
parent 4847 9466604ef448
child 4913 083b4d454192
--- a/dataimport.py	Tue Mar 16 10:54:59 2010 +0100
+++ b/dataimport.py	Wed Mar 10 16:07:24 2010 +0100
@@ -38,19 +38,15 @@
   GENERATORS.append( (gen_users, CHK) )
 
   # create controller
-  ctl = CWImportController(RQLObjectStore())
+  ctl = CWImportController(RQLObjectStore(cnx))
   ctl.askerror = 1
   ctl.generators = GENERATORS
-  ctl.store._checkpoint = checkpoint
-  ctl.store._rql = rql
   ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
   # run
   ctl.run()
-  sys.exit(0)
 
-
-.. BUG fichier à une colonne pose un problème de parsing
-.. TODO rollback()
+.. BUG file with one column are not parsable
+.. TODO rollback() invocation is not possible yet
 """
 __docformat__ = "restructuredtext en"
 
@@ -98,9 +94,9 @@
     for i, x in enumerate(it):
         yield x
         if nbit is not None and i % nbit:
-            store.checkpoint()
+            store.commit()
     if nbit is not None:
-        store.checkpoint()
+        store.commit()
 
 def lazytable(reader):
     """The first row is taken to be the header of the table and
@@ -115,24 +111,22 @@
 def mk_entity(row, map):
     """Return a dict made from sanitized mapped values.
 
-    ValidationError can be raised on unexpected values found in checkers
+    ValueError can be raised on unexpected values found in checkers
 
     >>> row = {'myname': u'dupont'}
-    >>> map = [('myname', u'name', (capitalize_if_unicase,))]
+    >>> map = [('myname', u'name', (call_transform_method('title'),))]
     >>> mk_entity(row, map)
     {'name': u'Dupont'}
     >>> row = {'myname': u'dupont', 'optname': u''}
-    >>> map = [('myname', u'name', (capitalize_if_unicase,)),
+    >>> map = [('myname', u'name', (call_transform_method('title'),)),
     ...        ('optname', u'MARKER', (optional,))]
     >>> mk_entity(row, map)
-    {'name': u'Dupont'}
+    {'name': u'Dupont', 'optname': None}
     """
     res = {}
     assert isinstance(row, dict)
     assert isinstance(map, list)
     for src, dest, funcs in map:
-        assert not (required in funcs and optional in funcs), \
-               "optional and required checks are exclusive"
         res[dest] = row[src]
         try:
             for func in funcs:
@@ -180,7 +174,22 @@
 # base sanitizing/coercing functions ###########################################
 
 def optional(value):
-    """validation error will not been raised if you add this checker in chain"""
+    """checker to filter optional field
+
+    If value is undefined (ex: empty string), return None that will
+    break the checkers validation chain
+
+    General use is to add 'optional' check in first condition to avoid
+    ValueError by further checkers
+
+    >>> MAPPER = [(u'value', 'value', (optional, int))]
+    >>> row = {'value': u'XXX'}
+    >>> mk_entity(row, MAPPER)
+    {'value': None}
+    >>> row = {'value': u'100'}
+    >>> mk_entity(row, MAPPER)
+    {'value': 100}
+    """
     if value:
         return value
     return None
@@ -254,7 +263,7 @@
         self.relations = set()
         self.indexes = {}
         self._rql = None
-        self._checkpoint = None
+        self._commit = None
 
     def _put(self, type, item):
         self.items.append(item)
@@ -267,79 +276,26 @@
         self.types.setdefault(type, []).append(eid)
 
     def relate(self, eid_from, rtype, eid_to, inlined=False):
-        """Add new relation (reverse type support is available)
-
-        >>> 1,2 = eid_from, eid_to
-        >>> self.relate(eid_from, 'in_group', eid_to)
-        1, 'in_group', 2
-        >>> self.relate(eid_from, 'reverse_in_group', eid_to)
-        2, 'in_group', 1
-        """
-        if rtype.startswith('reverse_'):
-            eid_from, eid_to = eid_to, eid_from
-            rtype = rtype[8:]
+        """Add new relation"""
         relation = eid_from, rtype, eid_to
         self.relations.add(relation)
         return relation
 
-    def build_index(self, name, type, func=None):
-        index = {}
-        if func is None or not callable(func):
-            func = lambda x: x['eid']
-        for eid in self.types[type]:
-            index.setdefault(func(self.eids[eid]), []).append(eid)
-        assert index, "new index '%s' cannot be empty" % name
-        self.indexes[name] = index
-
-    def build_rqlindex(self, name, type, key, rql, rql_params=False, func=None):
-        """build an index by rql query
-
-        rql should return eid in first column
-        ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
-        """
-        rset = self.rql(rql, rql_params or {})
-        for entity in rset.entities():
-            getattr(entity, key) # autopopulate entity with key attribute
-            self.eids[entity.eid] = dict(entity)
-            if entity.eid not in self.types.setdefault(type, []):
-                self.types[type].append(entity.eid)
-        assert self.types[type], "new index type '%s' cannot be empty (0 record found)" % type
+    def commit(self):
+        """this commit method do nothing by default
 
-        # Build index with specified key
-        func = lambda x: x[key]
-        self.build_index(name, type, func)
+        This is voluntary to use the frequent autocommit feature in CubicWeb
+        when you are using hooks or another
 
-    def fetch(self, name, key, unique=False, decorator=None):
-        """
-            decorator is a callable method or an iterator of callable methods (usually a lambda function)
-            decorator=lambda x: x[:1] (first value is returned)
-
-            We can use validation check function available in _entity
+        If you want override commit method, please set it by the
+        constructor
         """
-        eids = self.indexes[name].get(key, [])
-        if decorator is not None:
-            if not hasattr(decorator, '__iter__'):
-                decorator = (decorator,)
-            for f in decorator:
-                eids = f(eids)
-        if unique:
-            assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
-            eids = eids[0] # FIXME maybe it's better to keep an iterator here ?
-        return eids
-
-    def find(self, type, key, value):
-        for idx in self.types[type]:
-            item = self.items[idx]
-            if item[key] == value:
-                yield item
+        pass
 
     def rql(self, *args):
         if self._rql is not None:
             return self._rql(*args)
 
-    def checkpoint(self):
-        pass
-
     @property
     def nb_inserted_entities(self):
         return len(self.eids)
@@ -350,20 +306,81 @@
     def nb_inserted_relations(self):
         return len(self.relations)
 
-    @deprecated('[3.6] get_many() deprecated. Use fetch() instead')
-    def get_many(self, name, key):
-        return self.fetch(name, key, unique=False)
+    @deprecated("[3.7] index support will disappear")
+    def build_index(self, name, type, func=None, can_be_empty=False):
+        """build internal index for further search"""
+        index = {}
+        if func is None or not callable(func):
+            func = lambda x: x['eid']
+        for eid in self.types[type]:
+            index.setdefault(func(self.eids[eid]), []).append(eid)
+        if not can_be_empty:
+            assert index, "new index '%s' cannot be empty" % name
+        self.indexes[name] = index
+
+    @deprecated("[3.7] index support will disappear")
+    def build_rqlindex(self, name, type, key, rql, rql_params=False,
+                       func=None, can_be_empty=False):
+        """build an index by rql query
+
+        rql should return eid in first column
+        ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
+        """
+        self.types[type] = []
+        rset = self.rql(rql, rql_params or {})
+        if not can_be_empty:
+            assert rset, "new index type '%s' cannot be empty (0 record found)" % type
+        for entity in rset.entities():
+            getattr(entity, key) # autopopulate entity with key attribute
+            self.eids[entity.eid] = dict(entity)
+            if entity.eid not in self.types[type]:
+                self.types[type].append(entity.eid)
+
+        # Build index with specified key
+        func = lambda x: x[key]
+        self.build_index(name, type, func, can_be_empty=can_be_empty)
 
-    @deprecated('[3.6] get_one() deprecated. Use fetch(..., unique=True) instead')
-    def get_one(self, name, key):
-        return self.fetch(name, key, unique=True)
+    @deprecated("[3.7] index support will disappear")
+    def fetch(self, name, key, unique=False, decorator=None):
+        """index fetcher method
+
+        decorator is a callable method or an iterator of callable methods (usually a lambda function)
+        decorator=lambda x: x[:1] (first value is returned)
+        decorator=lambda x: x.lower (lowercased value is returned)
+
+        decorator is handy when you want to improve index keys but without
+        changing the original field
+
+        Same check functions can be reused here.
+        """
+        eids = self.indexes[name].get(key, [])
+        if decorator is not None:
+            if not hasattr(decorator, '__iter__'):
+                decorator = (decorator,)
+            for f in decorator:
+                eids = f(eids)
+        if unique:
+            assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
+            eids = eids[0]
+        return eids
+
+    @deprecated("[3.7] index support will disappear")
+    def find(self, type, key, value):
+        for idx in self.types[type]:
+            item = self.items[idx]
+            if item[key] == value:
+                yield item
+
+    @deprecated("[3.7] checkpoint() deprecated. use commit() instead")
+    def checkpoint(self):
+        self.commit()
 
 
 class RQLObjectStore(ObjectStore):
     """ObjectStore that works with an actual RQL repository (production mode)"""
     _rql = None # bw compat
 
-    def __init__(self, session=None, checkpoint=None):
+    def __init__(self, session=None, commit=None):
         ObjectStore.__init__(self)
         if session is not None:
             if not hasattr(session, 'set_pool'):
@@ -371,17 +388,21 @@
                 cnx = session
                 session = session.request()
                 session.set_pool = lambda : None
-                checkpoint = checkpoint or cnx.commit
+                commit = commit or cnx.commit
             else:
                 session.set_pool()
             self.session = session
-            self._checkpoint = checkpoint or session.commit
-        elif checkpoint is not None:
-            self._checkpoint = checkpoint
+            self._commit = commit or session.commit
+        elif commit is not None:
+            self._commit = commit
             # XXX .session
 
+    @deprecated("[3.7] checkpoint() deprecated. use commit() instead")
     def checkpoint(self):
-        self._checkpoint()
+        self.commit()
+
+    def commit(self):
+        self._commit()
         self.session.set_pool()
 
     def rql(self, *args):
@@ -401,7 +422,6 @@
         return self.rql(query, item)[0][0]
 
     def relate(self, eid_from, rtype, eid_to, inlined=False):
-        # if reverse relation is found, eids are exchanged
         eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
             eid_from, rtype, eid_to)
         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
@@ -460,14 +480,15 @@
         self.errors = {}
         for func, checks in self.generators:
             self._checks = {}
-            func_name = func.__name__[4:]  # XXX
-            self.tell("Import '%s'..." % func_name)
+            func_name = func.__name__
+            self.tell("Run import function '%s'..." % func_name)
             try:
                 func(self)
             except:
                 if self.catcherrors:
                     self.record_error(func_name, 'While calling %s' % func.__name__)
                 else:
+                    self._print_stats()
                     raise
             for key, func, title, help in checks:
                 buckets = self._checks.get(key)
@@ -475,13 +496,8 @@
                     err = func(buckets)
                     if err:
                         self.errors[title] = (help, err)
-        self.store.checkpoint()
-        nberrors = sum(len(err[1]) for err in self.errors.values())
-        self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
-                  % (self.store.nb_inserted_entities,
-                     self.store.nb_inserted_types,
-                     self.store.nb_inserted_relations,
-                     nberrors))
+        self.store.commit()
+        self._print_stats()
         if self.errors:
             if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
                 from pprint import pformat
@@ -489,6 +505,14 @@
                     self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
                     self.tell(pformat(sorted(error[1])))
 
+    def _print_stats(self):
+        nberrors = sum(len(err[1]) for err in self.errors.values())
+        self.tell('\nImport statistics: %i entities, %i types, %i relations and %i errors'
+                  % (self.store.nb_inserted_entities,
+                     self.store.nb_inserted_types,
+                     self.store.nb_inserted_relations,
+                     nberrors))
+
     def get_data(self, key):
         return self.data.get(key)
 
@@ -667,86 +691,3 @@
         for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
             eids.append(eschema_eid(self.session, eschema))
         return eids
-
-
-################################################################################
-
-utf8csvreader = deprecated('[3.6] use ucsvreader instead')(ucsvreader)
-
-@deprecated('[3.6] use required')
-def nonempty(value):
-    return required(value)
-
-@deprecated("[3.6] use call_check_method('isdigit')")
-def alldigits(txt):
-    if txt.isdigit():
-        return txt
-    else:
-        return u''
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def capitalize_if_unicase(txt):
-    if txt.isupper() or txt.islower():
-        return txt.capitalize()
-    return txt
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def yesno(value):
-    """simple heuristic that returns boolean value
-
-    >>> yesno("Yes")
-    True
-    >>> yesno("oui")
-    True
-    >>> yesno("1")
-    True
-    >>> yesno("11")
-    True
-    >>> yesno("")
-    False
-    >>> yesno("Non")
-    False
-    >>> yesno("blablabla")
-    False
-    """
-    if value:
-        return value.lower()[0] in 'yo1'
-    return False
-
-@deprecated("[3.7] use call_check_method('isalpha')")
-def isalpha(value):
-    if value.isalpha():
-        return value
-    raise ValueError("not all characters in the string alphabetic")
-
-@deprecated("[3.7] use call_transform_method('upper')")
-def uppercase(txt):
-    return txt.upper()
-
-@deprecated("[3.7] use call_transform_method('lower')")
-def lowercase(txt):
-    return txt.lower()
-
-@deprecated("[3.7] use call_transform_method('replace', ' ', '')")
-def no_space(txt):
-    return txt.replace(' ','')
-
-@deprecated("[3.7] use call_transform_method('replace', u'\xa0', '')")
-def no_uspace(txt):
-    return txt.replace(u'\xa0','')
-
-@deprecated("[3.7] use call_transform_method('replace', '-', '')")
-def no_dash(txt):
-    return txt.replace('-','')
-
-@deprecated("[3.7] use call_transform_method('strip')")
-def strip(txt):
-    return txt.strip()
-
-@deprecated("[3.7] use call_transform_method('replace', ',', '.'), float")
-def decimal(value):
-    return comma_float(value)
-
-@deprecated('[3.7] use int builtin')
-def integer(value):
-    return int(value)