--- a/devtools/dataimport.py Fri Dec 18 14:07:16 2009 +0100
+++ b/devtools/dataimport.py Fri Dec 18 14:08:41 2009 +0100
@@ -54,13 +54,40 @@
import sys, csv, traceback
from logilab.common import shellutils
+from logilab.common.deprecation import deprecated
-def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'):
- """A csv reader that accepts files with any encoding and outputs
- unicode strings."""
- for row in csv.reader(file, delimiter=separator, quotechar=quote):
+def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
+ skipfirst=False, withpb=True):
+ """same as ucsvreader but a progress bar is displayed as we iter on rows"""
+ rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
+ if skipfirst:
+ rowcount -= 1
+ if withpb:
+ pb = shellutils.ProgressBar(rowcount)
+ for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
+ yield urow
+ if withpb:
+ pb.update()
+ print ' %s rows imported' % rowcount
+
+def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
+ skipfirst=False):
+ """A csv reader that accepts files with any encoding and outputs unicode
+ strings
+ """
+ it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
+ if skipfirst:
+ it.next()
+ for row in it:
yield [item.decode(encoding) for item in row]
+utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
+
+def commit_every(nbit, store, it):
+ for i, x in enumerate(it):
+ if i % nbit:
+ store.checkpoint()
+ yield x
def lazytable(reader):
"""The first row is taken to be the header of the table and
used to output a dict for each row of data.
@@ -105,6 +132,10 @@
"""Extract the keys that have more than one item in their bucket."""
return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
+def check_doubles_not_none(buckets):
+ """Extract the keys that have more than one item in their bucket."""
+ return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
+
# make entity helper #####
def mk_entity(row, map):
@@ -181,24 +212,47 @@
if item[key] == value:
yield item
- def rql(self, query, args):
- if self._rql:
- return self._rql(query, args)
-
def checkpoint(self):
- if self._checkpoint:
- self._checkpoint()
+ pass
class RQLObjectStore(ObjectStore):
"""ObjectStore that works with an actual RQL repository."""
+ _rql = None # bw compat
+
+ def __init__(self, session=None, checkpoint=None):
+ ObjectStore.__init__(self)
+ if session is not None:
+ if not hasattr(session, 'set_pool'):
+ # connection
+ cnx = session
+ session = session.request()
+ session.set_pool = lambda : None
+ checkpoint = checkpoint or cnx.commit
+ self.session = session
+ self.checkpoint = checkpoint or session.commit
+ elif checkpoint is not None:
+ self.checkpoint = checkpoint
+
+ def rql(self, *args):
+ if self._rql is not None:
+ return self._rql(*args)
+ self.session.set_pool()
+ return self.session.execute(*args)
+
+ def create_entity(self, *args, **kwargs):
+ self.session.set_pool()
+ entity = self.session.create_entity(*args, **kwargs)
+ self.eids[entity.eid] = entity
+ self.types.setdefault(args[0], []).append(entity.eid)
+ return entity
def _put(self, type, item):
query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
return self.rql(query, item)[0][0]
def relate(self, eid_from, rtype, eid_to):
- query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype
- self.rql(query, {'from': int(eid_from), 'to': int(eid_to)})
+ self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
+ {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
self.relations.add( (eid_from, rtype, eid_to) )
# import controller #####
@@ -235,7 +289,7 @@
for func, checks in self.generators:
self._checks = {}
func_name = func.__name__[4:]
- question = 'Importation de %s' % func_name
+ question = 'Importing %s' % func_name
self.tell(question)
try:
func(self)
@@ -244,8 +298,9 @@
tmp = StringIO.StringIO()
traceback.print_exc(file=tmp)
print tmp.getvalue()
+ # use a list to avoid counting a <nb lines> errors instead of one
self.errors[func_name] = ('Erreur lors de la transformation',
- tmp.getvalue().splitlines())
+ [tmp.getvalue().splitlines()])
for key, func, title, help in checks:
buckets = self._checks.get(key)
if buckets:
@@ -253,11 +308,13 @@
if err:
self.errors[title] = (help, err)
self.store.checkpoint()
- errors = sum(len(err[1]) for err in self.errors.values())
- self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).'
+ self.tell('Import completed: %i entities (%i types), %i relations'
% (len(self.store.eids), len(self.store.types),
- len(self.store.relations), errors))
- if self.errors and self.askerror and confirm('Afficher les erreurs ?'):
+ len(self.store.relations)))
+ nberrors = sum(len(err[1]) for err in self.errors.values())
+ if nberrors:
+ print '%s errors' % nberrors
+ if self.errors and self.askerror and confirm('Display errors?'):
import pprint
pprint.pprint(self.errors)