devtools/dataimport.py
changeset 4136 47060a66c97f
parent 3486 ea6bf6f9ba0c
child 4140 46ddd27a4ca4
equal deleted inserted replaced
4135:cb0d0bf255f7 4136:47060a66c97f
    52 __docformat__ = "restructuredtext en"
    52 __docformat__ = "restructuredtext en"
    53 
    53 
    54 import sys, csv, traceback
    54 import sys, csv, traceback
    55 
    55 
    56 from logilab.common import shellutils
    56 from logilab.common import shellutils
    57 
    57 from logilab.common.deprecation import deprecated
    58 def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'):
    58 
    59     """A csv reader that accepts files with any encoding and outputs
    59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
    60     unicode strings."""
    60                   skipfirst=False, withpb=True):
    61     for row in csv.reader(file, delimiter=separator, quotechar=quote):
    61     """same as ucsvreader but a progress bar is displayed as we iter on rows"""
       
    62     rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
       
    63     if skipfirst:
       
    64         rowcount -= 1
       
    65     if withpb:
       
    66         pb = shellutils.ProgressBar(rowcount)
       
    67     for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
       
    68         yield urow
       
    69         if withpb:
       
    70             pb.update()
       
    71     print ' %s rows imported' % rowcount
       
    72 
       
    73 def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
       
    74                skipfirst=False):
       
    75     """A csv reader that accepts files with any encoding and outputs unicode
       
    76     strings
       
    77     """
       
    78     it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
       
    79     if skipfirst:
       
    80         it.next()
       
    81     for row in it:
    62         yield [item.decode(encoding) for item in row]
    82         yield [item.decode(encoding) for item in row]
    63 
    83 
       
    84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
       
    85 
       
    86 def commit_every(nbit, store, it):
       
    87     for i, x in enumerate(it):
       
    88         if i % nbit:
       
    89             store.checkpoint()
       
    90         yield x
    64 def lazytable(reader):
    91 def lazytable(reader):
    65     """The first row is taken to be the header of the table and
    92     """The first row is taken to be the header of the table and
    66     used to output a dict for each row of data.
    93     used to output a dict for each row of data.
    67 
    94 
    68     >>> data = lazytable(utf8csvreader(open(filename)))
    95     >>> data = lazytable(utf8csvreader(open(filename)))
   102 # base checks #####
   129 # base checks #####
   103 
   130 
   104 def check_doubles(buckets):
   131 def check_doubles(buckets):
   105     """Extract the keys that have more than one item in their bucket."""
   132     """Extract the keys that have more than one item in their bucket."""
   106     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
   133     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
       
   134 
       
   135 def check_doubles_not_none(buckets):
       
   136     """Extract the keys that have more than one item in their bucket."""
       
   137     return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
   107 
   138 
   108 # make entity helper #####
   139 # make entity helper #####
   109 
   140 
   110 def mk_entity(row, map):
   141 def mk_entity(row, map):
   111     """Return a dict made from sanitized mapped values.
   142     """Return a dict made from sanitized mapped values.
   179         for idx in self.types[type]:
   210         for idx in self.types[type]:
   180             item = self.items[idx]
   211             item = self.items[idx]
   181             if item[key] == value:
   212             if item[key] == value:
   182                 yield item
   213                 yield item
   183 
   214 
   184     def rql(self, query, args):
       
   185         if self._rql:
       
   186             return self._rql(query, args)
       
   187 
       
   188     def checkpoint(self):
   215     def checkpoint(self):
   189         if self._checkpoint:
   216         pass
   190             self._checkpoint()
       
   191 
   217 
   192 class RQLObjectStore(ObjectStore):
   218 class RQLObjectStore(ObjectStore):
   193     """ObjectStore that works with an actual RQL repository."""
   219     """ObjectStore that works with an actual RQL repository."""
       
   220     _rql = None # bw compat
       
   221 
       
   222     def __init__(self, session=None, checkpoint=None):
       
   223         ObjectStore.__init__(self)
       
   224         if session is not None:
       
   225             if not hasattr(session, 'set_pool'):
       
   226                 # connection
       
   227                 cnx = session
       
   228                 session = session.request()
       
   229                 session.set_pool = lambda : None
       
   230                 checkpoint = checkpoint or cnx.commit
       
   231             self.session = session
       
   232             self.checkpoint = checkpoint or session.commit
       
   233         elif checkpoint is not None:
       
   234             self.checkpoint = checkpoint
       
   235 
       
   236     def rql(self, *args):
       
   237         if self._rql is not None:
       
   238             return self._rql(*args)
       
   239         self.session.set_pool()
       
   240         return self.session.execute(*args)
       
   241 
       
   242     def create_entity(self, *args, **kwargs):
       
   243         self.session.set_pool()
       
   244         entity = self.session.create_entity(*args, **kwargs)
       
   245         self.eids[entity.eid] = entity
       
   246         self.types.setdefault(args[0], []).append(entity.eid)
       
   247         return entity
   194 
   248 
   195     def _put(self, type, item):
   249     def _put(self, type, item):
   196         query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
   250         query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
   197         return self.rql(query, item)[0][0]
   251         return self.rql(query, item)[0][0]
   198 
   252 
   199     def relate(self, eid_from, rtype, eid_to):
   253     def relate(self, eid_from, rtype, eid_to):
   200         query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype
   254         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
   201         self.rql(query, {'from': int(eid_from), 'to': int(eid_to)})
   255                   {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
   202         self.relations.add( (eid_from, rtype, eid_to) )
   256         self.relations.add( (eid_from, rtype, eid_to) )
   203 
   257 
   204 # import controller #####
   258 # import controller #####
   205 
   259 
   206 class CWImportController(object):
   260 class CWImportController(object):
   233     def run(self):
   287     def run(self):
   234         self.errors = {}
   288         self.errors = {}
   235         for func, checks in self.generators:
   289         for func, checks in self.generators:
   236             self._checks = {}
   290             self._checks = {}
   237             func_name = func.__name__[4:]
   291             func_name = func.__name__[4:]
   238             question = 'Importation de %s' % func_name
   292             question = 'Importing %s' % func_name
   239             self.tell(question)
   293             self.tell(question)
   240             try:
   294             try:
   241                 func(self)
   295                 func(self)
   242             except:
   296             except:
   243                 import StringIO
   297                 import StringIO
   244                 tmp = StringIO.StringIO()
   298                 tmp = StringIO.StringIO()
   245                 traceback.print_exc(file=tmp)
   299                 traceback.print_exc(file=tmp)
   246                 print tmp.getvalue()
   300                 print tmp.getvalue()
       
   301                 # use a list to avoid counting a <nb lines> errors instead of one
   247                 self.errors[func_name] = ('Erreur lors de la transformation',
   302                 self.errors[func_name] = ('Erreur lors de la transformation',
   248                                           tmp.getvalue().splitlines())
   303                                           [tmp.getvalue().splitlines()])
   249             for key, func, title, help in checks:
   304             for key, func, title, help in checks:
   250                 buckets = self._checks.get(key)
   305                 buckets = self._checks.get(key)
   251                 if buckets:
   306                 if buckets:
   252                     err = func(buckets)
   307                     err = func(buckets)
   253                     if err:
   308                     if err:
   254                         self.errors[title] = (help, err)
   309                         self.errors[title] = (help, err)
   255             self.store.checkpoint()
   310             self.store.checkpoint()
   256         errors = sum(len(err[1]) for err in self.errors.values())
   311         self.tell('Import completed: %i entities (%i types), %i relations'
   257         self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).'
       
   258                   % (len(self.store.eids), len(self.store.types),
   312                   % (len(self.store.eids), len(self.store.types),
   259                      len(self.store.relations), errors))
   313                      len(self.store.relations)))
   260         if self.errors and self.askerror and confirm('Afficher les erreurs ?'):
   314         nberrors = sum(len(err[1]) for err in self.errors.values())
       
   315         if nberrors:
       
   316             print '%s errors' % nberrors
       
   317         if self.errors and self.askerror and confirm('Display errors?'):
   261             import pprint
   318             import pprint
   262             pprint.pprint(self.errors)
   319             pprint.pprint(self.errors)
   263 
   320 
   264     def get_data(self, key):
   321     def get_data(self, key):
   265         return self.data.get(key)
   322         return self.data.get(key)