new catch_error context manager, nicer controller __init__ and new iter_and_commit(datakey) method
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 21 Dec 2009 18:44:42 +0100
changeset 4152 30fd1229137d
parent 4151 66fe38345a65
child 4153 1e0b30474454
new catch_error context manager, nicer controller __init__ and new iter_and_commit(datakey) method
devtools/dataimport.py
--- a/devtools/dataimport.py	Mon Dec 21 18:43:16 2009 +0100
+++ b/devtools/dataimport.py	Mon Dec 21 18:44:42 2009 +0100
@@ -59,7 +59,7 @@
 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
                   skipfirst=False, withpb=True):
     """same as ucsvreader but a progress bar is displayed as we iter on rows"""
-    rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
+    rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
     if skipfirst:
         rowcount -= 1
     if withpb:
@@ -85,9 +85,12 @@
 
 def commit_every(nbit, store, it):
     for i, x in enumerate(it):
-        if i % nbit:
+        yield x
+        if nbit is not None and i % nbit:
             store.checkpoint()
-        yield x
+    if nbit is not None:
+        store.checkpoint()
+
 def lazytable(reader):
     """The first row is taken to be the header of the table and
     used to output a dict for each row of data.
@@ -98,10 +101,56 @@
     for row in reader:
         yield dict(zip(header, row))
 
+def mk_entity(row, map):
+    """Return a dict made from sanitized mapped values.
+
+    >>> row = {'myname': u'dupont'}
+    >>> map = [('myname', u'name', (capitalize_if_unicase,))]
+    >>> mk_entity(row, map)
+    {'name': u'Dupont'}
+    """
+    res = {}
+    for src, dest, funcs in map:
+        res[dest] = row[src]
+        for func in funcs:
+            res[dest] = func(res[dest])
+    return res
+
+
+# user interactions ############################################################
+
 def tell(msg):
     print msg
 
-# base sanitizing functions #####
+def confirm(question):
+    """A confirm function that asks for yes/no/abort and exits on abort."""
+    answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
+    if answer == 'abort':
+        sys.exit(1)
+    return answer == 'Y'
+
+
+class catch_error(object):
+    """Helper for @contextmanager decorator."""
+
+    def __init__(self, ctl, key='unexpected error', msg=None):
+        self.ctl = ctl
+        self.key = key
+        self.msg = msg
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        if type is not None:
+            if issubclass(type, (KeyboardInterrupt, SystemExit)):
+                return # re-raise
+            if self.ctl.catcherrors:
+                self.ctl.record_error(self.key, msg)
+                return True # silent
+
+
+# base sanitizing functions ####################################################
 
 def capitalize_if_unicase(txt):
     if txt.isupper() or txt.islower():
@@ -126,7 +175,8 @@
 def strip(txt):
     return txt.strip()
 
-# base checks #####
+
+# base integrity checking functions ############################################
 
 def check_doubles(buckets):
     """Extract the keys that have more than one item in their bucket."""
@@ -136,24 +186,8 @@
     """Extract the keys that have more than one item in their bucket."""
     return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
 
-# make entity helper #####
 
-def mk_entity(row, map):
-    """Return a dict made from sanitized mapped values.
-
-    >>> row = {'myname': u'dupont'}
-    >>> map = [('myname', u'name', (capitalize_if_unicase,))]
-    >>> mk_entity(row, map)
-    {'name': u'Dupont'}
-    """
-    res = {}
-    for src, dest, funcs in map:
-        res[dest] = row[src]
-        for func in funcs:
-            res[dest] = func(res[dest])
-    return res
-
-# object stores
+# object stores #################################################################
 
 class ObjectStore(object):
     """Store objects in memory for faster testing. Will not
@@ -215,6 +249,7 @@
     def checkpoint(self):
         pass
 
+
 class RQLObjectStore(ObjectStore):
     """ObjectStore that works with an actual RQL repository."""
     _rql = None # bw compat
@@ -255,7 +290,8 @@
                   {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
         self.relations.add( (eid_from, rtype, eid_to) )
 
-# import controller #####
+
+# the import controller ########################################################
 
 class CWImportController(object):
     """Controller of the data import process.
@@ -266,12 +302,17 @@
     >>> ctl.run()
     """
 
-    def __init__(self, store):
+    def __init__(self, store, askerror=False, catcherrors=None, tell=tell,
+                 commitevery=50):
         self.store = store
         self.generators = None
         self.data = {}
         self.errors = None
-        self.askerror = False
+        self.askerror = askerror
+        if  catcherrors is None:
+            catcherrors = askerror
+        self.catcherrors = catcherrors
+        self.commitevery = commitevery # set to None to do a single commit
         self._tell = tell
 
     def check(self, type, key, value):
@@ -284,30 +325,41 @@
             self.check(key, entity[key], None)
             entity[key] = default
 
+    def record_error(self, key, msg=None, type=None, value=None, tb=None):
+        import StringIO
+        tmp = StringIO.StringIO()
+        if type is None:
+            traceback.print_exc(file=tmp)
+        else:
+            traceback.print_exception(type, value, tb, file=tmp)
+        print tmp.getvalue()
+        # use a list to avoid counting a <nb lines> errors instead of one
+        errorlog = self.errors.setdefault(key, [])
+        if msg is None:
+            errorlog.append(tmp.getvalue().splitlines())
+        else:
+            errorlog.append( (msg, tmp.getvalue().splitlines()) )
+
     def run(self):
         self.errors = {}
         for func, checks in self.generators:
             self._checks = {}
-            func_name = func.__name__[4:]
-            question = 'Importing %s' % func_name
-            self.tell(question)
+            func_name = func.__name__[4:]  # XXX
+            self.tell('Importing %s' % func_name)
             try:
                 func(self)
             except:
-                import StringIO
-                tmp = StringIO.StringIO()
-                traceback.print_exc(file=tmp)
-                print tmp.getvalue()
-                # use a list to avoid counting a <nb lines> errors instead of one
-                self.errors[func_name] = ('Erreur lors de la transformation',
-                                          [tmp.getvalue().splitlines()])
+                if self.catcherrors:
+                    self.record_error(func_name, 'While calling %s' % func.__name__)
+                else:
+                    raise
             for key, func, title, help in checks:
                 buckets = self._checks.get(key)
                 if buckets:
                     err = func(buckets)
                     if err:
                         self.errors[title] = (help, err)
-            self.store.checkpoint()
+        self.store.checkpoint()
         self.tell('\nImport completed: %i entities (%i types), %i relations'
                   % (len(self.store.eids), len(self.store.types),
                      len(self.store.relations)))
@@ -327,9 +379,6 @@
     def tell(self, msg):
         self._tell(msg)
 
-def confirm(question):
-    """A confirm function that asks for yes/no/abort and exits on abort."""
-    answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
-    if answer == 'abort':
-        sys.exit(1)
-    return answer == 'Y'
+    def iter_and_commit(self, datakey):
+        """iter rows, triggering commit every self.commitevery iterations"""
+        return commit_every(self.commitevery, self.store, self.get_data(datakey))