devtools/dataimport.py
changeset 4152 30fd1229137d
parent 4140 46ddd27a4ca4
child 4173 cfd5d3270f99
equal deleted inserted replaced
4151:66fe38345a65 4152:30fd1229137d
    57 from logilab.common.deprecation import deprecated
    57 from logilab.common.deprecation import deprecated
    58 
    58 
    59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
    59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
    60                   skipfirst=False, withpb=True):
    60                   skipfirst=False, withpb=True):
    61     """same as ucsvreader but a progress bar is displayed as we iter on rows"""
    61     """same as ucsvreader but a progress bar is displayed as we iter on rows"""
    62     rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0])
    62     rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
    63     if skipfirst:
    63     if skipfirst:
    64         rowcount -= 1
    64         rowcount -= 1
    65     if withpb:
    65     if withpb:
    66         pb = shellutils.ProgressBar(rowcount, 50)
    66         pb = shellutils.ProgressBar(rowcount, 50)
    67     for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
    67     for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
    83 
    83 
    84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
    84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
    85 
    85 
    86 def commit_every(nbit, store, it):
    86 def commit_every(nbit, store, it):
    87     for i, x in enumerate(it):
    87     for i, x in enumerate(it):
    88         if i % nbit:
    88         yield x
       
    89         if nbit is not None and i % nbit:
    89             store.checkpoint()
    90             store.checkpoint()
    90         yield x
    91     if nbit is not None:
       
    92         store.checkpoint()
       
    93 
    91 def lazytable(reader):
    94 def lazytable(reader):
    92     """The first row is taken to be the header of the table and
    95     """The first row is taken to be the header of the table and
    93     used to output a dict for each row of data.
    96     used to output a dict for each row of data.
    94 
    97 
    95     >>> data = lazytable(utf8csvreader(open(filename)))
    98     >>> data = lazytable(utf8csvreader(open(filename)))
    96     """
    99     """
    97     header = reader.next()
   100     header = reader.next()
    98     for row in reader:
   101     for row in reader:
    99         yield dict(zip(header, row))
   102         yield dict(zip(header, row))
   100 
       
   101 def tell(msg):
       
   102     print msg
       
   103 
       
   104 # base sanitizing functions #####
       
   105 
       
   106 def capitalize_if_unicase(txt):
       
   107     if txt.isupper() or txt.islower():
       
   108         return txt.capitalize()
       
   109     return txt
       
   110 
       
   111 def no_space(txt):
       
   112     return txt.replace(' ','')
       
   113 
       
   114 def no_uspace(txt):
       
   115     return txt.replace(u'\xa0','')
       
   116 
       
   117 def no_dash(txt):
       
   118     return txt.replace('-','')
       
   119 
       
   120 def alldigits(txt):
       
   121     if txt.isdigit():
       
   122         return txt
       
   123     else:
       
   124         return u''
       
   125 
       
   126 def strip(txt):
       
   127     return txt.strip()
       
   128 
       
   129 # base checks #####
       
   130 
       
   131 def check_doubles(buckets):
       
   132     """Extract the keys that have more than one item in their bucket."""
       
   133     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
       
   134 
       
   135 def check_doubles_not_none(buckets):
       
   136     """Extract the keys that have more than one item in their bucket."""
       
   137     return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
       
   138 
       
   139 # make entity helper #####
       
   140 
   103 
   141 def mk_entity(row, map):
   104 def mk_entity(row, map):
   142     """Return a dict made from sanitized mapped values.
   105     """Return a dict made from sanitized mapped values.
   143 
   106 
   144     >>> row = {'myname': u'dupont'}
   107     >>> row = {'myname': u'dupont'}
   151         res[dest] = row[src]
   114         res[dest] = row[src]
   152         for func in funcs:
   115         for func in funcs:
   153             res[dest] = func(res[dest])
   116             res[dest] = func(res[dest])
   154     return res
   117     return res
   155 
   118 
   156 # object stores
   119 
       
   120 # user interactions ############################################################
       
   121 
       
   122 def tell(msg):
       
   123     print msg
       
   124 
       
   125 def confirm(question):
       
   126     """A confirm function that asks for yes/no/abort and exits on abort."""
       
   127     answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
       
   128     if answer == 'abort':
       
   129         sys.exit(1)
       
   130     return answer == 'Y'
       
   131 
       
   132 
       
   133 class catch_error(object):
       
   134     """Helper for @contextmanager decorator."""
       
   135 
       
   136     def __init__(self, ctl, key='unexpected error', msg=None):
       
   137         self.ctl = ctl
       
   138         self.key = key
       
   139         self.msg = msg
       
   140 
       
   141     def __enter__(self):
       
   142         return self
       
   143 
       
   144     def __exit__(self, type, value, traceback):
       
   145         if type is not None:
       
   146             if issubclass(type, (KeyboardInterrupt, SystemExit)):
       
   147                 return # re-raise
       
   148             if self.ctl.catcherrors:
       
   149                 self.ctl.record_error(self.key, msg)
       
   150                 return True # silent
       
   151 
       
   152 
       
   153 # base sanitizing functions ####################################################
       
   154 
       
   155 def capitalize_if_unicase(txt):
       
   156     if txt.isupper() or txt.islower():
       
   157         return txt.capitalize()
       
   158     return txt
       
   159 
       
   160 def no_space(txt):
       
   161     return txt.replace(' ','')
       
   162 
       
   163 def no_uspace(txt):
       
   164     return txt.replace(u'\xa0','')
       
   165 
       
   166 def no_dash(txt):
       
   167     return txt.replace('-','')
       
   168 
       
   169 def alldigits(txt):
       
   170     if txt.isdigit():
       
   171         return txt
       
   172     else:
       
   173         return u''
       
   174 
       
   175 def strip(txt):
       
   176     return txt.strip()
       
   177 
       
   178 
       
   179 # base integrity checking functions ############################################
       
   180 
       
   181 def check_doubles(buckets):
       
   182     """Extract the keys that have more than one item in their bucket."""
       
   183     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
       
   184 
       
   185 def check_doubles_not_none(buckets):
       
   186     """Extract the keys that have more than one item in their bucket."""
       
   187     return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1]
       
   188 
       
   189 
       
   190 # object stores #################################################################
   157 
   191 
   158 class ObjectStore(object):
   192 class ObjectStore(object):
   159     """Store objects in memory for faster testing. Will not
   193     """Store objects in memory for faster testing. Will not
   160     enforce the constraints of the schema and hence will miss
   194     enforce the constraints of the schema and hence will miss
   161     some problems.
   195     some problems.
   212             if item[key] == value:
   246             if item[key] == value:
   213                 yield item
   247                 yield item
   214 
   248 
   215     def checkpoint(self):
   249     def checkpoint(self):
   216         pass
   250         pass
       
   251 
   217 
   252 
   218 class RQLObjectStore(ObjectStore):
   253 class RQLObjectStore(ObjectStore):
   219     """ObjectStore that works with an actual RQL repository."""
   254     """ObjectStore that works with an actual RQL repository."""
   220     _rql = None # bw compat
   255     _rql = None # bw compat
   221 
   256 
   253     def relate(self, eid_from, rtype, eid_to):
   288     def relate(self, eid_from, rtype, eid_to):
   254         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
   289         self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
   255                   {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
   290                   {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
   256         self.relations.add( (eid_from, rtype, eid_to) )
   291         self.relations.add( (eid_from, rtype, eid_to) )
   257 
   292 
   258 # import controller #####
   293 
       
   294 # the import controller ########################################################
   259 
   295 
   260 class CWImportController(object):
   296 class CWImportController(object):
   261     """Controller of the data import process.
   297     """Controller of the data import process.
   262 
   298 
   263     >>> ctl = CWImportController(store)
   299     >>> ctl = CWImportController(store)
   264     >>> ctl.generators = list_of_data_generators
   300     >>> ctl.generators = list_of_data_generators
   265     >>> ctl.data = dict_of_data_tables
   301     >>> ctl.data = dict_of_data_tables
   266     >>> ctl.run()
   302     >>> ctl.run()
   267     """
   303     """
   268 
   304 
   269     def __init__(self, store):
   305     def __init__(self, store, askerror=False, catcherrors=None, tell=tell,
       
   306                  commitevery=50):
   270         self.store = store
   307         self.store = store
   271         self.generators = None
   308         self.generators = None
   272         self.data = {}
   309         self.data = {}
   273         self.errors = None
   310         self.errors = None
   274         self.askerror = False
   311         self.askerror = askerror
       
   312         if  catcherrors is None:
       
   313             catcherrors = askerror
       
   314         self.catcherrors = catcherrors
       
   315         self.commitevery = commitevery # set to None to do a single commit
   275         self._tell = tell
   316         self._tell = tell
   276 
   317 
   277     def check(self, type, key, value):
   318     def check(self, type, key, value):
   278         self._checks.setdefault(type, {}).setdefault(key, []).append(value)
   319         self._checks.setdefault(type, {}).setdefault(key, []).append(value)
   279 
   320 
   282             entity[key] = map[entity[key]]
   323             entity[key] = map[entity[key]]
   283         except KeyError:
   324         except KeyError:
   284             self.check(key, entity[key], None)
   325             self.check(key, entity[key], None)
   285             entity[key] = default
   326             entity[key] = default
   286 
   327 
       
   328     def record_error(self, key, msg=None, type=None, value=None, tb=None):
       
   329         import StringIO
       
   330         tmp = StringIO.StringIO()
       
   331         if type is None:
       
   332             traceback.print_exc(file=tmp)
       
   333         else:
       
   334             traceback.print_exception(type, value, tb, file=tmp)
       
   335         print tmp.getvalue()
       
   336         # use a list to avoid counting a <nb lines> errors instead of one
       
   337         errorlog = self.errors.setdefault(key, [])
       
   338         if msg is None:
       
   339             errorlog.append(tmp.getvalue().splitlines())
       
   340         else:
       
   341             errorlog.append( (msg, tmp.getvalue().splitlines()) )
       
   342 
   287     def run(self):
   343     def run(self):
   288         self.errors = {}
   344         self.errors = {}
   289         for func, checks in self.generators:
   345         for func, checks in self.generators:
   290             self._checks = {}
   346             self._checks = {}
   291             func_name = func.__name__[4:]
   347             func_name = func.__name__[4:]  # XXX
   292             question = 'Importing %s' % func_name
   348             self.tell('Importing %s' % func_name)
   293             self.tell(question)
       
   294             try:
   349             try:
   295                 func(self)
   350                 func(self)
   296             except:
   351             except:
   297                 import StringIO
   352                 if self.catcherrors:
   298                 tmp = StringIO.StringIO()
   353                     self.record_error(func_name, 'While calling %s' % func.__name__)
   299                 traceback.print_exc(file=tmp)
   354                 else:
   300                 print tmp.getvalue()
   355                     raise
   301                 # use a list to avoid counting a <nb lines> errors instead of one
       
   302                 self.errors[func_name] = ('Erreur lors de la transformation',
       
   303                                           [tmp.getvalue().splitlines()])
       
   304             for key, func, title, help in checks:
   356             for key, func, title, help in checks:
   305                 buckets = self._checks.get(key)
   357                 buckets = self._checks.get(key)
   306                 if buckets:
   358                 if buckets:
   307                     err = func(buckets)
   359                     err = func(buckets)
   308                     if err:
   360                     if err:
   309                         self.errors[title] = (help, err)
   361                         self.errors[title] = (help, err)
   310             self.store.checkpoint()
   362         self.store.checkpoint()
   311         self.tell('\nImport completed: %i entities (%i types), %i relations'
   363         self.tell('\nImport completed: %i entities (%i types), %i relations'
   312                   % (len(self.store.eids), len(self.store.types),
   364                   % (len(self.store.eids), len(self.store.types),
   313                      len(self.store.relations)))
   365                      len(self.store.relations)))
   314         nberrors = sum(len(err[1]) for err in self.errors.values())
   366         nberrors = sum(len(err[1]) for err in self.errors.values())
   315         if nberrors:
   367         if nberrors:
   325         self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
   377         self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
   326 
   378 
   327     def tell(self, msg):
   379     def tell(self, msg):
   328         self._tell(msg)
   380         self._tell(msg)
   329 
   381 
   330 def confirm(question):
   382     def iter_and_commit(self, datakey):
   331     """A confirm function that asks for yes/no/abort and exits on abort."""
   383         """iter rows, triggering commit every self.commitevery iterations"""
   332     answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
   384         return commit_every(self.commitevery, self.store, self.get_data(datakey))
   333     if answer == 'abort':
       
   334         sys.exit(1)
       
   335     return answer == 'Y'