devtools/dataimport.py
branchstable
changeset 4818 9f9bfbcdecfd
parent 4734 4ae30c9ca11b
child 4847 9466604ef448
--- a/devtools/dataimport.py	Fri Mar 05 17:24:01 2010 +0100
+++ b/devtools/dataimport.py	Fri Mar 05 18:07:39 2010 +0100
@@ -59,10 +59,14 @@
 import traceback
 import os.path as osp
 from StringIO import StringIO
+from copy import copy
 
 from logilab.common import shellutils
+from logilab.common.date import strptime
+from logilab.common.decorators import cached
 from logilab.common.deprecation import deprecated
 
+
 def ucsvreader_pb(filepath, encoding='utf-8', separator=',', quote='"',
                   skipfirst=False, withpb=True):
     """same as ucsvreader but a progress bar is displayed as we iter on rows"""
@@ -90,8 +94,6 @@
     for row in it:
         yield [item.decode(encoding) for item in row]
 
-utf8csvreader = deprecated('use ucsvreader instead')(ucsvreader)
-
 def commit_every(nbit, store, it):
     for i, x in enumerate(it):
         yield x
@@ -129,19 +131,16 @@
     assert isinstance(row, dict)
     assert isinstance(map, list)
     for src, dest, funcs in map:
-        assert not (required in funcs and optional in funcs), "optional and required checks are exclusive"
+        assert not (required in funcs and optional in funcs), \
+               "optional and required checks are exclusive"
         res[dest] = row[src]
         try:
             for func in funcs:
                 res[dest] = func(res[dest])
-            if res[dest] is None:
-                raise AssertionError('undetermined value')
-        except AssertionError, err:
-            if optional in funcs:
-                # Forget this field if exception is coming from optional function
-               del res[dest]
-            else:
-               raise AssertionError('error with "%s" field: %s' % (src, err))
+                if res[dest] is None:
+                    break
+        except ValueError, err:
+            raise ValueError('error with %r field: %s' % (src, err))
     return res
 
 
@@ -178,98 +177,49 @@
                 return True # silent
 
 
-# base sanitizing functions ####################################################
-
-def capitalize_if_unicase(txt):
-    if txt.isupper() or txt.islower():
-        return txt.capitalize()
-    return txt
-
-def uppercase(txt):
-    return txt.upper()
-
-def lowercase(txt):
-    return txt.lower()
-
-def no_space(txt):
-    return txt.replace(' ','')
-
-def no_uspace(txt):
-    return txt.replace(u'\xa0','')
-
-def no_dash(txt):
-    return txt.replace('-','')
-
-def decimal(value):
-    """cast to float but with comma replacement
-
-    We take care of some locale format as replacing ',' by '.'"""
-    value = value.replace(',', '.')
-    try:
-        return float(value)
-    except Exception, err:
-        raise AssertionError(err)
-
-def integer(value):
-    try:
-        return int(value)
-    except Exception, err:
-        raise AssertionError(err)
-
-def strip(txt):
-    return txt.strip()
-
-def yesno(value):
-    """simple heuristic that returns boolean value
-
-    >>> yesno("Yes")
-    True
-    >>> yesno("oui")
-    True
-    >>> yesno("1")
-    True
-    >>> yesno("11")
-    True
-    >>> yesno("")
-    False
-    >>> yesno("Non")
-    False
-    >>> yesno("blablabla")
-    False
-    """
-    if value:
-        return value.lower()[0] in 'yo1'
-    return False
-
-def isalpha(value):
-    if value.isalpha():
-        return value
-    raise AssertionError("not all characters in the string alphabetic")
+# base sanitizing/coercing functions ###########################################
 
 def optional(value):
     """validation error will not been raised if you add this checker in chain"""
-    return value
+    if value:
+        return value
+    return None
 
 def required(value):
