devtools/dataimport.py
branch3.5
changeset 2974 3dfe497e5afa
child 3003 2944ee420dca
equal deleted inserted replaced
2973:46a5a94287fa 2974:3dfe497e5afa
       
     1 # -*- coding: utf-8 -*-
       
     2 """This module provides tools to import tabular data.
       
     3 
       
     4 :organization: Logilab
       
     5 :copyright: 2001-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
       
     6 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     7 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
       
     8 
       
     9 
       
    10 Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
       
    11 
       
    12 .. sourcecode:: python
       
    13 
       
    14   # define data generators
       
    15   GENERATORS = []
       
    16 
       
    17   USERS = [('Prenom', 'firstname', ()),
       
    18            ('Nom', 'surname', ()),
       
    19            ('Identifiant', 'login', ()),
       
    20            ]
       
    21 
       
    22   def gen_users(ctl):
       
    23       for row in ctl.get_data('utilisateurs'):
       
    24           entity = mk_entity(row, USERS)
       
    25           entity['upassword'] = u'motdepasse'
       
    26           ctl.check('login', entity['login'], None)
       
    27           ctl.store.add('CWUser', entity)
       
    28           email = {'address': row['email']}
       
    29           ctl.store.add('EmailAddress', email)
       
    30           ctl.store.relate(entity['uid'], 'use_email', email['uid'])
       
    31           ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['uid']})
       
    32 
       
    33   CHK = [('login', check_doubles, 'Utilisateurs Login',
       
    34           'Deux utilisateurs ne devraient pas avoir le même login.'),
       
    35          ]
       
    36 
       
    37   GENERATORS.append( (gen_users, CHK) )
       
    38 
       
    39   # progress callback
       
    40   def tell(msg):
       
    41       print msg
       
    42 
       
    43   # create controller
       
    44   ctl = CWImportController(RQLObjectStore())
       
    45   ctl.askerror = True
       
    46   ctl._tell = tell
       
    47   ctl.generators = GENERATORS
       
    48   ctl.store._checkpoint = checkpoint
       
    49   ctl.store._rql = rql
       
    50   ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
       
    51   # run
       
    52   ctl.run()
       
    53   sys.exit(0)
       
    54 
       
    55 """
       
    56 __docformat__ = "restructuredtext en"
       
    57 
       
    58 import sys, csv, traceback
       
    59 
       
    60 from logilab.common import shellutils
       
    61 
       
    62 def utf8csvreader(file, encoding='utf-8', separator=',', quote='"'):
       
    63     """A csv reader that accepts files with any encoding and outputs
       
    64     unicode strings."""
       
    65     for row in csv.reader(file, delimiter=separator, quotechar=quote):
       
    66         yield [item.decode(encoding) for item in row]
       
    67 
       
    68 def lazytable(reader):
       
    69     """The first row is taken to be the header of the table and
       
    70     used to output a dict for each row of data.
       
    71 
       
    72     >>> data = lazytable(utf8csvreader(open(filename)))
       
    73     """
       
    74     header = reader.next()
       
    75     for row in reader:
       
    76         yield dict(zip(header, row))
       
    77 
       
    78 # base sanitizing functions #####
       
    79 
       
    80 def capitalize_if_unicase(txt):
       
    81     if txt.isupper() or txt.islower():
       
    82         return txt.capitalize()
       
    83     return txt
       
    84 
       
    85 def no_space(txt):
       
    86     return txt.replace(' ','')
       
    87 
       
    88 def no_uspace(txt):
       
    89     return txt.replace(u'\xa0','')
       
    90 
       
    91 def no_dash(txt):
       
    92     return txt.replace('-','')
       
    93 
       
    94 def alldigits(txt):
       
    95     if txt.isdigit():
       
    96         return txt
       
    97     else:
       
    98         return u''
       
    99 
       
   100 def strip(txt):
       
   101     return txt.strip()
       
   102 
       
   103 # base checks #####
       
   104 
       
   105 def check_doubles(buckets):
       
   106     """Extract the keys that have more than one item in their bucket."""
       
   107     return [(key, len(value)) for key,value in buckets.items() if len(value) > 1]
       
   108 
       
   109 # make entity helper #####
       
   110 
       
   111 def mk_entity(row, map):
       
   112     """Return a dict made from sanitized mapped values.
       
   113 
       
   114     >>> row = {'myname': u'dupont'}
       
   115     >>> map = [('myname', u'name', (capitalize_if_unicase,))]
       
   116     >>> mk_entity(row, map)
       
   117     {'name': u'Dupont'}
       
   118     """
       
   119     res = {}
       
   120     for src, dest, funcs in map:
       
   121         res[dest] = row[src]
       
   122         for func in funcs:
       
   123             res[dest] = func(res[dest])
       
   124     return res
       
   125 
       
   126 # object stores
       
   127 
       
   128 class ObjectStore(object):
       
   129     """Store objects in memory for faster testing. Will not
       
   130     enforce the constraints of the schema and hence will miss
       
   131     some problems.
       
   132 
       
   133     >>> store = ObjectStore()
       
   134     >>> user = {'login': 'johndoe'}
       
   135     >>> store.add('CWUser', user)
       
   136     >>> group = {'name': 'unknown'}
       
   137     >>> store.add('CWUser', group)
       
   138     >>> store.relate(user['uid'], 'in_group', group['uid'])
       
   139     """
       
   140 
       
   141     def __init__(self):
       
   142         self.items = []
       
   143         self.uids = {}
       
   144         self.types = {}
       
   145         self.relations = set()
       
   146         self.indexes = {}
       
   147         self._rql = None
       
   148         self._checkpoint = None
       
   149 
       
   150     def _put(self, type, item):
       
   151         self.items.append(item)
       
   152         return len(self.items) - 1
       
   153 
       
   154     def add(self, type, item):
       
   155         assert isinstance(item, dict), item
       
   156         uid = item['uid'] = self._put(type, item)
       
   157         self.uids[uid] = item
       
   158         self.types.setdefault(type, []).append(uid)
       
   159 
       
   160     def relate(self, uid_from, rtype, uid_to):
       
   161         uids_valid = (uid_from < len(self.items) and uid_to <= len(self.items))
       
   162         assert uids_valid, 'uid error %s %s' % (uid_from, uid_to)
       
   163         self.relations.add( (uid_from, rtype, uid_to) )
       
   164 
       
   165     def build_index(self, name, type, func):
       
   166         index = {}
       
   167         for uid in self.types[type]:
       
   168             index.setdefault(func(self.uids[uid]), []).append(uid)
       
   169         self.indexes[name] = index
       
   170 
       
   171     def get_many(self, name, key):
       
   172         return self.indexes[name].get(key, [])
       
   173 
       
   174     def get_one(self, name, key):
       
   175         uids = self.indexes[name].get(key, [])
       
   176         assert len(uids) == 1
       
   177         return uids[0]
       
   178 
       
   179     def find(self, type, key, value):
       
   180         for idx in self.types[type]:
       
   181             item = self.items[idx]
       
   182             if item[key] == value:
       
   183                 yield item
       
   184 
       
   185     def rql(self, query, args):
       
   186         if self._rql:
       
   187             return self._rql(query, args)
       
   188 
       
   189     def checkpoint(self):
       
   190         if self._checkpoint:
       
   191             self._checkpoint()
       
   192 
       
   193 class RQLObjectStore(ObjectStore):
       
   194     """ObjectStore that works with an actual RQL repository."""
       
   195 
       
   196     def _put(self, type, item):
       
   197         query = ('INSERT %s X: ' % type) + ', '.join(['X %s %%(%s)s' % (key,key) for key in item])
       
   198         return self.rql(query, item)[0][0]
       
   199 
       
   200     def relate(self, uid_from, rtype, uid_to):
       
   201         query = 'SET X %s Y WHERE X eid %%(from)s, Y eid %%(to)s' % rtype
       
   202         self.rql(query, {'from': int(uid_from), 'to': int(uid_to)})
       
   203         self.relations.add( (uid_from, rtype, uid_to) )
       
   204 
       
   205 # import controller #####
       
   206 
       
   207 class CWImportController(object):
       
   208     """Controller of the data import process.
       
   209 
       
   210     >>> ctl = CWImportController(store)
       
   211     >>> ctl.generators = list_of_data_generators
       
   212     >>> ctl.data = dict_of_data_tables
       
   213     >>> ctl.run()
       
   214     """
       
   215 
       
   216     def __init__(self, store):
       
   217         self.store = store
       
   218         self.generators = None
       
   219         self.data = {}
       
   220         self.errors = None
       
   221         self.askerror = False
       
   222 
       
   223     def check(self, type, key, value):
       
   224         self._checks.setdefault(type, {}).setdefault(key, []).append(value)
       
   225 
       
   226     def check_map(self, entity, key, map, default):
       
   227         try:
       
   228             entity[key] = map[entity[key]]
       
   229         except KeyError:
       
   230             self.check(key, entity[key], None)
       
   231             entity[key] = default
       
   232 
       
   233     def run(self):
       
   234         self.errors = {}
       
   235         for func, checks in self.generators:
       
   236             self._checks = {}
       
   237             func_name = func.__name__[4:]
       
   238             question = 'Importation de %s' % func_name
       
   239             self.tell(question)
       
   240             try:
       
   241                 func(self)
       
   242             except:
       
   243                 import StringIO
       
   244                 tmp = StringIO.StringIO()
       
   245                 traceback.print_exc(file=tmp)
       
   246                 print tmp.getvalue()
       
   247                 self.errors[func_name] = ('Erreur lors de la transformation',
       
   248                                           tmp.getvalue().splitlines())
       
   249             for key, func, title, help in checks:
       
   250                 buckets = self._checks.get(key)
       
   251                 if buckets:
       
   252                     err = func(buckets)
       
   253                     if err:
       
   254                         self.errors[title] = (help, err)
       
   255             self.store.checkpoint()
       
   256         errors = sum(len(err[1]) for err in self.errors.values())
       
   257         self.tell('Importation terminée. (%i objets, %i types, %i relations et %i erreurs).'
       
   258                   % (len(self.store.uids), len(self.store.types),
       
   259                      len(self.store.relations), errors))
       
   260         if self.errors and self.askerror and confirm('Afficher les erreurs ?'):
       
   261             import pprint
       
   262             pprint.pprint(self.errors)
       
   263 
       
   264     def get_data(self, key):
       
   265         return self.data.get(key)
       
   266 
       
   267     def index(self, name, key, value):
       
   268         self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
       
   269 
       
   270     def tell(self, msg):
       
   271         self._tell(msg)
       
   272 
       
   273 def confirm(question):
       
   274     """A confirm function that asks for yes/no/abort and exits on abort."""
       
   275     answer = shellutils.ASK.ask(question, ('Y','n','abort'), 'Y')
       
   276     if answer == 'abort':
       
   277         sys.exit(1)
       
   278     return answer == 'Y'