devtools/dataimport.py
changeset 4527 67ab70e98488
parent 4252 6c4f109c2b03
child 4613 141a4f613f8a
--- a/devtools/dataimport.py	Tue Feb 09 18:49:12 2010 +0100
+++ b/devtools/dataimport.py	Fri Feb 05 17:13:53 2010 +0100
@@ -39,7 +39,7 @@
 
   # create controller
   ctl = CWImportController(RQLObjectStore())
-  ctl.askerror = True
+  ctl.askerror = 1
   ctl.generators = GENERATORS
   ctl.store._checkpoint = checkpoint
   ctl.store._rql = rql
@@ -48,6 +48,9 @@
   ctl.run()
   sys.exit(0)
 
+
+.. BUG fichier à une colonne pose un problème de parsing
+.. TODO rollback()
 """
 __docformat__ = "restructuredtext en"
 
@@ -110,16 +113,35 @@
 def mk_entity(row, map):
     """Return a dict made from sanitized mapped values.
 
+    ValidationError can be raised on unexpected values found in checkers
+
     >>> row = {'myname': u'dupont'}
     >>> map = [('myname', u'name', (capitalize_if_unicase,))]
     >>> mk_entity(row, map)
     {'name': u'Dupont'}
+    >>> row = {'myname': u'dupont', 'optname': u''}
+    >>> map = [('myname', u'name', (capitalize_if_unicase,)),
+    ...        ('optname', u'MARKER', (optional,))]
+    >>> mk_entity(row, map)
+    {'name': u'Dupont'}
     """
     res = {}
+    assert isinstance(row, dict)
+    assert isinstance(map, list)
     for src, dest, funcs in map:
+        assert not (required in funcs and optional in funcs), "optional and required checks are exclusive"
         res[dest] = row[src]
-        for func in funcs:
-            res[dest] = func(res[dest])
+        try:
+            for func in funcs:
+                res[dest] = func(res[dest])
+            if res[dest] is None or res[dest]==False:
+                raise AssertionError('undetermined value')
+        except AssertionError, err:
+            if optional in funcs:
+                # Forget this field if exception is coming from optional function
+               del res[dest]
+            else:
+               raise AssertionError('error with "%s" field: %s' % (src, err))
     return res
 
 
@@ -163,6 +185,12 @@
         return txt.capitalize()
     return txt
 
+def uppercase(txt):
+    return txt.upper()
+
+def lowercase(txt):
+    return txt.lower()
+
 def no_space(txt):
     return txt.replace(' ','')
 
@@ -172,15 +200,57 @@
 def no_dash(txt):
     return txt.replace('-','')
 
+def decimal(value):
+    """cast to float but with comma replacement
+
+    We take care of some locale format as replacing ',' by '.'"""
+    value = value.replace(',', '.')
+    try:
+        return float(value)
+    except Exception, err:
+        raise AssertionError(err)
+
+def integer(value):
+    try:
+        return int(value)
+    except Exception, err:
+        raise AssertionError(err)
+
+def strip(txt):
+    return txt.strip()
+
+def yesno(value):
+    return value.lower()[0] in 'yo1'
+
+def isalpha(value):
+    if value.isalpha():
+        return value
+    raise AssertionError("not all characters in the string alphabetic")
+
+def optional(value):
+    """validation error will not been raised if you add this checker in chain"""
+    return value
+
+def required(value):
+    """raise AssertionError is value is empty
+
+    This check should be often found in last position in the chain.
+    """
+    if bool(value):
+        return value
+    raise AssertionError("required")
+
+@deprecated('use required(value)')
+def nonempty(value):
+    return required(value)
+
+@deprecated('use integer(value)')
 def alldigits(txt):
     if txt.isdigit():
         return txt
     else:
         return u''
 
-def strip(txt):
-    return txt.strip()
-
 
 # base integrity checking functions ############################################
 
@@ -196,9 +266,9 @@
 # object stores #################################################################
 
 class ObjectStore(object):
-    """Store objects in memory for faster testing. Will not
-    enforce the constraints of the schema and hence will miss
-    some problems.
+    """Store objects in memory for *faster* validation (development mode)
+
+    But it will not enforce the constraints of the schema and hence will miss some problems
 
     >>> store = ObjectStore()
     >>> user = {'login': 'johndoe'}
@@ -207,7 +277,6 @@
     >>> store.add('CWUser', group)
     >>> store.relate(user['eid'], 'in_group', group['eid'])
     """
-
     def __init__(self):
         self.items = []
         self.eids = {}
@@ -228,23 +297,73 @@
         self.types.setdefault(type, []).append(eid)
 
     def relate(self, eid_from, rtype, eid_to):
-        eids_valid = (eid_from < len(self.items) and eid_to <= len(self.items))
-        assert eids_valid, 'eid error %s %s' % (eid_from, eid_to)
-        self.relations.add( (eid_from, rtype, eid_to) )
+        """Add new relation (reverse type support is available)
 
-    def build_index(self, name, type, func):
+        >>> 1,2 = eid_from, eid_to
+        >>> self.relate(eid_from, 'in_group', eid_to)
+        1, 'in_group', 2
+        >>> self.relate(eid_from, 'reverse_in_group', eid_to)
+        2, 'in_group', 1
+        """
+        if rtype.startswith('reverse_'):
+            eid_from, eid_to = eid_to, eid_from
+            rtype = rtype[8:]
+        relation = eid_from, rtype, eid_to
+        self.relations.add(relation)
+        return relation
+
+    def build_index(self, name, type, func=None):
         index = {}
