move devtools.dataimport at the cw level since we don't want cubes using it to depends on cubicweb-dev
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/dataimport.py Tue Mar 09 11:42:06 2010 +0100
@@ -0,0 +1,752 @@
+# -*- coding: utf-8 -*-
+"""This module provides tools to import tabular data.
+
+:organization: Logilab
+:copyright: 2001-2010 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
+
+
+Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
+
+.. sourcecode:: python
+
+ from cubicweb.devtools.dataimport import *
+ # define data generators
+ GENERATORS = []
+
+ USERS = [('Prenom', 'firstname', ()),
+ ('Nom', 'surname', ()),
+ ('Identifiant', 'login', ()),
+ ]
+
+ def gen_users(ctl):
+ for row in ctl.get_data('utilisateurs'):
+ entity = mk_entity(row, USERS)
+ entity['upassword'] = u'motdepasse'
+ ctl.check('login', entity['login'], None)
+ ctl.store.add('CWUser', entity)
+ email = {'address': row['email']}
+ ctl.store.add('EmailAddress', email)
+ ctl.store.relate(entity['eid'], 'use_email', email['eid'])
+ ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['eid']})
+
+ CHK = [('login', check_doubles, 'Utilisateurs Login',
+ 'Deux utilisateurs ne devraient pas avoir le même login.'),
+ ]
+
+ GENERATORS.append( (gen_users, CHK) )
+
+ # create controller
+ ctl = CWImportController(RQLObjectStore())
+ ctl.askerror = 1
+ ctl.generators = GENERATORS
+ ctl.store._checkpoint = checkpoint
+ ctl.store._rql = rql
+ ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
+ # run
+ ctl.run()
+ sys.exit(0)
+
+
+.. BUG fichier à une colonne pose un problème de parsing
+.. TODO rollback()
+"""
+__docformat__ = "restructuredtext en"
+
+import sys
+import csv
+import traceback
+import os.path as osp
+from StringIO import StringIO
+from copy import copy
+
+from logilab.common import shellutils
+from logilab.common.date import strptime
+from logilab.common.decorators import cached
+from logilab.common.deprecation import deprecated
+
+
+def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
+ skipfirst=False, withpb=True):
+ """same as ucsvreader but a progress bar is displayed as we iter on rows"""
+ if not osp.exists(filepath):
+ raise Exception("file doesn't exists: %s" % filepath)
+ rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
+ if skipfirst:
+ rowcount -= 1
+ if withpb:
+ pb = shellutils.ProgressBar(rowcount, 50)
+ for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
+ yield urow
+ if withpb:
+ pb.update()
+ print ' %s rows imported' % rowcount
+
+def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
+ skipfirst=False):
+ """A csv reader that accepts files with any encoding and outputs unicode
+ strings
+ """
+ it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
+ if skipfirst:
+ it.next()
+ for row in it:
+ yield [item.decode(encoding) for item in row]
+
+def commit_every(nbit, store, it):
+ for i, x in enumerate(it):
+ yield x
+ if nbit is not None and i % nbit:
+ store.checkpoint()
+ if nbit is not None:
+ store.checkpoint()
+
+def lazytable(reader):
+ """The first row is taken to be the header of the table and
+ used to output a dict for each row of data.
+
+ >>> data = lazytable(utf8csvreader(open(filename)))
+ """
+ header = reader.next()
+ for row in reader:
+ yield dict(zip(header, row))
+
+def mk_entity(row, map):
+ """Return a dict made from sanitized mapped values.
+
+ ValidationError can be raised on unexpected values found in checkers
+
+ >>> row = {'myname': u'dupont'}
+ >>> map = [('myname', u'name', (capitalize_if_unicase,))]
+ >>> mk_entity(row, map)
+ {'name': u'Dupont'}
+ >>> row = {'myname': u'dupont', 'optname': u''}
+ >>> map = [('myname', u'name', (capitalize_if_unicase,)),
+ ... ('optname', u'MARKER', (optional,))]
+ >>> mk_entity(row, map)
+ {'name': u'Dupont'}
+ """
+ res = {}
+ assert isinstance(row, dict)
+ assert isinstance(map, list)
+ for src, dest, funcs in map:
+ assert not (required in funcs and optional in funcs), \
+ "optional and required checks are exclusive"
+ res[dest] = row[src]
+ try:
+ for func in funcs:
+ res[dest] = func(res[dest])
+ if res[dest] is None:
+ break
+ except ValueError, err:
+ raise ValueError('error with %r field: %s' % (src, err))
+ return res
+
+
+# user interactions ############################################################
+
+def tell(msg):
+ print msg
+
+def confirm(question):
+ """A confirm function that asks for yes/no/abort and exits on abort."""
+ answer = shellutils.ASK.ask(question, ('Y', 'n', 'abort'), 'Y')
+ if answer == 'abort':
+ sys.exit(1)
+ return answer == 'Y'
+
+
+class catch_error(object):
+ """Helper for @contextmanager decorator."""
+
+ def __init__(self, ctl, key='unexpected error', msg=None):
+ self.ctl = ctl
+ self.key = key
+ self.msg = msg
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ if type is not None:
+ if issubclass(type, (KeyboardInterrupt, SystemExit)):
+ return # re-raise
+ if self.ctl.catcherrors:
+ self.ctl.record_error(self.key, None, type, value, traceback)
+ return True # silent
+
+
+# base sanitizing/coercing functions ###########################################
+
+def optional(value):
+ """validation error will not been raised if you add this checker in chain"""
+ if value:
+ return value
+ return None
+
+def required(value):
+ """raise ValueError is value is empty
+
+ This check should be often found in last position in the chain.
+ """
+ if value:
+ return value
+ raise ValueError("required")
+
+def todatetime(format='%d/%m/%Y'):
+ """return a transformation function to turn string input value into a
+ `datetime.datetime` instance, using given format.
+
+ Follow it by `todate` or `totime` functions from `logilab.common.date` if
+ you want a `date`/`time` instance instead of `datetime`.
+ """
+ def coerce(value):
+ return strptime(value, format)
+ return coerce
+
+def call_transform_method(methodname, *args, **kwargs):
+ """return value returned by calling the given method on input"""
+ def coerce(value):
+ return getattr(value, methodname)(*args, **kwargs)
+ return coerce
+
+def call_check_method(methodname, *args, **kwargs):
+ """check value returned by calling the given method on input is true,
+ else raise ValueError
+ """
+ def check(value):
+ if getattr(value, methodname)(*args, **kwargs):
+ return value
+ raise ValueError('%s not verified on %r' % (methodname, value))
+ return check
+
+# base integrity checking functions ############################################
+
+def check_doubles(buckets):
+ """Extract the keys that have more than one item in their bucket."""
+ return [(k, len(v)) for k, v in buckets.items() if len(v) > 1]
+
+def check_doubles_not_none(buckets):
+ """Extract the keys that have more than one item in their bucket."""
+ return [(k, len(v)) for k, v in buckets.items()
+ if k is not None and len(v) > 1]
+
+
+# object stores #################################################################
+
+class ObjectStore(object):
+ """Store objects in memory for *faster* validation (development mode)
+
+ But it will not enforce the constraints of the schema and hence will miss some problems
+
+ >>> store = ObjectStore()
+ >>> user = {'login': 'johndoe'}
+ >>> store.add('CWUser', user)
+ >>> group = {'name': 'unknown'}
+ >>> store.add('CWUser', group)
+ >>> store.relate(user['eid'], 'in_group', group['eid'])
+ """
+ def __init__(self):
+ self.items = []
+ self.eids = {}
+ self.types = {}
+ self.relations = set()
+ self.indexes = {}
+ self._rql = None
+ self._checkpoint = None
+
+ def _put(self, type, item):
+ self.items.append(item)
+ return len(self.items) - 1
+
+ def add(self, type, item):
+ assert isinstance(item, dict), 'item is not a dict but a %s' % type(item)
+ eid = item['eid'] = self._put(type, item)
+ self.eids[eid] = item
+ self.types.setdefault(type, []).append(eid)
+
+ def relate(self, eid_from, rtype, eid_to, inlined=False):
+ """Add new relation (reverse type support is available)
+
+ >>> 1,2 = eid_from, eid_to
+ >>> self.relate(eid_from, 'in_group', eid_to)
+ 1, 'in_group', 2
+ >>> self.relate(eid_from, 'reverse_in_group', eid_to)
+ 2, 'in_group', 1
+ """
+ if rtype.startswith('reverse_'):
+ eid_from, eid_to = eid_to, eid_from
+ rtype = rtype[8:]
+ relation = eid_from, rtype, eid_to
+ self.relations.add(relation)
+ return relation
+
+ def build_index(self, name, type, func=None):
+ index = {}
+ if func is None or not callable(func):
+ func = lambda x: x['eid']
+ for eid in self.types[type]:
+ index.setdefault(func(self.eids[eid]), []).append(eid)
+ assert index, "new index '%s' cannot be empty" % name
+ self.indexes[name] = index
+
+ def build_rqlindex(self, name, type, key, rql, rql_params=False, func=None):
+ """build an index by rql query
+
+ rql should return eid in first column
+ ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
+ """
+ rset = self.rql(rql, rql_params or {})
+ for entity in rset.entities():
+ getattr(entity, key) # autopopulate entity with key attribute
+ self.eids[entity.eid] = dict(entity)
+ if entity.eid not in self.types.setdefault(type, []):
+ self.types[type].append(entity.eid)
+ assert self.types[type], "new index type '%s' cannot be empty (0 record found)" % type
+
+ # Build index with specified key
+ func = lambda x: x[key]
+ self.build_index(name, type, func)
+
+ def fetch(self, name, key, unique=False, decorator=None):
+ """
+ decorator is a callable method or an iterator of callable methods (usually a lambda function)
+ decorator=lambda x: x[:1] (first value is returned)
+
+ We can use validation check function available in _entity
+ """
+ eids = self.indexes[name].get(key, [])
+ if decorator is not None:
+ if not hasattr(decorator, '__iter__'):
+ decorator = (decorator,)
+ for f in decorator:
+ eids = f(eids)
+ if unique:
+ assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
+ eids = eids[0] # FIXME maybe it's better to keep an iterator here ?
+ return eids
+
+ def find(self, type, key, value):
+ for idx in self.types[type]:
+ item = self.items[idx]
+ if item[key] == value:
+ yield item
+
+ def rql(self, *args):
+ if self._rql is not None:
+ return self._rql(*args)
+
+ def checkpoint(self):
+ pass
+
+ @property
+ def nb_inserted_entities(self):
+ return len(self.eids)
+ @property
+ def nb_inserted_types(self):
+ return len(self.types)
+ @property
+ def nb_inserted_relations(self):
+ return len(self.relations)
+
+ @deprecated('[3.6] get_many() deprecated. Use fetch() instead')
+ def get_many(self, name, key):
+ return self.fetch(name, key, unique=False)
+
+ @deprecated('[3.6] get_one() deprecated. Use fetch(..., unique=True) instead')
+ def get_one(self, name, key):
+ return self.fetch(name, key, unique=True)
+
+
+class RQLObjectStore(ObjectStore):
+ """ObjectStore that works with an actual RQL repository (production mode)"""
+ _rql = None # bw compat
+
+ def __init__(self, session=None, checkpoint=None):
+ ObjectStore.__init__(self)
+ if session is not None:
+ if not hasattr(session, 'set_pool'):
+ # connection
+ cnx = session
+ session = session.request()
+ session.set_pool = lambda : None
+ checkpoint = checkpoint or cnx.commit
+ else:
+ session.set_pool()
+ self.session = session
+ self._checkpoint = checkpoint or session.commit
+ elif checkpoint is not None:
+ self._checkpoint = checkpoint
+ # XXX .session
+
+ def checkpoint(self):
+ self._checkpoint()
+ self.session.set_pool()
+
+ def rql(self, *args):
+ if self._rql is not None:
+ return self._rql(*args)
+ return self.session.execute(*args)
+
+ def create_entity(self, *args, **kwargs):
+ entity = self.session.create_entity(*args, **kwargs)
+ self.eids[entity.eid] = entity
+ self.types.setdefault(args[0], []).append(entity.eid)
+ return entity
+
+ def _put(self, type, item):
+ query = ('INSERT %s X: ' % type) + ', '.join('X %s %%(%s)s' % (k, k)
+ for k in item)
+ return self.rql(query, item)[0][0]
+
+ def relate(self, eid_from, rtype, eid_to, inlined=False):
+ # if reverse relation is found, eids are exchanged
+ eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
+ eid_from, rtype, eid_to)
+ self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
+ {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
+
+
+# the import controller ########################################################
+
+class CWImportController(object):
+ """Controller of the data import process.
+
+ >>> ctl = CWImportController(store)
+ >>> ctl.generators = list_of_data_generators
+ >>> ctl.data = dict_of_data_tables
+ >>> ctl.run()
+ """
+
+ def __init__(self, store, askerror=0, catcherrors=None, tell=tell,
+ commitevery=50):
+ self.store = store
+ self.generators = None
+ self.data = {}
+ self.errors = None
+ self.askerror = askerror
+ if catcherrors is None:
+ catcherrors = askerror
+ self.catcherrors = catcherrors
+ self.commitevery = commitevery # set to None to do a single commit
+ self._tell = tell
+
+ def check(self, type, key, value):
+ self._checks.setdefault(type, {}).setdefault(key, []).append(value)
+
+ def check_map(self, entity, key, map, default):
+ try:
+ entity[key] = map[entity[key]]
+ except KeyError:
+ self.check(key, entity[key], None)
+ entity[key] = default
+
+ def record_error(self, key, msg=None, type=None, value=None, tb=None):
+ tmp = StringIO()
+ if type is None:
+ traceback.print_exc(file=tmp)
+ else:
+ traceback.print_exception(type, value, tb, file=tmp)
+ print tmp.getvalue()
+ # use a list to avoid counting a <nb lines> errors instead of one
+ errorlog = self.errors.setdefault(key, [])
+ if msg is None:
+ errorlog.append(tmp.getvalue().splitlines())
+ else:
+ errorlog.append( (msg, tmp.getvalue().splitlines()) )
+
+ def run(self):
+ self.errors = {}
+ for func, checks in self.generators:
+ self._checks = {}
+ func_name = func.__name__[4:] # XXX
+ self.tell("Import '%s'..." % func_name)
+ try:
+ func(self)
+ except:
+ if self.catcherrors:
+ self.record_error(func_name, 'While calling %s' % func.__name__)
+ else:
+ raise
+ for key, func, title, help in checks:
+ buckets = self._checks.get(key)
+ if buckets:
+ err = func(buckets)
+ if err:
+ self.errors[title] = (help, err)
+ self.store.checkpoint()
+ nberrors = sum(len(err[1]) for err in self.errors.values())
+ self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
+ % (self.store.nb_inserted_entities,
+ self.store.nb_inserted_types,
+ self.store.nb_inserted_relations,
+ nberrors))
+ if self.errors:
+ if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
+ from pprint import pformat
+ for errkey, error in self.errors.items():
+ self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
+ self.tell(pformat(sorted(error[1])))
+
+ def get_data(self, key):
+ return self.data.get(key)
+
+ def index(self, name, key, value, unique=False):
+ """create a new index
+
+ If unique is set to True, only first occurence will be kept not the following ones
+ """
+ if unique:
+ try:
+ if value in self.store.indexes[name][key]:
+ return
+ except KeyError:
+ # we're sure that one is the first occurence; so continue...
+ pass
+ self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
+
+ def tell(self, msg):
+ self._tell(msg)
+
+ def iter_and_commit(self, datakey):
+ """iter rows, triggering commit every self.commitevery iterations"""
+ return commit_every(self.commitevery, self.store, self.get_data(datakey))
+
+
+
+from datetime import datetime
+from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
+
+
+class NoHookRQLObjectStore(RQLObjectStore):
+ """ObjectStore that works with an actual RQL repository (production mode)"""
+ _rql = None # bw compat
+
+ def __init__(self, session, metagen=None, baseurl=None):
+ super(NoHookRQLObjectStore, self).__init__(session)
+ self.source = session.repo.system_source
+ self.rschema = session.repo.schema.rschema
+ self.add_relation = self.source.add_relation
+ if metagen is None:
+ metagen = MetaGenerator(session, baseurl)
+ self.metagen = metagen
+ self._nb_inserted_entities = 0
+ self._nb_inserted_types = 0
+ self._nb_inserted_relations = 0
+ self.rql = session.unsafe_execute
+
+ def create_entity(self, etype, **kwargs):
+ for k, v in kwargs.iteritems():
+ kwargs[k] = getattr(v, 'eid', v)
+ entity, rels = self.metagen.base_etype_dicts(etype)
+ entity = copy(entity)
+ entity._related_cache = {}
+ self.metagen.init_entity(entity)
+ entity.update(kwargs)
+ session = self.session
+ self.source.add_entity(session, entity)
+ self.source.add_info(session, entity, self.source, complete=False)
+ for rtype, targeteids in rels.iteritems():
+ # targeteids may be a single eid or a list of eids
+ inlined = self.rschema(rtype).inlined
+ try:
+ for targeteid in targeteids:
+ self.add_relation(session, entity.eid, rtype, targeteid,
+ inlined)
+ except TypeError:
+ self.add_relation(session, entity.eid, rtype, targeteids,
+ inlined)
+ self._nb_inserted_entities += 1
+ return entity
+
+ def relate(self, eid_from, rtype, eid_to):
+ assert not rtype.startswith('reverse_')
+ self.add_relation(self.session, eid_from, rtype, eid_to,
+ self.rschema(rtype).inlined)
+ self._nb_inserted_relations += 1
+
+ @property
+ def nb_inserted_entities(self):
+ return self._nb_inserted_entities
+ @property
+ def nb_inserted_types(self):
+ return self._nb_inserted_types
+ @property
+ def nb_inserted_relations(self):
+ return self._nb_inserted_relations
+
+ def _put(self, type, item):
+ raise RuntimeError('use create entity')
+
+
+class MetaGenerator(object):
+ def __init__(self, session, baseurl=None):
+ self.session = session
+ self.source = session.repo.system_source
+ self.time = datetime.now()
+ if baseurl is None:
+ config = session.vreg.config
+ baseurl = config['base-url'] or config.default_base_url()
+ if not baseurl[-1] == '/':
+ baseurl += '/'
+ self.baseurl = baseurl
+ # attributes/relations shared by all entities of the same type
+ self.etype_attrs = []
+ self.etype_rels = []
+ # attributes/relations specific to each entity
+ self.entity_attrs = ['eid', 'cwuri']
+ #self.entity_rels = [] XXX not handled (YAGNI?)
+ schema = session.vreg.schema
+ rschema = schema.rschema
+ for rtype in META_RTYPES:
+ if rtype in ('eid', 'cwuri') or rtype in VIRTUAL_RTYPES:
+ continue
+ if rschema(rtype).final:
+ self.etype_attrs.append(rtype)
+ else:
+ self.etype_rels.append(rtype)
+ if not schema._eid_index:
+ # test schema loaded from the fs
+ self.gen_is = self.test_gen_is
+ self.gen_is_instance_of = self.test_gen_is_instanceof
+
+ @cached
+ def base_etype_dicts(self, etype):
+ entity = self.session.vreg['etypes'].etype_class(etype)(self.session)
+ # entity are "surface" copied, avoid shared dict between copies
+ del entity.cw_extra_kwargs
+ for attr in self.etype_attrs:
+ entity[attr] = self.generate(entity, attr)
+ rels = {}
+ for rel in self.etype_rels:
+ rels[rel] = self.generate(entity, rel)
+ return entity, rels
+
+ def init_entity(self, entity):
+ for attr in self.entity_attrs:
+ entity[attr] = self.generate(entity, attr)
+ entity.eid = entity['eid']
+
+ def generate(self, entity, rtype):
+ return getattr(self, 'gen_%s' % rtype)(entity)
+
+ def gen_eid(self, entity):
+ return self.source.create_eid(self.session)
+
+ def gen_cwuri(self, entity):
+ return u'%seid/%s' % (self.baseurl, entity['eid'])
+
+ def gen_creation_date(self, entity):
+ return self.time
+ def gen_modification_date(self, entity):
+ return self.time
+
+ def gen_is(self, entity):
+ return entity.e_schema.eid
+ def gen_is_instance_of(self, entity):
+ eids = []
+ for etype in entity.e_schema.ancestors() + [entity.e_schema]:
+ eids.append(entity.e_schema.eid)
+ return eids
+
+ def gen_created_by(self, entity):
+ return self.session.user.eid
+ def gen_owned_by(self, entity):
+ return self.session.user.eid
+
+ # implementations of gen_is / gen_is_instance_of to use during test where
+ # schema has been loaded from the fs (hence entity type schema eids are not
+ # known)
+ def test_gen_is(self, entity):
+ from cubicweb.hooks.metadata import eschema_eid
+ return eschema_eid(self.session, entity.e_schema)
+ def test_gen_is_instanceof(self, entity):
+ from cubicweb.hooks.metadata import eschema_eid
+ eids = []
+ for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
+ eids.append(eschema_eid(self.session, eschema))
+ return eids
+
+
+################################################################################
+
+utf8csvreader = deprecated('[3.6] use ucsvreader instead')(ucsvreader)
+
+@deprecated('[3.6] use required')
+def nonempty(value):
+ return required(value)
+
+@deprecated("[3.6] use call_check_method('isdigit')")
+def alldigits(txt):
+ if txt.isdigit():
+ return txt
+ else:
+ return u''
+
+@deprecated("[3.7] too specific, will move away, copy me")
+def capitalize_if_unicase(txt):
+ if txt.isupper() or txt.islower():
+ return txt.capitalize()
+ return txt
+
+@deprecated("[3.7] too specific, will move away, copy me")
+def yesno(value):
+ """simple heuristic that returns boolean value
+
+ >>> yesno("Yes")
+ True
+ >>> yesno("oui")
+ True
+ >>> yesno("1")
+ True
+ >>> yesno("11")
+ True
+ >>> yesno("")
+ False
+ >>> yesno("Non")
+ False
+ >>> yesno("blablabla")
+ False
+ """
+ if value:
+ return value.lower()[0] in 'yo1'
+ return False
+
+@deprecated("[3.7] use call_check_method('isalpha')")
+def isalpha(value):
+ if value.isalpha():
+ return value
+ raise ValueError("not all characters in the string alphabetic")
+
+@deprecated("[3.7] use call_transform_method('upper')")
+def uppercase(txt):
+ return txt.upper()
+
+@deprecated("[3.7] use call_transform_method('lower')")
+def lowercase(txt):
+ return txt.lower()
+
+@deprecated("[3.7] use call_transform_method('replace', ' ', '')")
+def no_space(txt):
+ return txt.replace(' ','')
+
+@deprecated("[3.7] use call_transform_method('replace', u'\xa0', '')")
+def no_uspace(txt):
+ return txt.replace(u'\xa0','')
+
+@deprecated("[3.7] use call_transform_method('replace', '-', '')")
+def no_dash(txt):
+ return txt.replace('-','')
+
+@deprecated("[3.7] use call_transform_method('strip')")
+def strip(txt):
+ return txt.strip()
+
+@deprecated("[3.7] use call_transform_method('replace', ',', '.'), float")
+def decimal(value):
+ return comma_float(value)
+
+@deprecated('[3.7] use int builtin')
+def integer(value):
+ return int(value)
--- a/devtools/dataimport.py Tue Mar 09 11:05:29 2010 +0100
+++ b/devtools/dataimport.py Tue Mar 09 11:42:06 2010 +0100
@@ -1,752 +1,4 @@
-# -*- coding: utf-8 -*-
-"""This module provides tools to import tabular data.
-
-:organization: Logilab
-:copyright: 2001-2010 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
-:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
-:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
-
-
-Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
-
-.. sourcecode:: python
-
- from cubicweb.devtools.dataimport import *
- # define data generators
- GENERATORS = []
-
- USERS = [('Prenom', 'firstname', ()),
- ('Nom', 'surname', ()),
- ('Identifiant', 'login', ()),
- ]
-
- def gen_users(ctl):
- for row in ctl.get_data('utilisateurs'):
- entity = mk_entity(row, USERS)
- entity['upassword'] = u'motdepasse'
- ctl.check('login', entity['login'], None)
- ctl.store.add('CWUser', entity)
- email = {'address': row['email']}
- ctl.store.add('EmailAddress', email)
- ctl.store.relate(entity['eid'], 'use_email', email['eid'])
- ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['eid']})
-
- CHK = [('login', check_doubles, 'Utilisateurs Login',
- 'Deux utilisateurs ne devraient pas avoir le même login.'),
- ]
-
- GENERATORS.append( (gen_users, CHK) )
-
- # create controller
- ctl = CWImportController(RQLObjectStore())
- ctl.askerror = 1
- ctl.generators = GENERATORS
- ctl.store._checkpoint = checkpoint
- ctl.store._rql = rql
- ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
- # run
- ctl.run()
- sys.exit(0)
-
-
-.. BUG fichier à une colonne pose un problème de parsing
-.. TODO rollback()
-"""
-__docformat__ = "restructuredtext en"
-
-import sys
-import csv
-import traceback
-import os.path as osp
-from StringIO import StringIO
-from copy import copy
-
-from logilab.common import shellutils
-from logilab.common.date import strptime
-from logilab.common.decorators import cached
-from logilab.common.deprecation import deprecated
-
-
-def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
- skipfirst=False, withpb=True):
- """same as ucsvreader but a progress bar is displayed as we iter on rows"""
- if not osp.exists(filepath):
- raise Exception("file doesn't exists: %s" % filepath)
- rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
- if skipfirst:
- rowcount -= 1
- if withpb:
- pb = shellutils.ProgressBar(rowcount, 50)
- for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
- yield urow
- if withpb:
- pb.update()
- print ' %s rows imported' % rowcount
-
-def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
- skipfirst=False):
- """A csv reader that accepts files with any encoding and outputs unicode
- strings
- """
- it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
- if skipfirst:
- it.next()
- for row in it:
- yield [item.decode(encoding) for item in row]
-
-def commit_every(nbit, store, it):
- for i, x in enumerate(it):
- yield x
- if nbit is not None and i % nbit:
- store.checkpoint()
- if nbit is not None:
- store.checkpoint()
-
-def lazytable(reader):
- """The first row is taken to be the header of the table and
- used to output a dict for each row of data.
-
- >>> data = lazytable(utf8csvreader(open(filename)))
- """
- header = reader.next()
- for row in reader:
- yield dict(zip(header, row))
-
-def mk_entity(row, map):
- """Return a dict made from sanitized mapped values.
-
- ValidationError can be raised on unexpected values found in checkers
-
- >>> row = {'myname': u'dupont'}
- >>> map = [('myname', u'name', (capitalize_if_unicase,))]
- >>> mk_entity(row, map)
- {'name': u'Dupont'}
- >>> row = {'myname': u'dupont', 'optname': u''}
- >>> map = [('myname', u'name', (capitalize_if_unicase,)),
- ... ('optname', u'MARKER', (optional,))]
- >>> mk_entity(row, map)
- {'name': u'Dupont'}
- """
- res = {}
- assert isinstance(row, dict)
- assert isinstance(map, list)
- for src, dest, funcs in map:
- assert not (required in funcs and optional in funcs), \
- "optional and required checks are exclusive"
- res[dest] = row[src]
- try:
- for func in funcs:
- res[dest] = func(res[dest])
- if res[dest] is None:
- break
- except ValueError, err:
- raise ValueError('error with %r field: %s' % (src, err))
- return res
-
-
-# user interactions ############################################################
-
-def tell(msg):
- print msg
-
-def confirm(question):
- """A confirm function that asks for yes/no/abort and exits on abort."""
- answer = shellutils.ASK.ask(question, ('Y', 'n', 'abort'), 'Y')
- if answer == 'abort':
- sys.exit(1)
- return answer == 'Y'
-
-
-class catch_error(object):
- """Helper for @contextmanager decorator."""
-
- def __init__(self, ctl, key='unexpected error', msg=None):
- self.ctl = ctl
- self.key = key
- self.msg = msg
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- if type is not None:
- if issubclass(type, (KeyboardInterrupt, SystemExit)):
- return # re-raise
- if self.ctl.catcherrors:
- self.ctl.record_error(self.key, None, type, value, traceback)
- return True # silent
-
-
-# base sanitizing/coercing functions ###########################################
-
-def optional(value):
- """validation error will not been raised if you add this checker in chain"""
- if value:
- return value
- return None
-
-def required(value):
- """raise ValueError is value is empty
-
- This check should be often found in last position in the chain.
- """
- if value:
- return value
- raise ValueError("required")
-
-def todatetime(format='%d/%m/%Y'):
- """return a transformation function to turn string input value into a
- `datetime.datetime` instance, using given format.
-
- Follow it by `todate` or `totime` functions from `logilab.common.date` if
- you want a `date`/`time` instance instead of `datetime`.
- """
- def coerce(value):
- return strptime(value, format)
- return coerce
-
-def call_transform_method(methodname, *args, **kwargs):
- """return value returned by calling the given method on input"""
- def coerce(value):
- return getattr(value, methodname)(*args, **kwargs)
- return coerce
-
-def call_check_method(methodname, *args, **kwargs):
- """check value returned by calling the given method on input is true,
- else raise ValueError
- """
- def check(value):
- if getattr(value, methodname)(*args, **kwargs):
- return value
- raise ValueError('%s not verified on %r' % (methodname, value))
- return check
-
-# base integrity checking functions ############################################
-
-def check_doubles(buckets):
- """Extract the keys that have more than one item in their bucket."""
- return [(k, len(v)) for k, v in buckets.items() if len(v) > 1]
-
-def check_doubles_not_none(buckets):
- """Extract the keys that have more than one item in their bucket."""
- return [(k, len(v)) for k, v in buckets.items()
- if k is not None and len(v) > 1]
-
-
-# object stores #################################################################
-
-class ObjectStore(object):
- """Store objects in memory for *faster* validation (development mode)
-
- But it will not enforce the constraints of the schema and hence will miss some problems
-
- >>> store = ObjectStore()
- >>> user = {'login': 'johndoe'}
- >>> store.add('CWUser', user)
- >>> group = {'name': 'unknown'}
- >>> store.add('CWUser', group)
- >>> store.relate(user['eid'], 'in_group', group['eid'])
- """
- def __init__(self):
- self.items = []
- self.eids = {}
- self.types = {}
- self.relations = set()
- self.indexes = {}
- self._rql = None
- self._checkpoint = None
-
- def _put(self, type, item):
- self.items.append(item)
- return len(self.items) - 1
-
- def add(self, type, item):
- assert isinstance(item, dict), 'item is not a dict but a %s' % type(item)
- eid = item['eid'] = self._put(type, item)
- self.eids[eid] = item
- self.types.setdefault(type, []).append(eid)
-
- def relate(self, eid_from, rtype, eid_to, inlined=False):
- """Add new relation (reverse type support is available)
-
- >>> 1,2 = eid_from, eid_to
- >>> self.relate(eid_from, 'in_group', eid_to)
- 1, 'in_group', 2
- >>> self.relate(eid_from, 'reverse_in_group', eid_to)
- 2, 'in_group', 1
- """
- if rtype.startswith('reverse_'):
- eid_from, eid_to = eid_to, eid_from
- rtype = rtype[8:]
- relation = eid_from, rtype, eid_to
- self.relations.add(relation)
- return relation
-
- def build_index(self, name, type, func=None):
- index = {}
- if func is None or not callable(func):
- func = lambda x: x['eid']
- for eid in self.types[type]:
- index.setdefault(func(self.eids[eid]), []).append(eid)
- assert index, "new index '%s' cannot be empty" % name
- self.indexes[name] = index
-
- def build_rqlindex(self, name, type, key, rql, rql_params=False, func=None):
- """build an index by rql query
-
- rql should return eid in first column
- ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
- """
- rset = self.rql(rql, rql_params or {})
- for entity in rset.entities():
- getattr(entity, key) # autopopulate entity with key attribute
- self.eids[entity.eid] = dict(entity)
- if entity.eid not in self.types.setdefault(type, []):
- self.types[type].append(entity.eid)
- assert self.types[type], "new index type '%s' cannot be empty (0 record found)" % type
-
- # Build index with specified key
- func = lambda x: x[key]
- self.build_index(name, type, func)
-
- def fetch(self, name, key, unique=False, decorator=None):
- """
- decorator is a callable method or an iterator of callable methods (usually a lambda function)
- decorator=lambda x: x[:1] (first value is returned)
-
- We can use validation check function available in _entity
- """
- eids = self.indexes[name].get(key, [])
- if decorator is not None:
- if not hasattr(decorator, '__iter__'):
- decorator = (decorator,)
- for f in decorator:
- eids = f(eids)
- if unique:
- assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
- eids = eids[0] # FIXME maybe it's better to keep an iterator here ?
- return eids
-
- def find(self, type, key, value):
- for idx in self.types[type]:
- item = self.items[idx]
- if item[key] == value:
- yield item
-
- def rql(self, *args):
- if self._rql is not None:
- return self._rql(*args)
-
- def checkpoint(self):
- pass
-
- @property
- def nb_inserted_entities(self):
- return len(self.eids)
- @property
- def nb_inserted_types(self):
- return len(self.types)
- @property
- def nb_inserted_relations(self):
- return len(self.relations)
-
- @deprecated('[3.6] get_many() deprecated. Use fetch() instead')
- def get_many(self, name, key):
- return self.fetch(name, key, unique=False)
-
- @deprecated('[3.6] get_one() deprecated. Use fetch(..., unique=True) instead')
- def get_one(self, name, key):
- return self.fetch(name, key, unique=True)
-
-
-class RQLObjectStore(ObjectStore):
- """ObjectStore that works with an actual RQL repository (production mode)"""
- _rql = None # bw compat
-
- def __init__(self, session=None, checkpoint=None):
- ObjectStore.__init__(self)
- if session is not None:
- if not hasattr(session, 'set_pool'):
- # connection
- cnx = session
- session = session.request()
- session.set_pool = lambda : None
- checkpoint = checkpoint or cnx.commit
- else:
- session.set_pool()
- self.session = session
- self._checkpoint = checkpoint or session.commit
- elif checkpoint is not None:
- self._checkpoint = checkpoint
- # XXX .session
-
- def checkpoint(self):
- self._checkpoint()
- self.session.set_pool()
-
- def rql(self, *args):
- if self._rql is not None:
- return self._rql(*args)
- return self.session.execute(*args)
-
- def create_entity(self, *args, **kwargs):
- entity = self.session.create_entity(*args, **kwargs)
- self.eids[entity.eid] = entity
- self.types.setdefault(args[0], []).append(entity.eid)
- return entity
-
- def _put(self, type, item):
- query = ('INSERT %s X: ' % type) + ', '.join('X %s %%(%s)s' % (k, k)
- for k in item)
- return self.rql(query, item)[0][0]
-
- def relate(self, eid_from, rtype, eid_to, inlined=False):
- # if reverse relation is found, eids are exchanged
- eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
- eid_from, rtype, eid_to)
- self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
- {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
-
-
-# the import controller ########################################################
-
-class CWImportController(object):
- """Controller of the data import process.
-
- >>> ctl = CWImportController(store)
- >>> ctl.generators = list_of_data_generators
- >>> ctl.data = dict_of_data_tables
- >>> ctl.run()
- """
-
- def __init__(self, store, askerror=0, catcherrors=None, tell=tell,
- commitevery=50):
- self.store = store
- self.generators = None
- self.data = {}
- self.errors = None
- self.askerror = askerror
- if catcherrors is None:
- catcherrors = askerror
- self.catcherrors = catcherrors
- self.commitevery = commitevery # set to None to do a single commit
- self._tell = tell
-
- def check(self, type, key, value):
- self._checks.setdefault(type, {}).setdefault(key, []).append(value)
-
- def check_map(self, entity, key, map, default):
- try:
- entity[key] = map[entity[key]]
- except KeyError:
- self.check(key, entity[key], None)
- entity[key] = default
-
- def record_error(self, key, msg=None, type=None, value=None, tb=None):
- tmp = StringIO()
- if type is None:
- traceback.print_exc(file=tmp)
- else:
- traceback.print_exception(type, value, tb, file=tmp)
- print tmp.getvalue()
- # use a list to avoid counting a <nb lines> errors instead of one
- errorlog = self.errors.setdefault(key, [])
- if msg is None:
- errorlog.append(tmp.getvalue().splitlines())
- else:
- errorlog.append( (msg, tmp.getvalue().splitlines()) )
-
- def run(self):
- self.errors = {}
- for func, checks in self.generators:
- self._checks = {}
- func_name = func.__name__[4:] # XXX
- self.tell("Import '%s'..." % func_name)
- try:
- func(self)
- except:
- if self.catcherrors:
- self.record_error(func_name, 'While calling %s' % func.__name__)
- else:
- raise
- for key, func, title, help in checks:
- buckets = self._checks.get(key)
- if buckets:
- err = func(buckets)
- if err:
- self.errors[title] = (help, err)
- self.store.checkpoint()
- nberrors = sum(len(err[1]) for err in self.errors.values())
- self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
- % (self.store.nb_inserted_entities,
- self.store.nb_inserted_types,
- self.store.nb_inserted_relations,
- nberrors))
- if self.errors:
- if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
- from pprint import pformat
- for errkey, error in self.errors.items():
- self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
- self.tell(pformat(sorted(error[1])))
-
- def get_data(self, key):
- return self.data.get(key)
-
- def index(self, name, key, value, unique=False):
- """create a new index
-
- If unique is set to True, only first occurence will be kept not the following ones
- """
- if unique:
- try:
- if value in self.store.indexes[name][key]:
- return
- except KeyError:
- # we're sure that one is the first occurence; so continue...
- pass
- self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
-
- def tell(self, msg):
- self._tell(msg)
-
- def iter_and_commit(self, datakey):
- """iter rows, triggering commit every self.commitevery iterations"""
- return commit_every(self.commitevery, self.store, self.get_data(datakey))
-
-
-
-from datetime import datetime
-from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
-
-
-class NoHookRQLObjectStore(RQLObjectStore):
- """ObjectStore that works with an actual RQL repository (production mode)"""
- _rql = None # bw compat
-
- def __init__(self, session, metagen=None, baseurl=None):
- super(NoHookRQLObjectStore, self).__init__(session)
- self.source = session.repo.system_source
- self.rschema = session.repo.schema.rschema
- self.add_relation = self.source.add_relation
- if metagen is None:
- metagen = MetaGenerator(session, baseurl)
- self.metagen = metagen
- self._nb_inserted_entities = 0
- self._nb_inserted_types = 0
- self._nb_inserted_relations = 0
- self.rql = session.unsafe_execute
-
- def create_entity(self, etype, **kwargs):
- for k, v in kwargs.iteritems():
- kwargs[k] = getattr(v, 'eid', v)
- entity, rels = self.metagen.base_etype_dicts(etype)
- entity = copy(entity)
- entity._related_cache = {}
- self.metagen.init_entity(entity)
- entity.update(kwargs)
- session = self.session
- self.source.add_entity(session, entity)
- self.source.add_info(session, entity, self.source, complete=False)
- for rtype, targeteids in rels.iteritems():
- # targeteids may be a single eid or a list of eids
- inlined = self.rschema(rtype).inlined
- try:
- for targeteid in targeteids:
- self.add_relation(session, entity.eid, rtype, targeteid,
- inlined)
- except TypeError:
- self.add_relation(session, entity.eid, rtype, targeteids,
- inlined)
- self._nb_inserted_entities += 1
- return entity
-
- def relate(self, eid_from, rtype, eid_to):
- assert not rtype.startswith('reverse_')
- self.add_relation(self.session, eid_from, rtype, eid_to,
- self.rschema(rtype).inlined)
- self._nb_inserted_relations += 1
-
- @property
- def nb_inserted_entities(self):
- return self._nb_inserted_entities
- @property
- def nb_inserted_types(self):
- return self._nb_inserted_types
- @property
- def nb_inserted_relations(self):
- return self._nb_inserted_relations
-
- def _put(self, type, item):
- raise RuntimeError('use create entity')
-
-
-class MetaGenerator(object):
- def __init__(self, session, baseurl=None):
- self.session = session
- self.source = session.repo.system_source
- self.time = datetime.now()
- if baseurl is None:
- config = session.vreg.config
- baseurl = config['base-url'] or config.default_base_url()
- if not baseurl[-1] == '/':
- baseurl += '/'
- self.baseurl = baseurl
- # attributes/relations shared by all entities of the same type
- self.etype_attrs = []
- self.etype_rels = []
- # attributes/relations specific to each entity
- self.entity_attrs = ['eid', 'cwuri']
- #self.entity_rels = [] XXX not handled (YAGNI?)
- schema = session.vreg.schema
- rschema = schema.rschema
- for rtype in META_RTYPES:
- if rtype in ('eid', 'cwuri') or rtype in VIRTUAL_RTYPES:
- continue
- if rschema(rtype).final:
- self.etype_attrs.append(rtype)
- else:
- self.etype_rels.append(rtype)
- if not schema._eid_index:
- # test schema loaded from the fs
- self.gen_is = self.test_gen_is
- self.gen_is_instance_of = self.test_gen_is_instanceof
-
- @cached
- def base_etype_dicts(self, etype):
- entity = self.session.vreg['etypes'].etype_class(etype)(self.session)
- # entity are "surface" copied, avoid shared dict between copies
- del entity.cw_extra_kwargs
- for attr in self.etype_attrs:
- entity[attr] = self.generate(entity, attr)
- rels = {}
- for rel in self.etype_rels:
- rels[rel] = self.generate(entity, rel)
- return entity, rels
-
- def init_entity(self, entity):
- for attr in self.entity_attrs:
- entity[attr] = self.generate(entity, attr)
- entity.eid = entity['eid']
-
- def generate(self, entity, rtype):
- return getattr(self, 'gen_%s' % rtype)(entity)
-
- def gen_eid(self, entity):
- return self.source.create_eid(self.session)
-
- def gen_cwuri(self, entity):
- return u'%seid/%s' % (self.baseurl, entity['eid'])
-
- def gen_creation_date(self, entity):
- return self.time
- def gen_modification_date(self, entity):
- return self.time
-
- def gen_is(self, entity):
- return entity.e_schema.eid
- def gen_is_instance_of(self, entity):
- eids = []
- for etype in entity.e_schema.ancestors() + [entity.e_schema]:
- eids.append(entity.e_schema.eid)
- return eids
-
- def gen_created_by(self, entity):
- return self.session.user.eid
- def gen_owned_by(self, entity):
- return self.session.user.eid
-
- # implementations of gen_is / gen_is_instance_of to use during test where
- # schema has been loaded from the fs (hence entity type schema eids are not
- # known)
- def test_gen_is(self, entity):
- from cubicweb.hooks.metadata import eschema_eid
- return eschema_eid(self.session, entity.e_schema)
- def test_gen_is_instanceof(self, entity):
- from cubicweb.hooks.metadata import eschema_eid
- eids = []
- for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
- eids.append(eschema_eid(self.session, eschema))
- return eids
-
-
-################################################################################
-
-utf8csvreader = deprecated('[3.6] use ucsvreader instead')(ucsvreader)
-
-@deprecated('[3.6] use required')
-def nonempty(value):
- return required(value)
-
-@deprecated("[3.6] use call_check_method('isdigit')")
-def alldigits(txt):
- if txt.isdigit():
- return txt
- else:
- return u''
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def capitalize_if_unicase(txt):
- if txt.isupper() or txt.islower():
- return txt.capitalize()
- return txt
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def yesno(value):
- """simple heuristic that returns boolean value
-
- >>> yesno("Yes")
- True
- >>> yesno("oui")
- True
- >>> yesno("1")
- True
- >>> yesno("11")
- True
- >>> yesno("")
- False
- >>> yesno("Non")
- False
- >>> yesno("blablabla")
- False
- """
- if value:
- return value.lower()[0] in 'yo1'
- return False
-
-@deprecated("[3.7] use call_check_method('isalpha')")
-def isalpha(value):
- if value.isalpha():
- return value
- raise ValueError("not all characters in the string alphabetic")
-
-@deprecated("[3.7] use call_transform_method('upper')")
-def uppercase(txt):
- return txt.upper()
-
-@deprecated("[3.7] use call_transform_method('lower')")
-def lowercase(txt):
- return txt.lower()
-
-@deprecated("[3.7] use call_transform_method('replace', ' ', '')")
-def no_space(txt):
- return txt.replace(' ','')
-
-@deprecated("[3.7] use call_transform_method('replace', u'\xa0', '')")
-def no_uspace(txt):
- return txt.replace(u'\xa0','')
-
-@deprecated("[3.7] use call_transform_method('replace', '-', '')")
-def no_dash(txt):
- return txt.replace('-','')
-
-@deprecated("[3.7] use call_transform_method('strip')")
-def strip(txt):
- return txt.strip()
-
-@deprecated("[3.7] use call_transform_method('replace', ',', '.'), float")
-def decimal(value):
- return comma_float(value)
-
-@deprecated('[3.7] use int builtin')
-def integer(value):
- return int(value)
+# pylint: disable-msg=W0614,W0401
+from warnings import warn
+warn('moved to cubicweb.dataimport', DeprecationWarning, stacklevel=2)
+from cubicweb.dataimport import *