# HG changeset patch # User Sylvain Thénault # Date 1261141721 -3600 # Node ID 47060a66c97f4abc56386556000bc22a21f7c506 # Parent cb0d0bf255f74d92320d7d4688e02f1cce57b706 dataimport refactoring / improvments, keeping bw compat (for now) diff -r cb0d0bf255f7 -r 47060a66c97f devtools/dataimport.py --- a/devtools/dataimport.py Fri Dec 18 14:07:16 2009 +0100 +++ b/devtools/dataimport.py Fri Dec 18 14:08:41 2009 +0100 @@ -54,13 +54,40 @@ import sys, csv, traceback from logilab.common import shellutils +from logilab.common.deprecation import deprecated -def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'): - """A csv reader that accepts files with any encoding and outputs - unicode strings.""" - for row in csv.reader(file, delimiter=separator, quotechar=quote): +def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"', + skipfirst=False, withpb=True): + """same as ucsvreader but a progress bar is displayed as we iter on rows""" + rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0]) + if skipfirst: + rowcount -= 1 + if withpb: + pb = shellutils.ProgressBar(rowcount) + for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst): + yield urow + if withpb: + pb.update() + print ' %s rows imported' % rowcount + +def ucsvreader(stream, encoding='utf-8', separator=',', quote='"', + skipfirst=False): + """A csv reader that accepts files with any encoding and outputs unicode + strings + """ + it = iter(csv.reader(stream, delimiter=separator, quotechar=quote)) + if skipfirst: + it.next() + for row in it: yield [item.decode(encoding) for item in row] +utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader) + +def commit_every(nbit, store, it): + for i, x in enumerate(it): + if i % nbit: + store.checkpoint() + yield x def lazytable(reader): """The first row is taken to be the header of the table and used to output a dict for each row of data. @@ -105,6 +132,10 @@ """Extract the keys that have more than one item in their bucket.""" return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] +def check_doubles_not_none(buckets): + """Extract the keys that have more than one item in their bucket.""" + return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1] + # make entity helper ##### def mk_entity(row, map): @@ -181,24 +212,47 @@ if item[key] == value: yield item - def rql(self, query, args): - if self._rql: - return self._rql(query, args) - def checkpoint(self): - if self._checkpoint: - self._checkpoint() + pass class RQLObjectStore(ObjectStore): """ObjectStore that works with an actual RQL repository.""" + _rql = None # bw compat + + def __init__(self, session=None, checkpoint=None): + ObjectStore.__init__(self) + if session is not None: + if not hasattr(session, 'set_pool'): + # connection + cnx = session + session = session.request() + session.set_pool = lambda : None + checkpoint = checkpoint or cnx.commit + self.session = session + self.checkpoint = checkpoint or session.commit + elif checkpoint is not None: + self.checkpoint = checkpoint + + def rql(self, *args): + if self._rql is not None: + return self._rql(*args) + self.session.set_pool() + return self.session.execute(*args) + + def create_entity(self, *args, **kwargs): + self.session.set_pool() + entity = self.session.create_entity(*args, **kwargs) + self.eids[entity.eid] = entity + self.types.setdefault(args[0], []).append(entity.eid) + return entity def _put(self, type, item): query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item]) return self.rql(query, item)[0][0] def relate(self, eid_from, rtype, eid_to): - query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype - self.rql(query, {'from': int(eid_from), 'to': int(eid_to)}) + self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, + {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y')) self.relations.add( (eid_from, rtype, eid_to) ) # import controller ##### @@ -235,7 +289,7 @@ for func, checks in self.generators: self._checks = {} func_name = func.__name__[4:] - question = 'Importation de %s' % func_name + question = 'Importing %s' % func_name self.tell(question) try: func(self) @@ -244,8 +298,9 @@ tmp = StringIO.StringIO() traceback.print_exc(file=tmp) print tmp.getvalue() + # use a list to avoid counting a errors instead of one self.errors[func_name] = ('Erreur lors de la transformation', - tmp.getvalue().splitlines()) + [tmp.getvalue().splitlines()]) for key, func, title, help in checks: buckets = self._checks.get(key) if buckets: @@ -253,11 +308,13 @@ if err: self.errors[title] = (help, err) self.store.checkpoint() - errors = sum(len(err[1]) for err in self.errors.values()) - self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).' + self.tell('Import completed: %i entities (%i types), %i relations' % (len(self.store.eids), len(self.store.types), - len(self.store.relations), errors)) - if self.errors and self.askerror and confirm('Afficher les erreurs ?'): + len(self.store.relations))) + nberrors = sum(len(err[1]) for err in self.errors.values()) + if nberrors: + print '%s errors' % nberrors + if self.errors and self.askerror and confirm('Display errors?'): import pprint pprint.pprint(self.errors)