devtools/dataimport.py
branch3.5
changeset 2974 3dfe497e5afa
child 3003 2944ee420dca
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/devtools/dataimport.py	Tue Aug 25 11:47:26 2009 +0200
@@ -0,0 +1,278 @@
+# -*- coding: utf-8 -*-
+"""This module provides tools to import tabular data.
+
+:organization: Logilab
+:copyright: 2001-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
+
+
+Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
+
+.. sourcecode:: python
+
+  # define data generators
+  GENERATORS = []
+
+  USERS = [('Prenom', 'firstname', ()),
+           ('Nom', 'surname', ()),
+           ('Identifiant', 'login', ()),
+           ]
+
+  def gen_users(ctl):
+      for row in ctl.get_data('utilisateurs'):
+          entity = mk_entity(row, USERS)
+          entity['upassword'] = u'motdepasse'
+          ctl.check('login', entity['login'], None)
+          ctl.store.add('CWUser', entity)
+          email = {'address': row['email']}
+          ctl.store.add('EmailAddress', email)
+          ctl.store.relate(entity['uid'], 'use_email', email['uid'])
+          ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['uid']})
+
+  CHK = [('login', check_doubles, 'Utilisateurs Login',
+          'Deux utilisateurs ne devraient pas avoir le même login.'),
+         ]
+
+  GENERATORS.append( (gen_users, CHK) )
+
+  # progress callback
+  def tell(msg):
+      print msg
+
+  # create controller
+  ctl = CWImportController(RQLObjectStore())
+  ctl.askerror = True
+  ctl._tell = tell
+  ctl.generators = GENERATORS
+  ctl.store._checkpoint = checkpoint
+  ctl.store._rql = rql
+  ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
+  # run
+  ctl.run()
+  sys.exit(0)
+
+"""
+__docformat__ = "restructuredtext en"
+
+import sys, csv, traceback
+
+from logilab.common import shellutils
+
+def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'):
+    """A csv reader that accepts files with any encoding and outputs
+    unicode strings."""
+    for row in csv.reader(file, delimiter=separator, quotechar=quote):
+        yield [item.decode(encoding) for item in row]
+
+def lazytable(reader):
+    """The first row is taken to be the header of the table and
+    used to output a dict for each row of data.
+
+    >>> data = lazytable(utf8csvreader(open(filename)))
+    """
+    header = reader.next()
+    for row in reader:
+        yield dict(zip(header, row))
+
+# base sanitizing functions #####
+
+def capitalize_if_unicase(txt):
+    if txt.isupper() or txt.islower():
+        return txt.capitalize()
+    return txt
+
+def no_space(txt):
+    return txt.replace(' ','')
+
+def no_uspace(txt):
+    return txt.replace(u'\xa0','')
+
+def no_dash(txt):
+    return txt.replace('-','')
+
+def alldigits(txt):
+    if txt.isdigit():
+        return txt
+    else:
+        return u''
+
+def strip(txt):
+    return txt.strip()
+
+# base checks #####
+
+def check_doubles(buckets):
+    """Extract the keys that have more than one item in their bucket."""
+    return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
+
+# make entity helper #####
+
+def mk_entity(row, map):
+    """Return a dict made from sanitized mapped values.
+
+    >>> row = {'myname': u'dupont'}
+    >>> map = [('myname', u'name', (capitalize_if_unicase,))]
+    >>> mk_entity(row, map)
+    {'name': u'Dupont'}
+    """
+    res = {}
+    for src, dest, funcs in map:
+        res[dest] = row[src]
+        for func in funcs:
+            res[dest] = func(res[dest])
+    return res
+
+# object stores
+
+class ObjectStore(object):
+    """Store objects in memory for faster testing. Will not
+    enforce the constraints of the schema and hence will miss
+    some problems.
+
+    >>> store = ObjectStore()
+    >>> user = {'login': 'johndoe'}
+    >>> store.add('CWUser', user)
+    >>> group = {'name': 'unknown'}
+    >>> store.add('CWUser', group)
+    >>> store.relate(user['uid'], 'in_group', group['uid'])
+    """
+
+    def __init__(self):
+        self.items = []
+        self.uids = {}
+        self.types = {}
+        self.relations = set()
+        self.indexes = {}
+        self._rql = None
+        self._checkpoint = None
+
+    def _put(self, type, item):
+        self.items.append(item)
+        return len(self.items) - 1
+
+    def add(self, type, item):
+        assert isinstance(item, dict), item
+        uid = item['uid'] = self._put(type, item)
+        self.uids[uid] = item
+        self.types.setdefault(type, []).append(uid)
+
+    def relate(self, uid_from, rtype, uid_to):
+        uids_valid = (uid_from < len(self.items) and uid_to <= len(self.items))
+        assert uids_valid, 'uid error %s %s' % (uid_from, uid_to)
+        self.relations.add( (uid_from, rtype, uid_to) )
+
+    def build_index(self, name, type, func):
+        index = {}
+        for uid in self.types[type]:
+            index.setdefault(func(self.uids[uid]), []).append(uid)
+        self.indexes[name] = index
+
+    def get_many(self, name, key):
+        return self.indexes[name].get(key, [])
+
+    def get_one(self, name, key):
+        uids = self.indexes[name].get(key, [])
+        assert len(uids) == 1
+        return uids[0]
+
+    def find(self, type, key, value):
+        for idx in self.types[type]:
+            item = self.items[idx]
+            if item[key] == value:
+                yield item
+
+    def rql(self, query, args):
+        if self._rql:
+            return self._rql(query, args)
+
+    def checkpoint(self):
+        if self._checkpoint:
+            self._checkpoint()
+
+class RQLObjectStore(ObjectStore):
+    """ObjectStore that works with an actual RQL repository."""
+
+    def _put(self, type, item):
+        query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
+        return self.rql(query, item)[0][0]
+
+    def relate(self, uid_from, rtype, uid_to):
+        query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype
+        self.rql(query, {'from': int(uid_from), 'to': int(uid_to)})
+        self.relations.add( (uid_from, rtype, uid_to) )
+
+# import controller #####
+
+class CWImportController(object):
+    """Controller of the data import process.
+
+    >>> ctl = CWImportController(store)
+    >>> ctl.generators = list_of_data_generators
+    >>> ctl.data = dict_of_data_tables
+    >>> ctl.run()
+    """
+
+    def __init__(self, store):
+        self.store = store
+        self.generators = None
+        self.data = {}
+        self.errors = None
+        self.askerror = False
+
+    def check(self, type, key, value):
+        self._checks.setdefault(type, {}).setdefault(key, []).append(value)
+
+    def check_map(self, entity, key, map, default):
+        try:
+            entity[key] = map[entity[key]]
+        except KeyError:
+            self.check(key, entity[key], None)
+            entity[key] = default
+
+    def run(self):
+        self.errors = {}
+        for func, checks in self.generators:
+            self._checks = {}
+            func_name = func.__name__[4:]
+            question = 'Importation de %s' % func_name
+            self.tell(question)
+            try:
+                func(self)
+            except:
+                import StringIO
+                tmp = StringIO.StringIO()
+                traceback.print_exc(file=tmp)
+                print tmp.getvalue()
+                self.errors[func_name] = ('Erreur lors de la transformation',
+                                          tmp.getvalue().splitlines())
+            for key, func, title, help in checks:
+                buckets = self._checks.get(key)
+                if buckets:
+                    err = func(buckets)
+                    if err:
+                        self.errors[title] = (help, err)
+            self.store.checkpoint()
+        errors = sum(len(err[1]) for err in self.errors.values())
+        self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).'
+                  % (len(self.store.uids), len(self.store.types),
+                     len(self.store.relations), errors))
+        if self.errors and self.askerror and confirm('Afficher les erreurs ?'):
+            import pprint
+            pprint.pprint(self.errors)
+
+    def get_data(self, key):
+        return self.data.get(key)
+
+    def index(self, name, key, value):
+        self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
+
+    def tell(self, msg):
+        self._tell(msg)
+
+def confirm(question):
+    """A confirm function that asks for yes/no/abort and exits on abort."""
+    answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
+    if answer == 'abort':
+        sys.exit(1)
+    return answer == 'Y'