# -*- coding: utf-8 -*-
# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""This module provides tools to import tabular data.
Example of use (run this with `cubicweb-ctl shell instance import-script.py`):
.. sourcecode:: python
from cubicweb.dataimport import *
# define data generators
GENERATORS = []
USERS = [('Prenom', 'firstname', ()),
('Nom', 'surname', ()),
('Identifiant', 'login', ()),
]
def gen_users(ctl):
for row in ctl.iter_and_commit('utilisateurs'):
entity = mk_entity(row, USERS)
entity['upassword'] = 'motdepasse'
ctl.check('login', entity['login'], None)
entity = ctl.store.create_entity('CWUser', **entity)
email = ctl.store.create_entity('EmailAddress', address=row['email'])
ctl.store.relate(entity.eid, 'use_email', email.eid)
ctl.store.rql('SET U in_group G WHERE G name "users", U eid %(x)s', {'x':entity['eid']})
CHK = [('login', check_doubles, 'Utilisateurs Login',
'Deux utilisateurs ne devraient pas avoir le même login.'),
]
GENERATORS.append( (gen_users, CHK) )
# create controller
if 'cnx' in globals():
ctl = CWImportController(RQLObjectStore(cnx))
else:
print 'debug mode (not connected)'
print 'run through cubicweb-ctl shell to access an instance'
ctl = CWImportController(ObjectStore())
ctl.askerror = 1
ctl.generators = GENERATORS
ctl.data['utilisateurs'] = lazytable(ucsvreader(open('users.csv')))
# run
ctl.run()
.. BUG file with one column are not parsable
.. TODO rollback() invocation is not possible yet
"""
__docformat__ = "restructuredtext en"
import sys
import csv
import traceback
import os.path as osp
from StringIO import StringIO
from copy import copy
from datetime import datetime
from logilab.common import shellutils, attrdict
from logilab.common.date import strptime
from logilab.common.decorators import cached
from logilab.common.deprecation import deprecated
from cubicweb import QueryError
from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
from cubicweb.server.utils import eschema_eid
from cubicweb.server.edition import EditedEntity
def count_lines(stream_or_filename):
if isinstance(stream_or_filename, basestring):
f = open(stream_or_filename)
else:
f = stream_or_filename
f.seek(0)
for i, line in enumerate(f):
pass
f.seek(0)
return i+1
def ucsvreader_pb(stream_or_path, encoding='utf-8', separator=',', quote='"',
skipfirst=False, withpb=True):
"""same as ucsvreader but a progress bar is displayed as we iter on rows"""
if isinstance(stream_or_path, basestring):
if not osp.exists(stream_or_path):
raise Exception("file doesn't exists: %s" % stream_or_path)
stream = open(stream_or_path)
else:
stream = stream_or_path
rowcount = count_lines(stream)
if skipfirst:
rowcount -= 1
if withpb:
pb = shellutils.ProgressBar(rowcount, 50)
for urow in ucsvreader(stream, encoding, separator, quote, skipfirst):
yield urow
if withpb:
pb.update()
print ' %s rows imported' % rowcount
def ucsvreader(stream, encoding='utf-8', separator=',', quote='"',
skipfirst=False):
"""A csv reader that accepts files with any encoding and outputs unicode
strings
"""
it = iter(csv.reader(stream, delimiter=separator, quotechar=quote))
if skipfirst:
it.next()
for row in it:
yield [item.decode(encoding) for item in row]
def callfunc_every(func, number, iterable):
"""yield items of `iterable` one by one and call function `func`
every `number` iterations. Always call function `func` at the end.
"""
for idx, item in enumerate(iterable):
yield item
if not idx % number:
func()
func()
def lazytable(reader):
"""The first row is taken to be the header of the table and
used to output a dict for each row of data.
>>> data = lazytable(ucsvreader(open(filename)))
"""
header = reader.next()
for row in reader:
yield dict(zip(header, row))
def lazydbtable(cu, table, headers, orderby=None):
"""return an iterator on rows of a sql table. On each row, fetch columns
defined in headers and return values as a dictionary.
>>> data = lazydbtable(cu, 'experimentation', ('id', 'nickname', 'gps'))
"""
sql = 'SELECT %s FROM %s' % (','.join(headers), table,)
if orderby:
sql += ' ORDER BY %s' % ','.join(orderby)
cu.execute(sql)
while True:
row = cu.fetchone()
if row is None:
break
yield dict(zip(headers, row))
def mk_entity(row, map):
"""Return a dict made from sanitized mapped values.
ValueError can be raised on unexpected values found in checkers
>>> row = {'myname': u'dupont'}
>>> map = [('myname', u'name', (call_transform_method('title'),))]
>>> mk_entity(row, map)
{'name': u'Dupont'}
>>> row = {'myname': u'dupont', 'optname': u''}
>>> map = [('myname', u'name', (call_transform_method('title'),)),
... ('optname', u'MARKER', (optional,))]
>>> mk_entity(row, map)
{'name': u'Dupont', 'optname': None}
"""
res = {}
assert isinstance(row, dict)
assert isinstance(map, list)
for src, dest, funcs in map:
res[dest] = row[src]
try:
for func in funcs:
res[dest] = func(res[dest])
if res[dest] is None:
break
except ValueError, err:
raise ValueError('error with %r field: %s' % (src, err)), None, sys.exc_info()[-1]
return res
# user interactions ############################################################
def tell(msg):
print msg
def confirm(question):
"""A confirm function that asks for yes/no/abort and exits on abort."""
answer = shellutils.ASK.ask(question, ('Y', 'n', 'abort'), 'Y')
if answer == 'abort':
sys.exit(1)
return answer == 'Y'
class catch_error(object):
"""Helper for @contextmanager decorator."""
def __init__(self, ctl, key='unexpected error', msg=None):
self.ctl = ctl
self.key = key
self.msg = msg
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
if type is not None:
if issubclass(type, (KeyboardInterrupt, SystemExit)):
return # re-raise
if self.ctl.catcherrors:
self.ctl.record_error(self.key, None, type, value, traceback)
return True # silent
# base sanitizing/coercing functions ###########################################
def optional(value):
"""checker to filter optional field
If value is undefined (ex: empty string), return None that will
break the checkers validation chain
General use is to add 'optional' check in first condition to avoid
ValueError by further checkers
>>> MAPPER = [(u'value', 'value', (optional, int))]
>>> row = {'value': u'XXX'}
>>> mk_entity(row, MAPPER)
{'value': None}
>>> row = {'value': u'100'}
>>> mk_entity(row, MAPPER)
{'value': 100}
"""
if value:
return value
return None
def required(value):
"""raise ValueError if value is empty
This check should be often found in last position in the chain.
"""
if value:
return value
raise ValueError("required")
def todatetime(format='%d/%m/%Y'):
"""return a transformation function to turn string input value into a
`datetime.datetime` instance, using given format.
Follow it by `todate` or `totime` functions from `logilab.common.date` if
you want a `date`/`time` instance instead of `datetime`.
"""
def coerce(value):
return strptime(value, format)
return coerce
def call_transform_method(methodname, *args, **kwargs):
"""return value returned by calling the given method on input"""
def coerce(value):
return getattr(value, methodname)(*args, **kwargs)
return coerce
def call_check_method(methodname, *args, **kwargs):
"""check value returned by calling the given method on input is true,
else raise ValueError
"""
def check(value):
if getattr(value, methodname)(*args, **kwargs):
return value
raise ValueError('%s not verified on %r' % (methodname, value))
return check
# base integrity checking functions ############################################
def check_doubles(buckets):
"""Extract the keys that have more than one item in their bucket."""
return [(k, len(v)) for k, v in buckets.items() if len(v) > 1]
def check_doubles_not_none(buckets):
"""Extract the keys that have more than one item in their bucket."""
return [(k, len(v)) for k, v in buckets.items()
if k is not None and len(v) > 1]
# object stores #################################################################
class ObjectStore(object):
"""Store objects in memory for *faster* validation (development mode)
But it will not enforce the constraints of the schema and hence will miss some problems
>>> store = ObjectStore()
>>> user = store.create_entity('CWUser', login=u'johndoe')
>>> group = store.create_entity('CWUser', name=u'unknown')
>>> store.relate(user.eid, 'in_group', group.eid)
"""
def __init__(self):
self.items = []
self.eids = {}
self.types = {}
self.relations = set()
self.indexes = {}
self._rql = None
self._commit = None
def _put(self, type, item):
self.items.append(item)
return len(self.items) - 1
def create_entity(self, etype, **data):
data = attrdict(data)
data['eid'] = eid = self._put(etype, data)
self.eids[eid] = data
self.types.setdefault(etype, []).append(eid)
return data
@deprecated("[3.11] add is deprecated, use create_entity instead")
def add(self, etype, item):
assert isinstance(item, dict), 'item is not a dict but a %s' % type(item)
data = self.create_entity(etype, **item)
item['eid'] = data['eid']
return item
def relate(self, eid_from, rtype, eid_to, inlined=False):
"""Add new relation"""
relation = eid_from, rtype, eid_to
self.relations.add(relation)
return relation
def commit(self):
"""this commit method do nothing by default
This is voluntary to use the frequent autocommit feature in CubicWeb
when you are using hooks or another
If you want override commit method, please set it by the
constructor
"""
pass
def rql(self, *args):
if self._rql is not None:
return self._rql(*args)
return []
@property
def nb_inserted_entities(self):
return len(self.eids)
@property
def nb_inserted_types(self):
return len(self.types)
@property
def nb_inserted_relations(self):
return len(self.relations)
@deprecated("[3.7] index support will disappear")
def build_index(self, name, type, func=None, can_be_empty=False):
"""build internal index for further search"""
index = {}
if func is None or not callable(func):
func = lambda x: x['eid']
for eid in self.types[type]:
index.setdefault(func(self.eids[eid]), []).append(eid)
if not can_be_empty:
assert index, "new index '%s' cannot be empty" % name
self.indexes[name] = index
@deprecated("[3.7] index support will disappear")
def build_rqlindex(self, name, type, key, rql, rql_params=False,
func=None, can_be_empty=False):
"""build an index by rql query
rql should return eid in first column
ctl.store.build_index('index_name', 'users', 'login', 'Any U WHERE U is CWUser')
"""
self.types[type] = []
rset = self.rql(rql, rql_params or {})
if not can_be_empty:
assert rset, "new index type '%s' cannot be empty (0 record found)" % type
for entity in rset.entities():
getattr(entity, key) # autopopulate entity with key attribute
self.eids[entity.eid] = dict(entity)
if entity.eid not in self.types[type]:
self.types[type].append(entity.eid)
# Build index with specified key
func = lambda x: x[key]
self.build_index(name, type, func, can_be_empty=can_be_empty)
@deprecated("[3.7] index support will disappear")
def fetch(self, name, key, unique=False, decorator=None):
"""index fetcher method
decorator is a callable method or an iterator of callable methods (usually a lambda function)
decorator=lambda x: x[:1] (first value is returned)
decorator=lambda x: x.lower (lowercased value is returned)
decorator is handy when you want to improve index keys but without
changing the original field
Same check functions can be reused here.
"""
eids = self.indexes[name].get(key, [])
if decorator is not None:
if not hasattr(decorator, '__iter__'):
decorator = (decorator,)
for f in decorator:
eids = f(eids)
if unique:
assert len(eids) == 1, u'expected a single one value for key "%s" in index "%s". Got %i' % (key, name, len(eids))
eids = eids[0]
return eids
@deprecated("[3.7] index support will disappear")
def find(self, type, key, value):
for idx in self.types[type]:
item = self.items[idx]
if item[key] == value:
yield item
@deprecated("[3.7] checkpoint() deprecated. use commit() instead")
def checkpoint(self):
self.commit()
class RQLObjectStore(ObjectStore):
"""ObjectStore that works with an actual RQL repository (production mode)"""
_rql = None # bw compat
def __init__(self, session=None, commit=None):
ObjectStore.__init__(self)
if session is None:
sys.exit('please provide a session of run this script with cubicweb-ctl shell and pass cnx as session')
if not hasattr(session, 'set_cnxset'):
# connection
cnx = session
session = session.request()
session.set_cnxset = lambda : None
commit = commit or cnx.commit
else:
session.set_cnxset()
self.session = session
self._commit = commit or session.commit
@deprecated("[3.7] checkpoint() deprecated. use commit() instead")
def checkpoint(self):
self.commit()
def commit(self):
txuuid = self._commit()
self.session.set_cnxset()
return txuuid
def rql(self, *args):
if self._rql is not None:
return self._rql(*args)
return self.session.execute(*args)
def create_entity(self, *args, **kwargs):
entity = self.session.create_entity(*args, **kwargs)
self.eids[entity.eid] = entity
self.types.setdefault(args[0], []).append(entity.eid)
return entity
def _put(self, type, item):
query = 'INSERT %s X' % type
if item:
query += ': ' + ', '.join('X %s %%(%s)s' % (k, k)
for k in item)
return self.rql(query, item)[0][0]
def relate(self, eid_from, rtype, eid_to, inlined=False):
eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
eid_from, rtype, eid_to)
self.rql('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
{'x': int(eid_from), 'y': int(eid_to)})
def find_entities(self, *args, **kwargs):
return self.session.find_entities(*args, **kwargs)
def find_one_entity(self, *args, **kwargs):
return self.session.find_one_entity(*args, **kwargs)
# the import controller ########################################################
class CWImportController(object):
"""Controller of the data import process.
>>> ctl = CWImportController(store)
>>> ctl.generators = list_of_data_generators
>>> ctl.data = dict_of_data_tables
>>> ctl.run()
"""
def __init__(self, store, askerror=0, catcherrors=None, tell=tell,
commitevery=50):
self.store = store
self.generators = None
self.data = {}
self.errors = None
self.askerror = askerror
if catcherrors is None:
catcherrors = askerror
self.catcherrors = catcherrors
self.commitevery = commitevery # set to None to do a single commit
self._tell = tell
def check(self, type, key, value):
self._checks.setdefault(type, {}).setdefault(key, []).append(value)
def check_map(self, entity, key, map, default):
try:
entity[key] = map[entity[key]]
except KeyError:
self.check(key, entity[key], None)
entity[key] = default
def record_error(self, key, msg=None, type=None, value=None, tb=None):
tmp = StringIO()
if type is None:
traceback.print_exc(file=tmp)
else:
traceback.print_exception(type, value, tb, file=tmp)
# use a list to avoid counting a <nb lines> errors instead of one
errorlog = self.errors.setdefault(key, [])
if msg is None:
errorlog.append(tmp.getvalue().splitlines())
else:
errorlog.append( (msg, tmp.getvalue().splitlines()) )
def run(self):
self.errors = {}
if self.commitevery is None:
self.tell('Will commit all or nothing.')
else:
self.tell('Will commit every %s iterations' % self.commitevery)
for func, checks in self.generators:
self._checks = {}
func_name = func.__name__
self.tell("Run import function '%s'..." % func_name)
try:
func(self)
except Exception:
if self.catcherrors:
self.record_error(func_name, 'While calling %s' % func.__name__)
else:
self._print_stats()
raise
for key, func, title, help in checks:
buckets = self._checks.get(key)
if buckets:
err = func(buckets)
if err:
self.errors[title] = (help, err)
try:
txuuid = self.store.commit()
if txuuid is not None:
self.tell('Transaction commited (txuuid: %s)' % txuuid)
except QueryError, ex:
self.tell('Transaction aborted: %s' % ex)
self._print_stats()
if self.errors:
if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
from pprint import pformat
for errkey, error in self.errors.items():
self.tell("\n%s (%s): %d\n" % (error[0], errkey, len(error[1])))
self.tell(pformat(sorted(error[1])))
def _print_stats(self):
nberrors = sum(len(err) for err in self.errors.values())
self.tell('\nImport statistics: %i entities, %i types, %i relations and %i errors'
% (self.store.nb_inserted_entities,
self.store.nb_inserted_types,
self.store.nb_inserted_relations,
nberrors))
def get_data(self, key):
return self.data.get(key)
def index(self, name, key, value, unique=False):
"""create a new index
If unique is set to True, only first occurence will be kept not the following ones
"""
if unique:
try:
if value in self.store.indexes[name][key]:
return
except KeyError:
# we're sure that one is the first occurence; so continue...
pass
self.store.indexes.setdefault(name, {}).setdefault(key, []).append(value)
def tell(self, msg):
self._tell(msg)
def iter_and_commit(self, datakey):
"""iter rows, triggering commit every self.commitevery iterations"""
if self.commitevery is None:
return self.get_data(datakey)
else:
return callfunc_every(self.store.commit,
self.commitevery,
self.get_data(datakey))
class NoHookRQLObjectStore(RQLObjectStore):
"""ObjectStore that works with an actual RQL repository (production mode)"""
_rql = None # bw compat
def __init__(self, session, metagen=None, baseurl=None):
super(NoHookRQLObjectStore, self).__init__(session)
self.source = session.repo.system_source
self.rschema = session.repo.schema.rschema
self.add_relation = self.source.add_relation
if metagen is None:
metagen = MetaGenerator(session, baseurl)
self.metagen = metagen
self._nb_inserted_entities = 0
self._nb_inserted_types = 0
self._nb_inserted_relations = 0
self.rql = session.execute
# deactivate security
session.set_read_security(False)
session.set_write_security(False)
def create_entity(self, etype, **kwargs):
for k, v in kwargs.iteritems():
kwargs[k] = getattr(v, 'eid', v)
entity, rels = self.metagen.base_etype_dicts(etype)
# make a copy to keep cached entity pristine
entity = copy(entity)
entity.cw_edited = copy(entity.cw_edited)
entity.cw_clear_relation_cache()
self.metagen.init_entity(entity)
entity.cw_edited.update(kwargs, skipsec=False)
session = self.session
self.source.add_entity(session, entity)
self.source.add_info(session, entity, self.source, None, complete=False)
for rtype, targeteids in rels.iteritems():
# targeteids may be a single eid or a list of eids
inlined = self.rschema(rtype).inlined
try:
for targeteid in targeteids:
self.add_relation(session, entity.eid, rtype, targeteid,
inlined)
except TypeError:
self.add_relation(session, entity.eid, rtype, targeteids,
inlined)
self._nb_inserted_entities += 1
return entity
def relate(self, eid_from, rtype, eid_to):
assert not rtype.startswith('reverse_')
self.add_relation(self.session, eid_from, rtype, eid_to,
self.rschema(rtype).inlined)
self._nb_inserted_relations += 1
@property
def nb_inserted_entities(self):
return self._nb_inserted_entities
@property
def nb_inserted_types(self):
return self._nb_inserted_types
@property
def nb_inserted_relations(self):
return self._nb_inserted_relations
def _put(self, type, item):
raise RuntimeError('use create entity')
class MetaGenerator(object):
META_RELATIONS = (META_RTYPES
- VIRTUAL_RTYPES
- set(('eid', 'cwuri',
'is', 'is_instance_of', 'cw_source')))
def __init__(self, session, baseurl=None):
self.session = session
self.source = session.repo.system_source
self.time = datetime.now()
if baseurl is None:
config = session.vreg.config
baseurl = config['base-url'] or config.default_base_url()
if not baseurl[-1] == '/':
baseurl += '/'
self.baseurl = baseurl
# attributes/relations shared by all entities of the same type
self.etype_attrs = []
self.etype_rels = []
# attributes/relations specific to each entity
self.entity_attrs = ['cwuri']
#self.entity_rels = [] XXX not handled (YAGNI?)
schema = session.vreg.schema
rschema = schema.rschema
for rtype in self.META_RELATIONS:
if rschema(rtype).final:
self.etype_attrs.append(rtype)
else:
self.etype_rels.append(rtype)
@cached
def base_etype_dicts(self, etype):
entity = self.session.vreg['etypes'].etype_class(etype)(self.session)
# entity are "surface" copied, avoid shared dict between copies
del entity.cw_extra_kwargs
entity.cw_edited = EditedEntity(entity)
for attr in self.etype_attrs:
entity.cw_edited.edited_attribute(attr, self.generate(entity, attr))
rels = {}
for rel in self.etype_rels:
rels[rel] = self.generate(entity, rel)
return entity, rels
def init_entity(self, entity):
entity.eid = self.source.create_eid(self.session)
for attr in self.entity_attrs:
entity.cw_edited.edited_attribute(attr, self.generate(entity, attr))
def generate(self, entity, rtype):
return getattr(self, 'gen_%s' % rtype)(entity)
def gen_cwuri(self, entity):
return u'%seid/%s' % (self.baseurl, entity.eid)
def gen_creation_date(self, entity):
return self.time
def gen_modification_date(self, entity):
return self.time
def gen_created_by(self, entity):
return self.session.user.eid
def gen_owned_by(self, entity):
return self.session.user.eid