57 from logilab.common.deprecation import deprecated |
57 from logilab.common.deprecation import deprecated |
58 |
58 |
59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"', |
59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"', |
60 skipfirst=False, withpb=True): |
60 skipfirst=False, withpb=True): |
61 """same as ucsvreader but a progress bar is displayed as we iter on rows""" |
61 """same as ucsvreader but a progress bar is displayed as we iter on rows""" |
62 rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0]) |
62 rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0]) |
63 if skipfirst: |
63 if skipfirst: |
64 rowcount -= 1 |
64 rowcount -= 1 |
65 if withpb: |
65 if withpb: |
66 pb = shellutils.ProgressBar(rowcount, 50) |
66 pb = shellutils.ProgressBar(rowcount, 50) |
67 for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst): |
67 for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst): |
83 |
83 |
84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader) |
84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader) |
85 |
85 |
86 def commit_every(nbit, store, it): |
86 def commit_every(nbit, store, it): |
87 for i, x in enumerate(it): |
87 for i, x in enumerate(it): |
88 if i % nbit: |
88 yield x |
|
89 if nbit is not None and i % nbit: |
89 store.checkpoint() |
90 store.checkpoint() |
90 yield x |
91 if nbit is not None: |
|
92 store.checkpoint() |
|
93 |
91 def lazytable(reader): |
94 def lazytable(reader): |
92 """The first row is taken to be the header of the table and |
95 """The first row is taken to be the header of the table and |
93 used to output a dict for each row of data. |
96 used to output a dict for each row of data. |
94 |
97 |
95 >>> data = lazytable(utf8csvreader(open(filename))) |
98 >>> data = lazytable(utf8csvreader(open(filename))) |
96 """ |
99 """ |
97 header = reader.next() |
100 header = reader.next() |
98 for row in reader: |
101 for row in reader: |
99 yield dict(zip(header, row)) |
102 yield dict(zip(header, row)) |
100 |
|
101 def tell(msg): |
|
102 print msg |
|
103 |
|
104 # base sanitizing functions ##### |
|
105 |
|
106 def capitalize_if_unicase(txt): |
|
107 if txt.isupper() or txt.islower(): |
|
108 return txt.capitalize() |
|
109 return txt |
|
110 |
|
111 def no_space(txt): |
|
112 return txt.replace(' ','') |
|
113 |
|
114 def no_uspace(txt): |
|
115 return txt.replace(u'\xa0','') |
|
116 |
|
117 def no_dash(txt): |
|
118 return txt.replace('-','') |
|
119 |
|
120 def alldigits(txt): |
|
121 if txt.isdigit(): |
|
122 return txt |
|
123 else: |
|
124 return u'' |
|
125 |
|
126 def strip(txt): |
|
127 return txt.strip() |
|
128 |
|
129 # base checks ##### |
|
130 |
|
131 def check_doubles(buckets): |
|
132 """Extract the keys that have more than one item in their bucket.""" |
|
133 return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] |
|
134 |
|
135 def check_doubles_not_none(buckets): |
|
136 """Extract the keys that have more than one item in their bucket.""" |
|
137 return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1] |
|
138 |
|
139 # make entity helper ##### |
|
140 |
103 |
141 def mk_entity(row, map): |
104 def mk_entity(row, map): |
142 """Return a dict made from sanitized mapped values. |
105 """Return a dict made from sanitized mapped values. |
143 |
106 |
144 >>> row = {'myname': u'dupont'} |
107 >>> row = {'myname': u'dupont'} |
151 res[dest] = row[src] |
114 res[dest] = row[src] |
152 for func in funcs: |
115 for func in funcs: |
153 res[dest] = func(res[dest]) |
116 res[dest] = func(res[dest]) |
154 return res |
117 return res |
155 |
118 |
156 # object stores |
119 |
|
120 # user interactions ############################################################ |
|
121 |
|
122 def tell(msg): |
|
123 print msg |
|
124 |
|
125 def confirm(question): |
|
126 """A confirm function that asks for yes/no/abort and exits on abort.""" |
|
127 answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y') |
|
128 if answer == 'abort': |
|
129 sys.exit(1) |
|
130 return answer == 'Y' |
|
131 |
|
132 |
|
133 class catch_error(object): |
|
134 """Helper for @contextmanager decorator.""" |
|
135 |
|
136 def __init__(self, ctl, key='unexpected error', msg=None): |
|
137 self.ctl = ctl |
|
138 self.key = key |
|
139 self.msg = msg |
|
140 |
|
141 def __enter__(self): |
|
142 return self |
|
143 |
|
144 def __exit__(self, type, value, traceback): |
|
145 if type is not None: |
|
146 if issubclass(type, (KeyboardInterrupt, SystemExit)): |
|
147 return # re-raise |
|
148 if self.ctl.catcherrors: |
|
149 self.ctl.record_error(self.key, msg) |
|
150 return True # silent |
|
151 |
|
152 |
|
153 # base sanitizing functions #################################################### |
|
154 |
|
155 def capitalize_if_unicase(txt): |
|
156 if txt.isupper() or txt.islower(): |
|
157 return txt.capitalize() |
|
158 return txt |
|
159 |
|
160 def no_space(txt): |
|
161 return txt.replace(' ','') |
|
162 |
|
163 def no_uspace(txt): |
|
164 return txt.replace(u'\xa0','') |
|
165 |
|
166 def no_dash(txt): |
|
167 return txt.replace('-','') |
|
168 |
|
169 def alldigits(txt): |
|
170 if txt.isdigit(): |
|
171 return txt |
|
172 else: |
|
173 return u'' |
|
174 |
|
175 def strip(txt): |
|
176 return txt.strip() |
|
177 |
|
178 |
|
179 # base integrity checking functions ############################################ |
|
180 |
|
181 def check_doubles(buckets): |
|
182 """Extract the keys that have more than one item in their bucket.""" |
|
183 return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] |
|
184 |
|
185 def check_doubles_not_none(buckets): |
|
186 """Extract the keys that have more than one item in their bucket.""" |
|
187 return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1] |
|
188 |
|
189 |
|
190 # object stores ################################################################# |
157 |
191 |
158 class ObjectStore(object): |
192 class ObjectStore(object): |
159 """Store objects in memory for faster testing. Will not |
193 """Store objects in memory for faster testing. Will not |
160 enforce the constraints of the schema and hence will miss |
194 enforce the constraints of the schema and hence will miss |
161 some problems. |
195 some problems. |
253 def relate(self, eid_from, rtype, eid_to): |
288 def relate(self, eid_from, rtype, eid_to): |
254 self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, |
289 self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, |
255 {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y')) |
290 {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y')) |
256 self.relations.add( (eid_from, rtype, eid_to) ) |
291 self.relations.add( (eid_from, rtype, eid_to) ) |
257 |
292 |
258 # import controller ##### |
293 |
|
294 # the import controller ######################################################## |
259 |
295 |
260 class CWImportController(object): |
296 class CWImportController(object): |
261 """Controller of the data import process. |
297 """Controller of the data import process. |
262 |
298 |
263 >>> ctl = CWImportController(store) |
299 >>> ctl = CWImportController(store) |
264 >>> ctl.generators = list_of_data_generators |
300 >>> ctl.generators = list_of_data_generators |
265 >>> ctl.data = dict_of_data_tables |
301 >>> ctl.data = dict_of_data_tables |
266 >>> ctl.run() |
302 >>> ctl.run() |
267 """ |
303 """ |
268 |
304 |
269 def __init__(self, store): |
305 def __init__(self, store, askerror=False, catcherrors=None, tell=tell, |
|
306 commitevery=50): |
270 self.store = store |
307 self.store = store |
271 self.generators = None |
308 self.generators = None |
272 self.data = {} |
309 self.data = {} |
273 self.errors = None |
310 self.errors = None |
274 self.askerror = False |
311 self.askerror = askerror |
|
312 if catcherrors is None: |
|
313 catcherrors = askerror |
|
314 self.catcherrors = catcherrors |
|
315 self.commitevery = commitevery # set to None to do a single commit |
275 self._tell = tell |
316 self._tell = tell |
276 |
317 |
277 def check(self, type, key, value): |
318 def check(self, type, key, value): |
278 self._checks.setdefault(type, {}).setdefault(key, []).append(value) |
319 self._checks.setdefault(type, {}).setdefault(key, []).append(value) |
279 |
320 |
282 entity[key] = map[entity[key]] |
323 entity[key] = map[entity[key]] |
283 except KeyError: |
324 except KeyError: |
284 self.check(key, entity[key], None) |
325 self.check(key, entity[key], None) |
285 entity[key] = default |
326 entity[key] = default |
286 |
327 |
|
328 def record_error(self, key, msg=None, type=None, value=None, tb=None): |
|
329 import StringIO |
|
330 tmp = StringIO.StringIO() |
|
331 if type is None: |
|
332 traceback.print_exc(file=tmp) |
|
333 else: |
|
334 traceback.print_exception(type, value, tb, file=tmp) |
|
335 print tmp.getvalue() |
|
336 # use a list to avoid counting a <nb lines> errors instead of one |
|
337 errorlog = self.errors.setdefault(key, []) |
|
338 if msg is None: |
|
339 errorlog.append(tmp.getvalue().splitlines()) |
|
340 else: |
|
341 errorlog.append( (msg, tmp.getvalue().splitlines()) ) |
|
342 |
287 def run(self): |
343 def run(self): |
288 self.errors = {} |
344 self.errors = {} |
289 for func, checks in self.generators: |
345 for func, checks in self.generators: |
290 self._checks = {} |
346 self._checks = {} |
291 func_name = func.__name__[4:] |
347 func_name = func.__name__[4:] # XXX |
292 question = 'Importing %s' % func_name |
348 self.tell('Importing %s' % func_name) |
293 self.tell(question) |
|
294 try: |
349 try: |
295 func(self) |
350 func(self) |
296 except: |
351 except: |
297 import StringIO |
352 if self.catcherrors: |
298 tmp = StringIO.StringIO() |
353 self.record_error(func_name, 'While calling %s' % func.__name__) |
299 traceback.print_exc(file=tmp) |
354 else: |
300 print tmp.getvalue() |
355 raise |
301 # use a list to avoid counting a <nb lines> errors instead of one |
|
302 self.errors[func_name] = ('Erreur lors de la transformation', |
|
303 [tmp.getvalue().splitlines()]) |
|
304 for key, func, title, help in checks: |
356 for key, func, title, help in checks: |
305 buckets = self._checks.get(key) |
357 buckets = self._checks.get(key) |
306 if buckets: |
358 if buckets: |
307 err = func(buckets) |
359 err = func(buckets) |
308 if err: |
360 if err: |
309 self.errors[title] = (help, err) |
361 self.errors[title] = (help, err) |
310 self.store.checkpoint() |
362 self.store.checkpoint() |
311 self.tell('\nImport completed: %i entities (%i types), %i relations' |
363 self.tell('\nImport completed: %i entities (%i types), %i relations' |
312 % (len(self.store.eids), len(self.store.types), |
364 % (len(self.store.eids), len(self.store.types), |
313 len(self.store.relations))) |
365 len(self.store.relations))) |
314 nberrors = sum(len(err[1]) for err in self.errors.values()) |
366 nberrors = sum(len(err[1]) for err in self.errors.values()) |
315 if nberrors: |
367 if nberrors: |
325 self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value) |
377 self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value) |
326 |
378 |
327 def tell(self, msg): |
379 def tell(self, msg): |
328 self._tell(msg) |
380 self._tell(msg) |
329 |
381 |
330 def confirm(question): |
382 def iter_and_commit(self, datakey): |
331 """A confirm function that asks for yes/no/abort and exits on abort.""" |
383 """iter rows, triggering commit every self.commitevery iterations""" |
332 answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y') |
384 return commit_every(self.commitevery, self.store, self.get_data(datakey)) |
333 if answer == 'abort': |
|
334 sys.exit(1) |
|
335 return answer == 'Y' |
|