|
1 # -*- coding: utf-8 -*- |
|
2 """This module provides tools to import tabular data. |
|
3 |
|
4 :organization: Logilab |
|
5 :copyright: 2001-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2. |
|
6 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
7 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses |
|
8 |
|
9 |
|
10 Example of use (run this with `cubicweb-ctl shell instance import-script.py`): |
|
11 |
|
12 .. sourcecode:: python |
|
13 |
|
14 # define data generators |
|
15 GENERATORS = [] |
|
16 |
|
17 USERS = [('Prenom', 'firstname', ()), |
|
18 ('Nom', 'surname', ()), |
|
19 ('Identifiant', 'login', ()), |
|
20 ] |
|
21 |
|
22 def gen_users(ctl): |
|
23 for row in ctl.get_data('utilisateurs'): |
|
24 entity = mk_entity(row, USERS) |
|
25 entity['upassword'] = u'motdepasse' |
|
26 ctl.check('login', entity['login'], None) |
|
27 ctl.store.add('CWUser', entity) |
|
28 email = {'address': row['email']} |
|
29 ctl.store.add('EmailAddress', email) |
|
30 ctl.store.relate(entity['uid'], 'use_email', email['uid']) |
|
31 ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['uid']}) |
|
32 |
|
33 CHK = [('login', check_doubles, 'Utilisateurs Login', |
|
34 'Deux utilisateurs ne devraient pas avoir le même login.'), |
|
35 ] |
|
36 |
|
37 GENERATORS.append( (gen_users, CHK) ) |
|
38 |
|
39 # progress callback |
|
40 def tell(msg): |
|
41 print msg |
|
42 |
|
43 # create controller |
|
44 ctl = CWImportController(RQLObjectStore()) |
|
45 ctl.askerror = True |
|
46 ctl._tell = tell |
|
47 ctl.generators = GENERATORS |
|
48 ctl.store._checkpoint = checkpoint |
|
49 ctl.store._rql = rql |
|
50 ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv'))) |
|
51 # run |
|
52 ctl.run() |
|
53 sys.exit(0) |
|
54 |
|
55 """ |
|
56 __docformat__ = "restructuredtext en" |
|
57 |
|
58 import sys, csv, traceback |
|
59 |
|
60 from logilab.common import shellutils |
|
61 |
|
62 def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'): |
|
63 """A csv reader that accepts files with any encoding and outputs |
|
64 unicode strings.""" |
|
65 for row in csv.reader(file, delimiter=separator, quotechar=quote): |
|
66 yield [item.decode(encoding) for item in row] |
|
67 |
|
68 def lazytable(reader): |
|
69 """The first row is taken to be the header of the table and |
|
70 used to output a dict for each row of data. |
|
71 |
|
72 >>> data = lazytable(utf8csvreader(open(filename))) |
|
73 """ |
|
74 header = reader.next() |
|
75 for row in reader: |
|
76 yield dict(zip(header, row)) |
|
77 |
|
78 # base sanitizing functions ##### |
|
79 |
|
80 def capitalize_if_unicase(txt): |
|
81 if txt.isupper() or txt.islower(): |
|
82 return txt.capitalize() |
|
83 return txt |
|
84 |
|
85 def no_space(txt): |
|
86 return txt.replace(' ','') |
|
87 |
|
88 def no_uspace(txt): |
|
89 return txt.replace(u'\xa0','') |
|
90 |
|
91 def no_dash(txt): |
|
92 return txt.replace('-','') |
|
93 |
|
94 def alldigits(txt): |
|
95 if txt.isdigit(): |
|
96 return txt |
|
97 else: |
|
98 return u'' |
|
99 |
|
100 def strip(txt): |
|
101 return txt.strip() |
|
102 |
|
103 # base checks ##### |
|
104 |
|
105 def check_doubles(buckets): |
|
106 """Extract the keys that have more than one item in their bucket.""" |
|
107 return [(key, len(value)) for key,value in buckets.items() if len(value) > 1] |
|
108 |
|
109 # make entity helper ##### |
|
110 |
|
111 def mk_entity(row, map): |
|
112 """Return a dict made from sanitized mapped values. |
|
113 |
|
114 >>> row = {'myname': u'dupont'} |
|
115 >>> map = [('myname', u'name', (capitalize_if_unicase,))] |
|
116 >>> mk_entity(row, map) |
|
117 {'name': u'Dupont'} |
|
118 """ |
|
119 res = {} |
|
120 for src, dest, funcs in map: |
|
121 res[dest] = row[src] |
|
122 for func in funcs: |
|
123 res[dest] = func(res[dest]) |
|
124 return res |
|
125 |
|
126 # object stores |
|
127 |
|
128 class ObjectStore(object): |
|
129 """Store objects in memory for faster testing. Will not |
|
130 enforce the constraints of the schema and hence will miss |
|
131 some problems. |
|
132 |
|
133 >>> store = ObjectStore() |
|
134 >>> user = {'login': 'johndoe'} |
|
135 >>> store.add('CWUser', user) |
|
136 >>> group = {'name': 'unknown'} |
|
137 >>> store.add('CWUser', group) |
|
138 >>> store.relate(user['uid'], 'in_group', group['uid']) |
|
139 """ |
|
140 |
|
141 def __init__(self): |
|
142 self.items = [] |
|
143 self.uids = {} |
|
144 self.types = {} |
|
145 self.relations = set() |
|
146 self.indexes = {} |
|
147 self._rql = None |
|
148 self._checkpoint = None |
|
149 |
|
150 def _put(self, type, item): |
|
151 self.items.append(item) |
|
152 return len(self.items) - 1 |
|
153 |
|
154 def add(self, type, item): |
|
155 assert isinstance(item, dict), item |
|
156 uid = item['uid'] = self._put(type, item) |
|
157 self.uids[uid] = item |
|
158 self.types.setdefault(type, []).append(uid) |
|
159 |
|
160 def relate(self, uid_from, rtype, uid_to): |
|
161 uids_valid = (uid_from < len(self.items) and uid_to <= len(self.items)) |
|
162 assert uids_valid, 'uid error %s %s' % (uid_from, uid_to) |
|
163 self.relations.add( (uid_from, rtype, uid_to) ) |
|
164 |
|
165 def build_index(self, name, type, func): |
|
166 index = {} |
|
167 for uid in self.types[type]: |
|
168 index.setdefault(func(self.uids[uid]), []).append(uid) |
|
169 self.indexes[name] = index |
|
170 |
|
171 def get_many(self, name, key): |
|
172 return self.indexes[name].get(key, []) |
|
173 |
|
174 def get_one(self, name, key): |
|
175 uids = self.indexes[name].get(key, []) |
|
176 assert len(uids) == 1 |
|
177 return uids[0] |
|
178 |
|
179 def find(self, type, key, value): |
|
180 for idx in self.types[type]: |
|
181 item = self.items[idx] |
|
182 if item[key] == value: |
|
183 yield item |
|
184 |
|
185 def rql(self, query, args): |
|
186 if self._rql: |
|
187 return self._rql(query, args) |
|
188 |
|
189 def checkpoint(self): |
|
190 if self._checkpoint: |
|
191 self._checkpoint() |
|
192 |
|
193 class RQLObjectStore(ObjectStore): |
|
194 """ObjectStore that works with an actual RQL repository.""" |
|
195 |
|
196 def _put(self, type, item): |
|
197 query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item]) |
|
198 return self.rql(query, item)[0][0] |
|
199 |
|
200 def relate(self, uid_from, rtype, uid_to): |
|
201 query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype |
|
202 self.rql(query, {'from': int(uid_from), 'to': int(uid_to)}) |
|
203 self.relations.add( (uid_from, rtype, uid_to) ) |
|
204 |
|
205 # import controller ##### |
|
206 |
|
207 class CWImportController(object): |
|
208 """Controller of the data import process. |
|
209 |
|
210 >>> ctl = CWImportController(store) |
|
211 >>> ctl.generators = list_of_data_generators |
|
212 >>> ctl.data = dict_of_data_tables |
|
213 >>> ctl.run() |
|
214 """ |
|
215 |
|
216 def __init__(self, store): |
|
217 self.store = store |
|
218 self.generators = None |
|
219 self.data = {} |
|
220 self.errors = None |
|
221 self.askerror = False |
|
222 |
|
223 def check(self, type, key, value): |
|
224 self._checks.setdefault(type, {}).setdefault(key, []).append(value) |
|
225 |
|
226 def check_map(self, entity, key, map, default): |
|
227 try: |
|
228 entity[key] = map[entity[key]] |
|
229 except KeyError: |
|
230 self.check(key, entity[key], None) |
|
231 entity[key] = default |
|
232 |
|
233 def run(self): |
|
234 self.errors = {} |
|
235 for func, checks in self.generators: |
|
236 self._checks = {} |
|
237 func_name = func.__name__[4:] |
|
238 question = 'Importation de %s' % func_name |
|
239 self.tell(question) |
|
240 try: |
|
241 func(self) |
|
242 except: |
|
243 import StringIO |
|
244 tmp = StringIO.StringIO() |
|
245 traceback.print_exc(file=tmp) |
|
246 print tmp.getvalue() |
|
247 self.errors[func_name] = ('Erreur lors de la transformation', |
|
248 tmp.getvalue().splitlines()) |
|
249 for key, func, title, help in checks: |
|
250 buckets = self._checks.get(key) |
|
251 if buckets: |
|
252 err = func(buckets) |
|
253 if err: |
|
254 self.errors[title] = (help, err) |
|
255 self.store.checkpoint() |
|
256 errors = sum(len(err[1]) for err in self.errors.values()) |
|
257 self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).' |
|
258 % (len(self.store.uids), len(self.store.types), |
|
259 len(self.store.relations), errors)) |
|
260 if self.errors and self.askerror and confirm('Afficher les erreurs ?'): |
|
261 import pprint |
|
262 pprint.pprint(self.errors) |
|
263 |
|
264 def get_data(self, key): |
|
265 return self.data.get(key) |
|
266 |
|
267 def index(self, name, key, value): |
|
268 self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value) |
|
269 |
|
270 def tell(self, msg): |
|
271 self._tell(msg) |
|
272 |
|
273 def confirm(question): |
|
274 """A confirm function that asks for yes/no/abort and exits on abort.""" |
|
275 answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y') |
|
276 if answer == 'abort': |
|
277 sys.exit(1) |
|
278 return answer == 'Y' |