--- a/devtools/dataimport.py Mon Dec 21 18:43:16 2009 +0100
+++ b/devtools/dataimport.py Mon Dec 21 18:44:42 2009 +0100
@@ -59,7 +59,7 @@
def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
skipfirst=False, withpb=True):
"""same as ucsvreader but a progress bar is displayed as we iter on rows"""
- rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
+ rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
if skipfirst:
rowcount -= 1
if withpb:
@@ -85,9 +85,12 @@
def commit_every(nbit, store, it):
for i, x in enumerate(it):
- if i % nbit:
+ yield x
+ if nbit is not None and i % nbit:
store.checkpoint()
- yield x
+ if nbit is not None:
+ store.checkpoint()
+
def lazytable(reader):
"""The first row is taken to be the header of the table and
used to output a dict for each row of data.
@@ -98,10 +101,56 @@
for row in reader:
yield dict(zip(header, row))
+def mk_entity(row, map):
+ """Return a dict made from sanitized mapped values.
+
+ >>> row = {'myname': u'dupont'}
+ >>> map = [('myname', u'name', (capitalize_if_unicase,))]
+ >>> mk_entity(row, map)
+ {'name': u'Dupont'}
+ """
+ res = {}
+ for src, dest, funcs in map:
+ res[dest] = row[src]
+ for func in funcs:
+ res[dest] = func(res[dest])
+ return res
+
+
+# user interactions ############################################################
+
def tell(msg):
print msg
-# base sanitizing functions #####
+def confirm(question):
+ """A confirm function that asks for yes/no/abort and exits on abort."""
+ answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
+ if answer == 'abort':
+ sys.exit(1)
+ return answer == 'Y'
+
+
+class catch_error(object):
+ """Helper for @contextmanager decorator."""
+
+ def __init__(self, ctl, key='unexpected error', msg=None):
+ self.ctl = ctl
+ self.key = key
+ self.msg = msg
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ if type is not None:
+ if issubclass(type, (KeyboardInterrupt, SystemExit)):
+ return # re-raise
+ if self.ctl.catcherrors:
+ self.ctl.record_error(self.key, msg)
+ return True # silent
+
+
+# base sanitizing functions ####################################################
def capitalize_if_unicase(txt):
if txt.isupper() or txt.islower():
@@ -126,7 +175,8 @@
def strip(txt):
return txt.strip()
-# base checks #####
+
+# base integrity checking functions ############################################
def check_doubles(buckets):
"""Extract the keys that have more than one item in their bucket."""
@@ -136,24 +186,8 @@
"""Extract the keys that have more than one item in their bucket."""
return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
-# make entity helper #####
-def mk_entity(row, map):
- """Return a dict made from sanitized mapped values.
-
- >>> row = {'myname': u'dupont'}
- >>> map = [('myname', u'name', (capitalize_if_unicase,))]
- >>> mk_entity(row, map)
- {'name': u'Dupont'}
- """
- res = {}
- for src, dest, funcs in map:
- res[dest] = row[src]
- for func in funcs:
- res[dest] = func(res[dest])
- return res
-
-# object stores
+# object stores #################################################################
class ObjectStore(object):
"""Store objects in memory for faster testing. Will not
@@ -215,6 +249,7 @@
def checkpoint(self):
pass
+
class RQLObjectStore(ObjectStore):
"""ObjectStore that works with an actual RQL repository."""
_rql = None # bw compat
@@ -255,7 +290,8 @@
{'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
self.relations.add( (eid_from, rtype, eid_to) )
-# import controller #####
+
+# the import controller ########################################################
class CWImportController(object):
"""Controller of the data import process.
@@ -266,12 +302,17 @@
>>> ctl.run()
"""
- def __init__(self, store):
+ def __init__(self, store, askerror=False, catcherrors=None, tell=tell,
+ commitevery=50):
self.store = store
self.generators = None
self.data = {}
self.errors = None
- self.askerror = False
+ self.askerror = askerror
+ if catcherrors is None:
+ catcherrors = askerror
+ self.catcherrors = catcherrors
+ self.commitevery = commitevery # set to None to do a single commit
self._tell = tell
def check(self, type, key, value):
@@ -284,30 +325,41 @@
self.check(key, entity[key], None)
entity[key] = default
+ def record_error(self, key, msg=None, type=None, value=None, tb=None):
+ import StringIO
+ tmp = StringIO.StringIO()
+ if type is None:
+ traceback.print_exc(file=tmp)
+ else:
+ traceback.print_exception(type, value, tb, file=tmp)
+ print tmp.getvalue()
+ # use a list to avoid counting a <nb lines> errors instead of one
+ errorlog = self.errors.setdefault(key, [])
+ if msg is None:
+ errorlog.append(tmp.getvalue().splitlines())
+ else:
+ errorlog.append( (msg, tmp.getvalue().splitlines()) )
+
def run(self):
self.errors = {}
for func, checks in self.generators:
self._checks = {}
- func_name = func.__name__[4:]
- question = 'Importing %s' % func_name
- self.tell(question)
+ func_name = func.__name__[4:] # XXX
+ self.tell('Importing %s' % func_name)
try:
func(self)
except:
- import StringIO
- tmp = StringIO.StringIO()
- traceback.print_exc(file=tmp)
- print tmp.getvalue()
- # use a list to avoid counting a <nb lines> errors instead of one
- self.errors[func_name] = ('Erreur lors de la transformation',
- [tmp.getvalue().splitlines()])
+ if self.catcherrors:
+ self.record_error(func_name, 'While calling %s' % func.__name__)
+ else:
+ raise
for key, func, title, help in checks:
buckets = self._checks.get(key)
if buckets:
err = func(buckets)
if err:
self.errors[title] = (help, err)
- self.store.checkpoint()
+ self.store.checkpoint()
self.tell('\nImport completed: %i entities (%i types), %i relations'
% (len(self.store.eids), len(self.store.types),
len(self.store.relations)))
@@ -327,9 +379,6 @@
def tell(self, msg):
self._tell(msg)
-def confirm(question):
- """A confirm function that asks for yes/no/abort and exits on abort."""
- answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
- if answer == 'abort':
- sys.exit(1)
- return answer == 'Y'
+ def iter_and_commit(self, datakey):
+ """iter rows, triggering commit every self.commitevery iterations"""
+ return commit_every(self.commitevery, self.store, self.get_data(datakey))