+        if func is None or not callable(func):
+            func = lambda x: x['eid']
         for eid in self.types[type]:
             index.setdefault(func(self.eids[eid]), []).append(eid)
+        assert index, "new index '%s' cannot be empty" % name
         self.indexes[name] = index
 
+    def build_rqlindex(self, name, type, key, rql, rql_params=False, func=None):
+        """build an index by rql query
+
+        rql should return eid in first column
+        ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
+        """
+        rset = self.rql(rql, rql_params or {})
+        for entity in rset.entities():
+            getattr(entity, key) # autopopulate entity with key attribute
+            self.eids[entity.eid] = dict(entity)
+            if entity.eid not in self.types.setdefault(type, []):
+                self.types[type].append(entity.eid)
+        assert self.types[type], "new index type '%s' cannot be empty (0 record found)" % type
+
+        # Build index with specified key
+        func = lambda x: x[key]
+        self.build_index(name, type, func)
+
+    @deprecated('get_many() deprecated. Use fetch() instead')
     def get_many(self, name, key):
-        return self.indexes[name].get(key, [])
+        return self.fetch(name, key, unique=False)
 
+    @deprecated('get_one() deprecated. Use fetch(..., unique=True) instead')
     def get_one(self, name, key):
+        return self.fetch(name, key, unique=True)
+
+    def fetch(self, name, key, unique=False, decorator=None):
+        """
+            decorator is a callable method or an iterator of callable methods (usually a lambda function)
+            decorator=lambda x: x[:1] (first value is returned)
+
+            We can use validation check function available in _entity
+        """
         eids = self.indexes[name].get(key, [])
-        assert len(eids) == 1, 'expected a single one got %i' % len(eids)
-        return eids[0]
+        if decorator is not None:
+            if not hasattr(decorator, '__iter__'):
+                decorator = (decorator,)
+            for f in decorator:
+                eids = f(eids)
+        if unique:
+            assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
+            eids = eids[0] # FIXME maybe it's better to keep an iterator here ?
+        return eids
 
     def find(self, type, key, value):
         for idx in self.types[type]:
@@ -252,12 +371,16 @@
             if item[key] == value:
                 yield item
 
+    def rql(self, *args):
+        if self._rql is not None:
+            return self._rql(*args)
+
     def checkpoint(self):
         pass
 
 
 class RQLObjectStore(ObjectStore):
-    """ObjectStore that works with an actual RQL repository."""
+    """ObjectStore that works with an actual RQL repository (production mode)"""
     _rql = None # bw compat
 
     def __init__(self, session=None, checkpoint=None):
@@ -292,9 +415,10 @@
         return self.rql(query, item)[0][0]
 
     def relate(self, eid_from, rtype, eid_to):
+        # if reverse relation is found, eids are exchanged
+        eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(eid_from, rtype, eid_to)
         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                   {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
-        self.relations.add( (eid_from, rtype, eid_to) )
 
 
 # the import controller ########################################################
@@ -308,7 +432,7 @@
     >>> ctl.run()
     """
 
-    def __init__(self, store, askerror=False, catcherrors=None, tell=tell,
+    def __init__(self, store, askerror=0, catcherrors=None, tell=tell,
                  commitevery=50):
         self.store = store
         self.generators = None
@@ -350,7 +474,7 @@
         for func, checks in self.generators:
             self._checks = {}
             func_name = func.__name__[4:]  # XXX
-            self.tell('Importing %s' % func_name)
+            self.tell("Import '%s'..." % func_name)
             try:
                 func(self)
             except:
@@ -365,20 +489,32 @@
                     if err:
                         self.errors[title] = (help, err)
         self.store.checkpoint()
-        self.tell('\nImport completed: %i entities (%i types), %i relations'
+        nberrors = sum(len(err[1]) for err in self.errors.values())
+        self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
                   % (len(self.store.eids), len(self.store.types),
-                     len(self.store.relations)))
-        nberrors = sum(len(err[1]) for err in self.errors.values())
-        if nberrors:
-            print '%s errors' % nberrors
-        if self.errors and self.askerror and confirm('Display errors?'):
-            import pprint
-            pprint.pprint(self.errors)
+                     len(self.store.relations), nberrors))
+        if self.errors:
+            if self.askerror==2 or (self.askerror and confirm('Display errors ?')):
+                from pprint import pformat
+                for errkey, error in self.errors.items():
+                    self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
+                    self.tell(pformat(sorted(error[1])))
 
     def get_data(self, key):
         return self.data.get(key)
 
-    def index(self, name, key, value):
+    def index(self, name, key, value, unique=False):
+        """create a new index
+
+        If unique is set to True, only first occurence will be kept not the following ones
+        """
+        if unique:
+            try:
+                if value in self.store.indexes[name][key]:
+                    return
+            except KeyError:
+                # we're sure that one is the first occurence; so continue...
+                pass
         self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
 
     def tell(self, msg):