--- a/devtools/dataimport.py Tue Mar 09 11:05:29 2010 +0100
+++ b/devtools/dataimport.py Tue Mar 09 11:42:06 2010 +0100
@@ -1,752 +1,4 @@
-# -*- coding: utf-8 -*-
-"""This module provides tools to import tabular data.
-
-:organization: Logilab
-:copyright: 2001-2010 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
-:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
-:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
-
-
-Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
-
-.. sourcecode:: python
-
- from cubicweb.devtools.dataimport import *
- # define data generators
- GENERATORS = []
-
- USERS = [('Prenom', 'firstname', ()),
- ('Nom', 'surname', ()),
- ('Identifiant', 'login', ()),
- ]
-
- def gen_users(ctl):
- for row in ctl.get_data('utilisateurs'):
- entity = mk_entity(row, USERS)
- entity['upassword'] = u'motdepasse'
- ctl.check('login', entity['login'], None)
- ctl.store.add('CWUser', entity)
- email = {'address': row['email']}
- ctl.store.add('EmailAddress', email)
- ctl.store.relate(entity['eid'], 'use_email', email['eid'])
- ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['eid']})
-
- CHK = [('login', check_doubles, 'Utilisateurs Login',
- 'Deux utilisateurs ne devraient pas avoir le même login.'),
- ]
-
- GENERATORS.append( (gen_users, CHK) )
-
- # create controller
- ctl = CWImportController(RQLObjectStore())
- ctl.askerror = 1
- ctl.generators = GENERATORS
- ctl.store._checkpoint = checkpoint
- ctl.store._rql = rql
- ctl.data['utilisateurs'] = lazytable(utf8csvreader(open('users.csv')))
- # run
- ctl.run()
- sys.exit(0)
-
-
-.. BUG fichier à une colonne pose un problème de parsing
-.. TODO rollback()
-"""
-__docformat__ = "restructuredtext en"
-
-import sys
-import csv
-import traceback
-import os.path as osp
-from StringIO import StringIO
-from copy import copy
-
-from logilab.common import shellutils
-from logilab.common.date import strptime
-from logilab.common.decorators import cached
-from logilab.common.deprecation import deprecated
-
-
-def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
- skipfirst=False, withpb=True):
- """same as ucsvreader but a progress bar is displayed as we iter on rows"""
- if not osp.exists(filepath):
- raise Exception("file doesn't exists: %s" % filepath)
- rowcount = int(shellutils.Execute('wc -l "%s"' % filepath).out.strip().split()[0])
- if skipfirst:
- rowcount -= 1
- if withpb:
- pb = shellutils.ProgressBar(rowcount, 50)
- for urow in ucsvreader(file(filepath), encoding, separator, quote, skipfirst):
- yield urow
- if withpb:
- pb.update()
- print ' %s rows imported' % rowcount
-
-def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
- skipfirst=False):
- """A csv reader that accepts files with any encoding and outputs unicode
- strings
- """
- it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
- if skipfirst:
- it.next()
- for row in it:
- yield [item.decode(encoding) for item in row]
-
-def commit_every(nbit, store, it):
- for i, x in enumerate(it):
- yield x
- if nbit is not None and i % nbit:
- store.checkpoint()
- if nbit is not None:
- store.checkpoint()
-
-def lazytable(reader):
- """The first row is taken to be the header of the table and
- used to output a dict for each row of data.
-
- >>> data = lazytable(utf8csvreader(open(filename)))
- """
- header = reader.next()
- for row in reader:
- yield dict(zip(header, row))
-
-def mk_entity(row, map):
- """Return a dict made from sanitized mapped values.
-
- ValidationError can be raised on unexpected values found in checkers
-
- >>> row = {'myname': u'dupont'}
- >>> map = [('myname', u'name', (capitalize_if_unicase,))]
- >>> mk_entity(row, map)
- {'name': u'Dupont'}
- >>> row = {'myname': u'dupont', 'optname': u''}
- >>> map = [('myname', u'name', (capitalize_if_unicase,)),
- ... ('optname', u'MARKER', (optional,))]
- >>> mk_entity(row, map)
- {'name': u'Dupont'}
- """
- res = {}
- assert isinstance(row, dict)
- assert isinstance(map, list)
- for src, dest, funcs in map:
- assert not (required in funcs and optional in funcs), \
- "optional and required checks are exclusive"
- res[dest] = row[src]
- try:
- for func in funcs:
- res[dest] = func(res[dest])
- if res[dest] is None:
- break
- except ValueError, err:
- raise ValueError('error with %r field: %s' % (src, err))
- return res
-
-
-# user interactions ############################################################
-
-def tell(msg):
- print msg
-
-def confirm(question):
- """A confirm function that asks for yes/no/abort and exits on abort."""
- answer = shellutils.ASK.ask(question, ('Y', 'n', 'abort'), 'Y')
- if answer == 'abort':
- sys.exit(1)
- return answer == 'Y'
-
-
-class catch_error(object):
- """Helper for @contextmanager decorator."""
-
- def __init__(self, ctl, key='unexpected error', msg=None):
- self.ctl = ctl
- self.key = key
- self.msg = msg
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- if type is not None:
- if issubclass(type, (KeyboardInterrupt, SystemExit)):
- return # re-raise
- if self.ctl.catcherrors:
- self.ctl.record_error(self.key, None, type, value, traceback)
- return True # silent
-
-
-# base sanitizing/coercing functions ###########################################
-
-def optional(value):
- """validation error will not been raised if you add this checker in chain"""
- if value:
- return value
- return None
-
-def required(value):
- """raise ValueError is value is empty
-
- This check should be often found in last position in the chain.
- """
- if value:
- return value
- raise ValueError("required")
-
-def todatetime(format='%d/%m/%Y'):
- """return a transformation function to turn string input value into a
- `datetime.datetime` instance, using given format.
-
- Follow it by `todate` or `totime` functions from `logilab.common.date` if
- you want a `date`/`time` instance instead of `datetime`.
- """
- def coerce(value):
- return strptime(value, format)
- return coerce
-
-def call_transform_method(methodname, *args, **kwargs):
- """return value returned by calling the given method on input"""
- def coerce(value):
- return getattr(value, methodname)(*args, **kwargs)
- return coerce
-
-def call_check_method(methodname, *args, **kwargs):
- """check value returned by calling the given method on input is true,
- else raise ValueError
- """
- def check(value):
- if getattr(value, methodname)(*args, **kwargs):
- return value
- raise ValueError('%s not verified on %r' % (methodname, value))
- return check
-
-# base integrity checking functions ############################################
-
-def check_doubles(buckets):
- """Extract the keys that have more than one item in their bucket."""
- return [(k, len(v)) for k, v in buckets.items() if len(v) > 1]
-
-def check_doubles_not_none(buckets):
- """Extract the keys that have more than one item in their bucket."""
- return [(k, len(v)) for k, v in buckets.items()
- if k is not None and len(v) > 1]
-
-
-# object stores #################################################################
-
-class ObjectStore(object):
- """Store objects in memory for *faster* validation (development mode)
-
- But it will not enforce the constraints of the schema and hence will miss some problems
-
- >>> store = ObjectStore()
- >>> user = {'login': 'johndoe'}
- >>> store.add('CWUser', user)
- >>> group = {'name': 'unknown'}
- >>> store.add('CWUser', group)
- >>> store.relate(user['eid'], 'in_group', group['eid'])
- """
- def __init__(self):
- self.items = []
- self.eids = {}
- self.types = {}
- self.relations = set()
- self.indexes = {}
- self._rql = None
- self._checkpoint = None
-
- def _put(self, type, item):
- self.items.append(item)
- return len(self.items) - 1
-
- def add(self, type, item):
- assert isinstance(item, dict), 'item is not a dict but a %s' % type(item)
- eid = item['eid'] = self._put(type, item)
- self.eids[eid] = item
- self.types.setdefault(type, []).append(eid)
-
- def relate(self, eid_from, rtype, eid_to, inlined=False):
- """Add new relation (reverse type support is available)
-
- >>> 1,2 = eid_from, eid_to
- >>> self.relate(eid_from, 'in_group', eid_to)
- 1, 'in_group', 2
- >>> self.relate(eid_from, 'reverse_in_group', eid_to)
- 2, 'in_group', 1
- """
- if rtype.startswith('reverse_'):
- eid_from, eid_to = eid_to, eid_from
- rtype = rtype[8:]
- relation = eid_from, rtype, eid_to
- self.relations.add(relation)
- return relation
-
- def build_index(self, name, type, func=None):
- index = {}
- if func is None or not callable(func):
- func = lambda x: x['eid']
- for eid in self.types[type]:
- index.setdefault(func(self.eids[eid]), []).append(eid)
- assert index, "new index '%s' cannot be empty" % name
- self.indexes[name] = index
-
- def build_rqlindex(self, name, type, key, rql, rql_params=False, func=None):
- """build an index by rql query
-
- rql should return eid in first column
- ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
- """
- rset = self.rql(rql, rql_params or {})
- for entity in rset.entities():
- getattr(entity, key) # autopopulate entity with key attribute
- self.eids[entity.eid] = dict(entity)
- if entity.eid not in self.types.setdefault(type, []):
- self.types[type].append(entity.eid)
- assert self.types[type], "new index type '%s' cannot be empty (0 record found)" % type
-
- # Build index with specified key
- func = lambda x: x[key]
- self.build_index(name, type, func)
-
- def fetch(self, name, key, unique=False, decorator=None):
- """
- decorator is a callable method or an iterator of callable methods (usually a lambda function)
- decorator=lambda x: x[:1] (first value is returned)
-
- We can use validation check function available in _entity
- """
- eids = self.indexes[name].get(key, [])
- if decorator is not None:
- if not hasattr(decorator, '__iter__'):
- decorator = (decorator,)
- for f in decorator:
- eids = f(eids)
- if unique:
- assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
- eids = eids[0] # FIXME maybe it's better to keep an iterator here ?
- return eids
-
- def find(self, type, key, value):
- for idx in self.types[type]:
- item = self.items[idx]
- if item[key] == value:
- yield item
-
- def rql(self, *args):
- if self._rql is not None:
- return self._rql(*args)
-
- def checkpoint(self):
- pass
-
- @property
- def nb_inserted_entities(self):
- return len(self.eids)
- @property
- def nb_inserted_types(self):
- return len(self.types)
- @property
- def nb_inserted_relations(self):
- return len(self.relations)
-
- @deprecated('[3.6] get_many() deprecated. Use fetch() instead')
- def get_many(self, name, key):
- return self.fetch(name, key, unique=False)
-
- @deprecated('[3.6] get_one() deprecated. Use fetch(..., unique=True) instead')
- def get_one(self, name, key):
- return self.fetch(name, key, unique=True)
-
-
-class RQLObjectStore(ObjectStore):
- """ObjectStore that works with an actual RQL repository (production mode)"""
- _rql = None # bw compat
-
- def __init__(self, session=None, checkpoint=None):
- ObjectStore.__init__(self)
- if session is not None:
- if not hasattr(session, 'set_pool'):
- # connection
- cnx = session
- session = session.request()
- session.set_pool = lambda : None
- checkpoint = checkpoint or cnx.commit
- else:
- session.set_pool()
- self.session = session
- self._checkpoint = checkpoint or session.commit
- elif checkpoint is not None:
- self._checkpoint = checkpoint
- # XXX .session
-
- def checkpoint(self):
- self._checkpoint()
- self.session.set_pool()
-
- def rql(self, *args):
- if self._rql is not None:
- return self._rql(*args)
- return self.session.execute(*args)
-
- def create_entity(self, *args, **kwargs):
- entity = self.session.create_entity(*args, **kwargs)
- self.eids[entity.eid] = entity
- self.types.setdefault(args[0], []).append(entity.eid)
- return entity
-
- def _put(self, type, item):
- query = ('INSERT %s X: ' % type) + ', '.join('X %s %%(%s)s' % (k, k)
- for k in item)
- return self.rql(query, item)[0][0]
-
- def relate(self, eid_from, rtype, eid_to, inlined=False):
- # if reverse relation is found, eids are exchanged
- eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
- eid_from, rtype, eid_to)
- self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
- {'x': int(eid_from), 'y': int(eid_to)}, ('x', 'y'))
-
-
-# the import controller ########################################################
-
-class CWImportController(object):
- """Controller of the data import process.
-
- >>> ctl = CWImportController(store)
- >>> ctl.generators = list_of_data_generators
- >>> ctl.data = dict_of_data_tables
- >>> ctl.run()
- """
-
- def __init__(self, store, askerror=0, catcherrors=None, tell=tell,
- commitevery=50):
- self.store = store
- self.generators = None
- self.data = {}
- self.errors = None
- self.askerror = askerror
- if catcherrors is None:
- catcherrors = askerror
- self.catcherrors = catcherrors
- self.commitevery = commitevery # set to None to do a single commit
- self._tell = tell
-
- def check(self, type, key, value):
- self._checks.setdefault(type, {}).setdefault(key, []).append(value)
-
- def check_map(self, entity, key, map, default):
- try:
- entity[key] = map[entity[key]]
- except KeyError:
- self.check(key, entity[key], None)
- entity[key] = default
-
- def record_error(self, key, msg=None, type=None, value=None, tb=None):
- tmp = StringIO()
- if type is None:
- traceback.print_exc(file=tmp)
- else:
- traceback.print_exception(type, value, tb, file=tmp)
- print tmp.getvalue()
- # use a list to avoid counting a <nb lines> errors instead of one
- errorlog = self.errors.setdefault(key, [])
- if msg is None:
- errorlog.append(tmp.getvalue().splitlines())
- else:
- errorlog.append( (msg, tmp.getvalue().splitlines()) )
-
- def run(self):
- self.errors = {}
- for func, checks in self.generators:
- self._checks = {}
- func_name = func.__name__[4:] # XXX
- self.tell("Import '%s'..." % func_name)
- try:
- func(self)
- except:
- if self.catcherrors:
- self.record_error(func_name, 'While calling %s' % func.__name__)
- else:
- raise
- for key, func, title, help in checks:
- buckets = self._checks.get(key)
- if buckets:
- err = func(buckets)
- if err:
- self.errors[title] = (help, err)
- self.store.checkpoint()
- nberrors = sum(len(err[1]) for err in self.errors.values())
- self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
- % (self.store.nb_inserted_entities,
- self.store.nb_inserted_types,
- self.store.nb_inserted_relations,
- nberrors))
- if self.errors:
- if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
- from pprint import pformat
- for errkey, error in self.errors.items():
- self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
- self.tell(pformat(sorted(error[1])))
-
- def get_data(self, key):
- return self.data.get(key)
-
- def index(self, name, key, value, unique=False):
- """create a new index
-
- If unique is set to True, only first occurence will be kept not the following ones
- """
- if unique:
- try:
- if value in self.store.indexes[name][key]:
- return
- except KeyError:
- # we're sure that one is the first occurence; so continue...
- pass
- self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
-
- def tell(self, msg):
- self._tell(msg)
-
- def iter_and_commit(self, datakey):
- """iter rows, triggering commit every self.commitevery iterations"""
- return commit_every(self.commitevery, self.store, self.get_data(datakey))
-
-
-
-from datetime import datetime
-from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
-
-
-class NoHookRQLObjectStore(RQLObjectStore):
- """ObjectStore that works with an actual RQL repository (production mode)"""
- _rql = None # bw compat
-
- def __init__(self, session, metagen=None, baseurl=None):
- super(NoHookRQLObjectStore, self).__init__(session)
- self.source = session.repo.system_source
- self.rschema = session.repo.schema.rschema
- self.add_relation = self.source.add_relation
- if metagen is None:
- metagen = MetaGenerator(session, baseurl)
- self.metagen = metagen
- self._nb_inserted_entities = 0
- self._nb_inserted_types = 0
- self._nb_inserted_relations = 0
- self.rql = session.unsafe_execute
-
- def create_entity(self, etype, **kwargs):
- for k, v in kwargs.iteritems():
- kwargs[k] = getattr(v, 'eid', v)
- entity, rels = self.metagen.base_etype_dicts(etype)
- entity = copy(entity)
- entity._related_cache = {}
- self.metagen.init_entity(entity)
- entity.update(kwargs)
- session = self.session
- self.source.add_entity(session, entity)
- self.source.add_info(session, entity, self.source, complete=False)
- for rtype, targeteids in rels.iteritems():
- # targeteids may be a single eid or a list of eids
- inlined = self.rschema(rtype).inlined
- try:
- for targeteid in targeteids:
- self.add_relation(session, entity.eid, rtype, targeteid,
- inlined)
- except TypeError:
- self.add_relation(session, entity.eid, rtype, targeteids,
- inlined)
- self._nb_inserted_entities += 1
- return entity
-
- def relate(self, eid_from, rtype, eid_to):
- assert not rtype.startswith('reverse_')
- self.add_relation(self.session, eid_from, rtype, eid_to,
- self.rschema(rtype).inlined)
- self._nb_inserted_relations += 1
-
- @property
- def nb_inserted_entities(self):
- return self._nb_inserted_entities
- @property
- def nb_inserted_types(self):
- return self._nb_inserted_types
- @property
- def nb_inserted_relations(self):
- return self._nb_inserted_relations
-
- def _put(self, type, item):
- raise RuntimeError('use create entity')
-
-
-class MetaGenerator(object):
- def __init__(self, session, baseurl=None):
- self.session = session
- self.source = session.repo.system_source
- self.time = datetime.now()
- if baseurl is None:
- config = session.vreg.config
- baseurl = config['base-url'] or config.default_base_url()
- if not baseurl[-1] == '/':
- baseurl += '/'
- self.baseurl = baseurl
- # attributes/relations shared by all entities of the same type
- self.etype_attrs = []
- self.etype_rels = []
- # attributes/relations specific to each entity
- self.entity_attrs = ['eid', 'cwuri']
- #self.entity_rels = [] XXX not handled (YAGNI?)
- schema = session.vreg.schema
- rschema = schema.rschema
- for rtype in META_RTYPES:
- if rtype in ('eid', 'cwuri') or rtype in VIRTUAL_RTYPES:
- continue
- if rschema(rtype).final:
- self.etype_attrs.append(rtype)
- else:
- self.etype_rels.append(rtype)
- if not schema._eid_index:
- # test schema loaded from the fs
- self.gen_is = self.test_gen_is
- self.gen_is_instance_of = self.test_gen_is_instanceof
-
- @cached
- def base_etype_dicts(self, etype):
- entity = self.session.vreg['etypes'].etype_class(etype)(self.session)
- # entity are "surface" copied, avoid shared dict between copies
- del entity.cw_extra_kwargs
- for attr in self.etype_attrs:
- entity[attr] = self.generate(entity, attr)
- rels = {}
- for rel in self.etype_rels:
- rels[rel] = self.generate(entity, rel)
- return entity, rels
-
- def init_entity(self, entity):
- for attr in self.entity_attrs:
- entity[attr] = self.generate(entity, attr)
- entity.eid = entity['eid']
-
- def generate(self, entity, rtype):
- return getattr(self, 'gen_%s' % rtype)(entity)
-
- def gen_eid(self, entity):
- return self.source.create_eid(self.session)
-
- def gen_cwuri(self, entity):
- return u'%seid/%s' % (self.baseurl, entity['eid'])
-
- def gen_creation_date(self, entity):
- return self.time
- def gen_modification_date(self, entity):
- return self.time
-
- def gen_is(self, entity):
- return entity.e_schema.eid
- def gen_is_instance_of(self, entity):
- eids = []
- for etype in entity.e_schema.ancestors() + [entity.e_schema]:
- eids.append(entity.e_schema.eid)
- return eids
-
- def gen_created_by(self, entity):
- return self.session.user.eid
- def gen_owned_by(self, entity):
- return self.session.user.eid
-
- # implementations of gen_is / gen_is_instance_of to use during test where
- # schema has been loaded from the fs (hence entity type schema eids are not
- # known)
- def test_gen_is(self, entity):
- from cubicweb.hooks.metadata import eschema_eid
- return eschema_eid(self.session, entity.e_schema)
- def test_gen_is_instanceof(self, entity):
- from cubicweb.hooks.metadata import eschema_eid
- eids = []
- for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
- eids.append(eschema_eid(self.session, eschema))
- return eids
-
-
-################################################################################
-
-utf8csvreader = deprecated('[3.6] use ucsvreader instead')(ucsvreader)
-
-@deprecated('[3.6] use required')
-def nonempty(value):
- return required(value)
-
-@deprecated("[3.6] use call_check_method('isdigit')")
-def alldigits(txt):
- if txt.isdigit():
- return txt
- else:
- return u''
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def capitalize_if_unicase(txt):
- if txt.isupper() or txt.islower():
- return txt.capitalize()
- return txt
-
-@deprecated("[3.7] too specific, will move away, copy me")
-def yesno(value):
- """simple heuristic that returns boolean value
-
- >>> yesno("Yes")
- True
- >>> yesno("oui")
- True
- >>> yesno("1")
- True
- >>> yesno("11")
- True
- >>> yesno("")
- False
- >>> yesno("Non")
- False
- >>> yesno("blablabla")
- False
- """
- if value:
- return value.lower()[0] in 'yo1'
- return False
-
-@deprecated("[3.7] use call_check_method('isalpha')")
-def isalpha(value):
- if value.isalpha():
- return value
- raise ValueError("not all characters in the string alphabetic")
-
-@deprecated("[3.7] use call_transform_method('upper')")
-def uppercase(txt):
- return txt.upper()
-
-@deprecated("[3.7] use call_transform_method('lower')")
-def lowercase(txt):
- return txt.lower()
-
-@deprecated("[3.7] use call_transform_method('replace', ' ', '')")
-def no_space(txt):
- return txt.replace(' ','')
-
-@deprecated("[3.7] use call_transform_method('replace', u'\xa0', '')")
-def no_uspace(txt):
- return txt.replace(u'\xa0','')
-
-@deprecated("[3.7] use call_transform_method('replace', '-', '')")
-def no_dash(txt):
- return txt.replace('-','')
-
-@deprecated("[3.7] use call_transform_method('strip')")
-def strip(txt):
- return txt.strip()
-
-@deprecated("[3.7] use call_transform_method('replace', ',', '.'), float")
-def decimal(value):
- return comma_float(value)
-
-@deprecated('[3.7] use int builtin')
-def integer(value):
- return int(value)
+# pylint: disable-msg=W0614,W0401
+from warnings import warn
+warn('moved to cubicweb.dataimport', DeprecationWarning, stacklevel=2)
+from cubicweb.dataimport import *