-    """raise AssertionError is value is empty
+    """raise ValueError is value is empty
 
     This check should be often found in last position in the chain.
     """
-    if bool(value):
+    if value:
         return value
-    raise AssertionError("required")
+    raise ValueError("required")
 
-@deprecated('use required(value)')
-def nonempty(value):
-    return required(value)
+def todatetime(format='%d/%m/%Y'):
+    """return a transformation function to turn string input value into a
+    `datetime.datetime` instance, using given format.
+
+    Follow it by `todate` or `totime` functions from `logilab.common.date` if
+    you want a `date`/`time` instance instead of `datetime`.
+    """
+    def coerce(value):
+        return strptime(value, format)
+    return coerce
 
-@deprecated('use integer(value)')
-def alldigits(txt):
-    if txt.isdigit():
-        return txt
-    else:
-        return u''
+def call_transform_method(methodname, *args, **kwargs):
+    """return value returned by calling the given method on input"""
+    def coerce(value):
+        return getattr(value, methodname)(*args, **kwargs)
+    return coerce
 
+def call_check_method(methodname, *args, **kwargs):
+    """check value returned by calling the given method on input is true,
+    else raise ValueError
+    """
+    def check(value):
+        if getattr(value, methodname)(*args, **kwargs):
+            return value
+        raise ValueError('%s not verified on %r' % (methodname, value))
+    return check
 
 # base integrity checking functions ############################################
 
@@ -316,7 +266,7 @@
         self.eids[eid] = item
         self.types.setdefault(type, []).append(eid)
 
-    def relate(self, eid_from, rtype, eid_to):
+    def relate(self, eid_from, rtype, eid_to, inlined=False):
         """Add new relation (reverse type support is available)
 
         >>> 1,2 = eid_from, eid_to
@@ -359,14 +309,6 @@
         func = lambda x: x[key]
         self.build_index(name, type, func)
 
