52 __docformat__ = "restructuredtext en" |
52 __docformat__ = "restructuredtext en" |
53 |
53 |
54 import sys, csv, traceback |
54 import sys, csv, traceback |
55 |
55 |
56 from logilab.common import shellutils |
56 from logilab.common import shellutils |
57 |
57 from logilab.common.deprecation import deprecated |
58 def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'): |
58 |
59 """A csv reader that accepts files with any encoding and outputs |
59 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"', |
60 unicode strings.""" |
60 skipfirst=False, withpb=True): |
61 for row in csv.reader(file, delimiter=separator, quotechar=quote): |
61 """same as ucsvreader but a progress bar is displayed as we iter on rows""" |
|
62 rowcount = int(shellutils.Execute('wc -l %s' % filepath).out.strip().split()[0]) |
|
63 if skipfirst: |
|
64 rowcount -= 1 |
|
65 if withpb: |
|
66 pb = shellutils.ProgressBar(rowcount) |
|
67 for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst): |
|
68 yield urow |
|
69 if withpb: |
|
70 pb.update() |
|
71 print ' %s rows imported' % rowcount |
|
72 |
|
73 def ucsvreader(stream, encoding='utf-8', separator=',', quote='"', |
|
74 skipfirst=False): |
|
75 """A csv reader that accepts files with any encoding and outputs unicode |
|
76 strings |
|
77 """ |
|
78 it = iter(csv.reader(stream, delimiter=separator, quotechar=quote)) |
|
79 if skipfirst: |
|
80 it.next() |
|
81 for row in it: |
62 yield [item.decode(encoding) for item in row] |
82 yield [item.decode(encoding) for item in row] |
63 |
83 |
|
84 utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader) |
|
85 |
|
86 def commit_every(nbit, store, it): |
|
87 for i, x in enumerate(it): |
|
88 if i % nbit: |
|
89 store.checkpoint() |
|
90 yield x |
64 def lazytable(reader): |
91 def lazytable(reader): |
65 """The first row is taken to be the header of the table and |
92 """The first row is taken to be the header of the table and |
66 used to output a dict for each row of data. |
93 used to output a dict for each row of data. |
67 |
94 |
68 >>> data = lazytable(utf8csvreader(open(filename))) |
95 >>> data = lazytable(utf8csvreader(open(filename))) |
102 # base checks ##### |
129 # base checks ##### |
103 |
130 |
104 def check_doubles(buckets): |
131 def check_doubles(buckets): |
105 """Extract the keys that have more than one item in their bucket.""" |
132 """Extract the keys that have more than one item in their bucket.""" |
106 return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] |
133 return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] |
|
134 |
|
135 def check_doubles_not_none(buckets): |
|
136 """Extract the keys that have more than one item in their bucket.""" |
|
137 return [(key, len(value)) for key,value in buckets.items() if key is not None and len(value) > 1] |
107 |
138 |
108 # make entity helper ##### |
139 # make entity helper ##### |
109 |
140 |
110 def mk_entity(row, map): |
141 def mk_entity(row, map): |
111 """Return a dict made from sanitized mapped values. |
142 """Return a dict made from sanitized mapped values. |
179 for idx in self.types[type]: |
210 for idx in self.types[type]: |
180 item = self.items[idx] |
211 item = self.items[idx] |
181 if item[key] == value: |
212 if item[key] == value: |
182 yield item |
213 yield item |
183 |
214 |
184 def rql(self, query, args): |
|
185 if self._rql: |
|
186 return self._rql(query, args) |
|
187 |
|
188 def checkpoint(self): |
215 def checkpoint(self): |
189 if self._checkpoint: |
216 pass |
190 self._checkpoint() |
|
191 |
217 |
192 class RQLObjectStore(ObjectStore): |
218 class RQLObjectStore(ObjectStore): |
193 """ObjectStore that works with an actual RQL repository.""" |
219 """ObjectStore that works with an actual RQL repository.""" |
|
220 _rql = None # bw compat |
|
221 |
|
222 def __init__(self, session=None, checkpoint=None): |
|
223 ObjectStore.__init__(self) |
|
224 if session is not None: |
|
225 if not hasattr(session, 'set_pool'): |
|
226 # connection |
|
227 cnx = session |
|
228 session = session.request() |
|
229 session.set_pool = lambda : None |
|
230 checkpoint = checkpoint or cnx.commit |
|
231 self.session = session |
|
232 self.checkpoint = checkpoint or session.commit |
|
233 elif checkpoint is not None: |
|
234 self.checkpoint = checkpoint |
|
235 |
|
236 def rql(self, *args): |
|
237 if self._rql is not None: |
|
238 return self._rql(*args) |
|
239 self.session.set_pool() |
|
240 return self.session.execute(*args) |
|
241 |
|
242 def create_entity(self, *args, **kwargs): |
|
243 self.session.set_pool() |
|
244 entity = self.session.create_entity(*args, **kwargs) |
|
245 self.eids[entity.eid] = entity |
|
246 self.types.setdefault(args[0], []).append(entity.eid) |
|
247 return entity |
194 |
248 |
195 def _put(self, type, item): |
249 def _put(self, type, item): |
196 query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item]) |
250 query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item]) |
197 return self.rql(query, item)[0][0] |
251 return self.rql(query, item)[0][0] |
198 |
252 |
199 def relate(self, eid_from, rtype, eid_to): |
253 def relate(self, eid_from, rtype, eid_to): |
200 query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype |
254 self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, |
201 self.rql(query, {'from': int(eid_from), 'to': int(eid_to)}) |
255 {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y')) |
202 self.relations.add( (eid_from, rtype, eid_to) ) |
256 self.relations.add( (eid_from, rtype, eid_to) ) |
203 |
257 |
204 # import controller ##### |
258 # import controller ##### |
205 |
259 |
206 class CWImportController(object): |
260 class CWImportController(object): |
233 def run(self): |
287 def run(self): |
234 self.errors = {} |
288 self.errors = {} |
235 for func, checks in self.generators: |
289 for func, checks in self.generators: |
236 self._checks = {} |
290 self._checks = {} |
237 func_name = func.__name__[4:] |
291 func_name = func.__name__[4:] |
238 question = 'Importation de %s' % func_name |
292 question = 'Importing %s' % func_name |
239 self.tell(question) |
293 self.tell(question) |
240 try: |
294 try: |
241 func(self) |
295 func(self) |
242 except: |
296 except: |
243 import StringIO |
297 import StringIO |
244 tmp = StringIO.StringIO() |
298 tmp = StringIO.StringIO() |
245 traceback.print_exc(file=tmp) |
299 traceback.print_exc(file=tmp) |
246 print tmp.getvalue() |
300 print tmp.getvalue() |
|
301 # use a list to avoid counting a <nb lines> errors instead of one |
247 self.errors[func_name] = ('Erreur lors de la transformation', |
302 self.errors[func_name] = ('Erreur lors de la transformation', |
248 tmp.getvalue().splitlines()) |
303 [tmp.getvalue().splitlines()]) |
249 for key, func, title, help in checks: |
304 for key, func, title, help in checks: |
250 buckets = self._checks.get(key) |
305 buckets = self._checks.get(key) |
251 if buckets: |
306 if buckets: |
252 err = func(buckets) |
307 err = func(buckets) |
253 if err: |
308 if err: |
254 self.errors[title] = (help, err) |
309 self.errors[title] = (help, err) |
255 self.store.checkpoint() |
310 self.store.checkpoint() |
256 errors = sum(len(err[1]) for err in self.errors.values()) |
311 self.tell('Import completed: %i entities (%i types), %i relations' |
257 self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).' |
|
258 % (len(self.store.eids), len(self.store.types), |
312 % (len(self.store.eids), len(self.store.types), |
259 len(self.store.relations), errors)) |
313 len(self.store.relations))) |
260 if self.errors and self.askerror and confirm('Afficher les erreurs ?'): |
314 nberrors = sum(len(err[1]) for err in self.errors.values()) |
|
315 if nberrors: |
|
316 print '%s errors' % nberrors |
|
317 if self.errors and self.askerror and confirm('Display errors?'): |
261 import pprint |
318 import pprint |
262 pprint.pprint(self.errors) |
319 pprint.pprint(self.errors) |
263 |
320 |
264 def get_data(self, key): |
321 def get_data(self, key): |
265 return self.data.get(key) |
322 return self.data.get(key) |