devtools/dataimport.py
changeset 4136 47060a66c97f
parent 3486 ea6bf6f9ba0c
child 4140 46ddd27a4ca4
--- a/devtools/dataimport.py	Fri Dec 18 14:07:16 2009 +0100
+++ b/devtools/dataimport.py	Fri Dec 18 14:08:41 2009 +0100
@@ -54,13 +54,40 @@
 import sys, csv, traceback
 
 from logilab.common import shellutils
+from logilab.common.deprecation import deprecated
 
-def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'):
-    """A csv reader that accepts files with any encoding and outputs
-    unicode strings."""
-    for row in csv.reader(file, delimiter=separator, quotechar=quote):
+def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
+                  skipfirst=False, withpb=True):
+    """same as ucsvreader but a progress bar is displayed as we iter on rows"""
+    rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
+    if skipfirst:
+        rowcount -= 1
+    if withpb:
+        pb = shellutils.ProgressBar(rowcount)
+    for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
+        yield urow
+        if withpb:
+            pb.update()
+    print ' %s rows imported' % rowcount
+
+def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
+               skipfirst=False):
+    """A csv reader that accepts files with any encoding and outputs unicode
+    strings
+    """
+    it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
+    if skipfirst:
+        it.next()
+    for row in it:
         yield [item.decode(encoding) for item in row]
 
+utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
+
+def commit_every(nbit, store, it):
+    for i, x in enumerate(it):
+        if i % nbit:
+            store.checkpoint()
+        yield x
 def lazytable(reader):
     """The first row is taken to be the header of the table and
     used to output a dict for each row of data.
@@ -105,6 +132,10 @@
     """Extract the keys that have more than one item in their bucket."""
     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
 
+def check_doubles_not_none(buckets):
+    """Extract the keys that have more than one item in their bucket."""
+    return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
+
 # make entity helper #####
 
 def mk_entity(row, map):
@@ -181,24 +212,47 @@
             if item[key] == value:
                 yield item
 
-    def rql(self, query, args):
-        if self._rql:
-            return self._rql(query, args)
-
     def checkpoint(self):
-        if self._checkpoint:
-            self._checkpoint()
+        pass
 
 class RQLObjectStore(ObjectStore):
     """ObjectStore that works with an actual RQL repository."""
+    _rql = None # bw compat
+
+    def __init__(self, session=None, checkpoint=None):
+        ObjectStore.__init__(self)
+        if session is not None:
+            if not hasattr(session, 'set_pool'):
+                # connection
+                cnx = session
+                session = session.request()
+                session.set_pool = lambda : None
+                checkpoint = checkpoint or cnx.commit
+            self.session = session
+            self.checkpoint = checkpoint or session.commit
+        elif checkpoint is not None:
+            self.checkpoint = checkpoint
+
+    def rql(self, *args):
+        if self._rql is not None:
+            return self._rql(*args)
+        self.session.set_pool()
+        return self.session.execute(*args)
+
+    def create_entity(self, *args, **kwargs):
+        self.session.set_pool()
+        entity = self.session.create_entity(*args, **kwargs)
+        self.eids[entity.eid] = entity
+        self.types.setdefault(args[0], []).append(entity.eid)
+        return entity
 
     def _put(self, type, item):
         query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
         return self.rql(query, item)[0][0]
 
     def relate(self, eid_from, rtype, eid_to):
-        query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype
-        self.rql(query, {'from': int(eid_from), 'to': int(eid_to)})
+        self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
+                  {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
         self.relations.add( (eid_from, rtype, eid_to) )
 
 # import controller #####
@@ -235,7 +289,7 @@
         for func, checks in self.generators:
             self._checks = {}
             func_name = func.__name__[4:]
-            question = 'Importation de %s' % func_name
+            question = 'Importing %s' % func_name
             self.tell(question)
             try:
                 func(self)
@@ -244,8 +298,9 @@
                 tmp = StringIO.StringIO()
                 traceback.print_exc(file=tmp)
                 print tmp.getvalue()
+                # use a list to avoid counting a <nb lines> errors instead of one
                 self.errors[func_name] = ('Erreur lors de la transformation',
-                                          tmp.getvalue().splitlines())
+                                          [tmp.getvalue().splitlines()])
             for key, func, title, help in checks:
                 buckets = self._checks.get(key)
                 if buckets:
@@ -253,11 +308,13 @@
                     if err:
                         self.errors[title] = (help, err)
             self.store.checkpoint()
-        errors = sum(len(err[1]) for err in self.errors.values())
-        self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).'
+        self.tell('Import completed: %i entities (%i types), %i relations'
                   % (len(self.store.eids), len(self.store.types),
-                     len(self.store.relations), errors))
-        if self.errors and self.askerror and confirm('Afficher les erreurs ?'):
+                     len(self.store.relations)))
+        nberrors = sum(len(err[1]) for err in self.errors.values())
+        if nberrors:
+            print '%s errors' % nberrors
+        if self.errors and self.askerror and confirm('Display errors?'):
             import pprint
             pprint.pprint(self.errors)