-    @deprecated('get_many() deprecated. Use fetch() instead')
-    def get_many(self, name, key):
-        return self.fetch(name, key, unique=False)
-
-    @deprecated('get_one() deprecated. Use fetch(..., unique=True) instead')
-    def get_one(self, name, key):
-        return self.fetch(name, key, unique=True)
-
     def fetch(self, name, key, unique=False, decorator=None):
         """
             decorator is a callable method or an iterator of callable methods (usually a lambda function)
@@ -398,6 +340,24 @@
     def checkpoint(self):
         pass
 
+    @property
+    def nb_inserted_entities(self):
+        return len(self.eids)
+    @property
+    def nb_inserted_types(self):
+        return len(self.types)
+    @property
+    def nb_inserted_relations(self):
+        return len(self.relations)
+
+    @deprecated('[3.6] get_many() deprecated. Use fetch() instead')
+    def get_many(self, name, key):
+        return self.fetch(name, key, unique=False)
+
+    @deprecated('[3.6] get_one() deprecated. Use fetch(..., unique=True) instead')
+    def get_one(self, name, key):
+        return self.fetch(name, key, unique=True)
+
 
 class RQLObjectStore(ObjectStore):
     """ObjectStore that works with an actual RQL repository (production mode)"""
@@ -412,19 +372,24 @@
                 session = session.request()
                 session.set_pool = lambda : None
                 checkpoint = checkpoint or cnx.commit
+            else:
+                session.set_pool()
             self.session = session
-            self.checkpoint = checkpoint or session.commit
+            self._checkpoint = checkpoint or session.commit
         elif checkpoint is not None:
-            self.checkpoint = checkpoint
+            self._checkpoint = checkpoint
+            # XXX .session
+
+    def checkpoint(self):
+        self._checkpoint()
+        self.session.set_pool()
 
     def rql(self, *args):
         if self._rql is not None:
             return self._rql(*args)
-        self.session.set_pool()
         return self.session.execute(*args)
 
     def create_entity(self, *args, **kwargs):
-        self.session.set_pool()
         entity = self.session.create_entity(*args, **kwargs)
         self.eids[entity.eid] = entity
         self.types.setdefault(args[0], []).append(entity.eid)
@@ -435,7 +400,7 @@
                                                      for k in item)
         return self.rql(query, item)[0][0]
 
-    def relate(self, eid_from, rtype, eid_to):
+    def relate(self, eid_from, rtype, eid_to, inlined=False):
         # if reverse relation is found, eids are exchanged
         eid_from, rtype, eid_to = super(RQLObjectStore, self).relate(
             eid_from, rtype, eid_to)
@@ -513,8 +478,10 @@
         self.store.checkpoint()
         nberrors = sum(len(err[1]) for err in self.errors.values())
         self.tell('\nImport completed: %i entities, %i types, %i relations and %i errors'
-                  % (len(self.store.eids), len(self.store.types),
-                     len(self.store.relations), nberrors))
+                  % (self.store.nb_inserted_entities,
+                     self.store.nb_inserted_types,
+                     self.store.nb_inserted_relations,
+                     nberrors))
         if self.errors:
             if self.askerror == 2 or (self.askerror and confirm('Display errors ?')):
                 from pprint import pformat
@@ -545,3 +512,241 @@
     def iter_and_commit(self, datakey):
         """iter rows, triggering commit every self.commitevery iterations"""
         return commit_every(self.commitevery, self.store, self.get_data(datakey))
+
+
+
+from datetime import datetime
+from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
+
+
+class NoHookRQLObjectStore(RQLObjectStore):
+    """ObjectStore that works with an actual RQL repository (production mode)"""
+    _rql = None # bw compat
+
+    def __init__(self, session, metagen=None, baseurl=None):
+        super(NoHookRQLObjectStore, self).__init__(session)
+        self.source = session.repo.system_source
+        self.rschema = session.repo.schema.rschema
+        self.add_relation = self.source.add_relation
+        if metagen is None:
+            metagen = MetaGenerator(session, baseurl)
+        self.metagen = metagen
+        self._nb_inserted_entities = 0
+        self._nb_inserted_types = 0
+        self._nb_inserted_relations = 0
+        self.rql = session.unsafe_execute
+
+    def create_entity(self, etype, **kwargs):
+        for k, v in kwargs.iteritems():
+            kwargs[k] = getattr(v, 'eid', v)
+        entity, rels = self.metagen.base_etype_dicts(etype)
+        entity = copy(entity)
+        entity._related_cache = {}
+        self.metagen.init_entity(entity)
+        entity.update(kwargs)
+        session = self.session
+        self.source.add_entity(session, entity)
+        self.source.add_info(session, entity, self.source, complete=False)
+        for rtype, targeteids in rels.iteritems():
+            # targeteids may be a single eid or a list of eids
+            inlined = self.rschema(rtype).inlined
+            try:
+                for targeteid in targeteids:
+                    self.add_relation(session, entity.eid, rtype, targeteid,
+                                      inlined)
+            except TypeError:
+                self.add_relation(session, entity.eid, rtype, targeteids,
+                                  inlined)
+        self._nb_inserted_entities += 1
+        return entity
+
+    def relate(self, eid_from, rtype, eid_to):
+        assert not rtype.startswith('reverse_')
+        self.add_relation(self.session, eid_from, rtype, eid_to,
+                          self.rschema(rtype).inlined)
+        self._nb_inserted_relations += 1
+
+    @property
+    def nb_inserted_entities(self):
+        return self._nb_inserted_entities
+    @property
+    def nb_inserted_types(self):
+        return self._nb_inserted_types
+    @property
+    def nb_inserted_relations(self):
+        return self._nb_inserted_relations
+
+    def _put(self, type, item):
+        raise RuntimeError('use create entity')
+
+
+class MetaGenerator(object):
+    def __init__(self, session, baseurl=None):
+        self.session = session
+        self.source = session.repo.system_source
+        self.time = datetime.now()
+        if baseurl is None:
+            config = session.vreg.config
+            baseurl = config['base-url'] or config.default_base_url()
+        if not baseurl[-1] == '/':
+            baseurl += '/'
+        self.baseurl =  baseurl
+        # attributes/relations shared by all entities of the same type
+        self.etype_attrs = []
+        self.etype_rels = []
+        # attributes/relations specific to each entity
+        self.entity_attrs = ['eid', 'cwuri']
+        #self.entity_rels = [] XXX not handled (YAGNI?)
+        schema = session.vreg.schema
+        rschema = schema.rschema
+        for rtype in META_RTYPES:
+            if rtype in ('eid', 'cwuri') or rtype in VIRTUAL_RTYPES:
+                continue
+            if rschema(rtype).final:
+                self.etype_attrs.append(rtype)
+            else:
+                self.etype_rels.append(rtype)
+        if not schema._eid_index:
+            # test schema loaded from the fs
+            self.gen_is = self.test_gen_is
+            self.gen_is_instance_of = self.test_gen_is_instanceof
+
+    @cached
+    def base_etype_dicts(self, etype):
+        entity = self.session.vreg['etypes'].etype_class(etype)(self.session)
+        # entity are "surface" copied, avoid shared dict between copies
+        del entity.cw_extra_kwargs
+        for attr in self.etype_attrs:
+            entity[attr] = self.generate(entity, attr)
+        rels = {}
+        for rel in self.etype_rels:
+            rels[rel] = self.generate(entity, rel)
+        return entity, rels
+
+    def init_entity(self, entity):
+        for attr in self.entity_attrs:
+            entity[attr] = self.generate(entity, attr)
+        entity.eid = entity['eid']
+
+    def generate(self, entity, rtype):
+        return getattr(self, 'gen_%s' % rtype)(entity)
+
+    def gen_eid(self, entity):
+        return self.source.create_eid(self.session)
+
+    def gen_cwuri(self, entity):
+        return u'%seid/%s' % (self.baseurl, entity['eid'])
+
+    def gen_creation_date(self, entity):
+        return self.time
+    def gen_modification_date(self, entity):
+        return self.time
+
+    def gen_is(self, entity):
+        return entity.e_schema.eid
+    def gen_is_instance_of(self, entity):
+        eids = []
+        for etype in entity.e_schema.ancestors() + [entity.e_schema]:
+            eids.append(entity.e_schema.eid)
+        return eids
+
+    def gen_created_by(self, entity):
+        return self.session.user.eid
+    def gen_owned_by(self, entity):
+        return self.session.user.eid
+
+    # implementations of gen_is / gen_is_instance_of to use during test where
+    # schema has been loaded from the fs (hence entity type schema eids are not
+    # known)
+    def test_gen_is(self, entity):
+        from cubicweb.hooks.metadata import eschema_eid
+        return eschema_eid(self.session, entity.e_schema)
+    def test_gen_is_instanceof(self, entity):
+        from cubicweb.hooks.metadata import eschema_eid
+        eids = []
+        for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
+            eids.append(eschema_eid(self.session, eschema))
+        return eids
+
+
+################################################################################
+
+utf8csvreader = deprecated('[3.6] use ucsvreader instead')(ucsvreader)
+
+@deprecated('[3.6] use required')
+def nonempty(value):
+    return required(value)
+
+@deprecated("[3.6] use call_check_method('isdigit')")
+def alldigits(txt):
+    if txt.isdigit():
+        return txt
+    else:
+        return u''
+
+@deprecated("[3.7] too specific, will move away, copy me")
+def capitalize_if_unicase(txt):
+    if txt.isupper() or txt.islower():
+        return txt.capitalize()
+    return txt
+
+@deprecated("[3.7] too specific, will move away, copy me")
+def yesno(value):
+    """simple heuristic that returns boolean value
+
+    >>> yesno("Yes")
+    True
+    >>> yesno("oui")
+    True
+    >>> yesno("1")
+    True
+    >>> yesno("11")
+    True
+    >>> yesno("")
+    False
+    >>> yesno("Non")
+    False
+    >>> yesno("blablabla")
+    False
+    """
+    if value:
+        return value.lower()[0] in 'yo1'
+    return False
+
+@deprecated("[3.7] use call_check_method('isalpha')")
+def isalpha(value):
+    if value.isalpha():
+        return value
+    raise ValueError("not all characters in the string alphabetic")
+
+@deprecated("[3.7] use call_transform_method('upper')")
+def uppercase(txt):
+    return txt.upper()
+
+@deprecated("[3.7] use call_transform_method('lower')")
+def lowercase(txt):
+    return txt.lower()
+
+@deprecated("[3.7] use call_transform_method('replace', ' ', '')")
+def no_space(txt):
+    return txt.replace(' ','')
+
+@deprecated("[3.7] use call_transform_method('replace', u'\xa0', '')")
+def no_uspace(txt):
+    return txt.replace(u'\xa0','')
+
+@deprecated("[3.7] use call_transform_method('replace', '-', '')")
+def no_dash(txt):
+    return txt.replace('-','')
+
+@deprecated("[3.7] use call_transform_method('strip')")
+def strip(txt):
+    return txt.strip()
+
+@deprecated("[3.7] use call_transform_method('replace', ',', '.'), float")
+def decimal(value):
+    return comma_float(value)
+
+@deprecated('[3.7] use int builtin')
+def integer(value):
+    return int(value)