# HG changeset patch # User Sylvain Thénault # Date 1261417482 -3600 # Node ID 30fd1229137d5774e5caa64db1a63df49d73d609 # Parent 66fe38345a65c6ebc02cb7987cb311d111d3553b new catch_error context manager, nicer controller __init__ and new iter_and_commit(datakey) method diff -r 66fe38345a65 -r 30fd1229137d devtools/dataimport.py --- a/devtools/dataimport.py Mon Dec 21 18:43:16 2009 +0100 +++ b/devtools/dataimport.py Mon Dec 21 18:44:42 2009 +0100 @@ -59,7 +59,7 @@ def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"', skipfirst=False, withpb=True): """same as ucsvreader but a progress bar is displayed as we iter on rows""" - rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0]) + rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0]) if skipfirst: rowcount -= 1 if withpb: @@ -85,9 +85,12 @@ def commit_every(nbit, store, it): for i, x in enumerate(it): - if i % nbit: + yield x + if nbit is not None and i % nbit: store.checkpoint() - yield x + if nbit is not None: + store.checkpoint() + def lazytable(reader): """The first row is taken to be the header of the table and used to output a dict for each row of data. @@ -98,10 +101,56 @@ for row in reader: yield dict(zip(header, row)) +def mk_entity(row, map): + """Return a dict made from sanitized mapped values. + + >>> row = {'myname': u'dupont'} + >>> map = [('myname', u'name', (capitalize_if_unicase,))] + >>> mk_entity(row, map) + {'name': u'Dupont'} + """ + res = {} + for src, dest, funcs in map: + res[dest] = row[src] + for func in funcs: + res[dest] = func(res[dest]) + return res + + +# user interactions ############################################################ + def tell(msg): print msg -# base sanitizing functions ##### +def confirm(question): + """A confirm function that asks for yes/no/abort and exits on abort.""" + answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y') + if answer == 'abort': + sys.exit(1) + return answer == 'Y' + + +class catch_error(object): + """Helper for @contextmanager decorator.""" + + def __init__(self, ctl, key='unexpected error', msg=None): + self.ctl = ctl + self.key = key + self.msg = msg + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is not None: + if issubclass(type, (KeyboardInterrupt, SystemExit)): + return # re-raise + if self.ctl.catcherrors: + self.ctl.record_error(self.key, msg) + return True # silent + + +# base sanitizing functions #################################################### def capitalize_if_unicase(txt): if txt.isupper() or txt.islower(): @@ -126,7 +175,8 @@ def strip(txt): return txt.strip() -# base checks ##### + +# base integrity checking functions ############################################ def check_doubles(buckets): """Extract the keys that have more than one item in their bucket.""" @@ -136,24 +186,8 @@ """Extract the keys that have more than one item in their bucket.""" return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1] -# make entity helper ##### -def mk_entity(row, map): - """Return a dict made from sanitized mapped values. - - >>> row = {'myname': u'dupont'} - >>> map = [('myname', u'name', (capitalize_if_unicase,))] - >>> mk_entity(row, map) - {'name': u'Dupont'} - """ - res = {} - for src, dest, funcs in map: - res[dest] = row[src] - for func in funcs: - res[dest] = func(res[dest]) - return res - -# object stores +# object stores ################################################################# class ObjectStore(object): """Store objects in memory for faster testing. Will not @@ -215,6 +249,7 @@ def checkpoint(self): pass + class RQLObjectStore(ObjectStore): """ObjectStore that works with an actual RQL repository.""" _rql = None # bw compat @@ -255,7 +290,8 @@ {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y')) self.relations.add( (eid_from, rtype, eid_to) ) -# import controller ##### + +# the import controller ######################################################## class CWImportController(object): """Controller of the data import process. @@ -266,12 +302,17 @@ >>> ctl.run() """ - def __init__(self, store): + def __init__(self, store, askerror=False, catcherrors=None, tell=tell, + commitevery=50): self.store = store self.generators = None self.data = {} self.errors = None - self.askerror = False + self.askerror = askerror + if catcherrors is None: + catcherrors = askerror + self.catcherrors = catcherrors + self.commitevery = commitevery # set to None to do a single commit self._tell = tell def check(self, type, key, value): @@ -284,30 +325,41 @@ self.check(key, entity[key], None) entity[key] = default + def record_error(self, key, msg=None, type=None, value=None, tb=None): + import StringIO + tmp = StringIO.StringIO() + if type is None: + traceback.print_exc(file=tmp) + else: + traceback.print_exception(type, value, tb, file=tmp) + print tmp.getvalue() + # use a list to avoid counting a errors instead of one + errorlog = self.errors.setdefault(key, []) + if msg is None: + errorlog.append(tmp.getvalue().splitlines()) + else: + errorlog.append( (msg, tmp.getvalue().splitlines()) ) + def run(self): self.errors = {} for func, checks in self.generators: self._checks = {} - func_name = func.__name__[4:] - question = 'Importing %s' % func_name - self.tell(question) + func_name = func.__name__[4:] # XXX + self.tell('Importing %s' % func_name) try: func(self) except: - import StringIO - tmp = StringIO.StringIO() - traceback.print_exc(file=tmp) - print tmp.getvalue() - # use a list to avoid counting a errors instead of one - self.errors[func_name] = ('Erreur lors de la transformation', - [tmp.getvalue().splitlines()]) + if self.catcherrors: + self.record_error(func_name, 'While calling %s' % func.__name__) + else: + raise for key, func, title, help in checks: buckets = self._checks.get(key) if buckets: err = func(buckets) if err: self.errors[title] = (help, err) - self.store.checkpoint() + self.store.checkpoint() self.tell('\nImport completed: %i entities (%i types), %i relations' % (len(self.store.eids), len(self.store.types), len(self.store.relations))) @@ -327,9 +379,6 @@ def tell(self, msg): self._tell(msg) -def confirm(question): - """A confirm function that asks for yes/no/abort and exits on abort.""" - answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y') - if answer == 'abort': - sys.exit(1) - return answer == 'Y' + def iter_and_commit(self, datakey): + """iter rows, triggering commit every self.commitevery iterations""" + return commit_every(self.commitevery, self.store, self.get_data(datakey))