dataimport.py
changeset 8625 7ee0752178e5
parent 8406 f3bc8ca0b715
child 8631 1053b9d0fdf7
--- a/dataimport.py	Mon Nov 26 12:52:33 2012 +0100
+++ b/dataimport.py	Fri Dec 14 14:08:14 2012 +0100
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -66,13 +66,18 @@
 """
 __docformat__ = "restructuredtext en"
 
+import csv
 import sys
-import csv
+import threading
 import traceback
+import cPickle
 import os.path as osp
-from StringIO import StringIO
+from collections import defaultdict
+from contextlib import contextmanager
 from copy import copy
-from datetime import datetime
+from datetime import date, datetime
+from time import asctime
+from StringIO import StringIO
 
 from logilab.common import shellutils, attrdict
 from logilab.common.date import strptime
@@ -80,9 +85,11 @@
 from logilab.common.deprecation import deprecated
 
 from cubicweb import QueryError
+from cubicweb.utils import make_uid
 from cubicweb.schema import META_RTYPES, VIRTUAL_RTYPES
+from cubicweb.server.edition import EditedEntity
+from cubicweb.server.sqlutils import SQL_PREFIX
 from cubicweb.server.utils import eschema_eid
-from cubicweb.server.edition import EditedEntity
 
 
 def count_lines(stream_or_filename):
@@ -299,6 +306,142 @@
             if k is not None and len(v) > 1]
 
 
+# sql generator utility functions #############################################
+
+
+def _import_statements(sql_connect, statements, nb_threads=3,
+                       dump_output_dir=None,
+                       support_copy_from=True, encoding='utf-8'):
+    """
+    Import a bunch of sql statements, using different threads.
+    """
+    try:
+        chunksize = (len(statements) / nb_threads) + 1
+        threads = []
+        for i in xrange(nb_threads):
+            chunks = statements[i*chunksize:(i+1)*chunksize]
+            thread = threading.Thread(target=_execmany_thread,
+                                      args=(sql_connect, chunks,
+                                            dump_output_dir,
+                                            support_copy_from,
+                                            encoding))
+            thread.start()
+            threads.append(thread)
+        for t in threads:
+            t.join()
+    except Exception:
+        print 'Error in import statements'
+
+def _execmany_thread_not_copy_from(cu, statement, data, table=None,
+                                   columns=None, encoding='utf-8'):
+    """ Execute thread without copy from
+    """
+    cu.executemany(statement, data)
+
+def _execmany_thread_copy_from(cu, statement, data, table,
+                               columns, encoding='utf-8'):
+    """ Execute thread with copy from
+    """
+    buf = _create_copyfrom_buffer(data, columns, encoding)
+    if buf is None:
+        _execmany_thread_not_copy_from(cu, statement, data)
+    else:
+        if columns is None:
+            cu.copy_from(buf, table, null='NULL')
+        else:
+            cu.copy_from(buf, table, null='NULL', columns=columns)
+
+def _execmany_thread(sql_connect, statements, dump_output_dir=None,
+                     support_copy_from=True, encoding='utf-8'):
+    """
+    Execute sql statement. If 'INSERT INTO', try to use 'COPY FROM' command,
+    or fallback to execute_many.
+    """
+    if support_copy_from:
+        execmany_func = _execmany_thread_copy_from
+    else:
+        execmany_func = _execmany_thread_not_copy_from
+    cnx = sql_connect()
+    cu = cnx.cursor()
+    try:
+        for statement, data in statements:
+            table = None
+            columns = None
+            try:
+                if not statement.startswith('INSERT INTO'):
+                    cu.executemany(statement, data)
+                    continue
+                table = statement.split()[2]
+                if isinstance(data[0], (tuple, list)):
+                    columns = None
+                else:
+                    columns = data[0].keys()
+                execmany_func(cu, statement, data, table, columns, encoding)
+            except Exception:
+                print 'unable to copy data into table %s', table
+                # Error in import statement, save data in dump_output_dir
+                if dump_output_dir is not None:
+                    pdata = {'data': data, 'statement': statement,
+                             'time': asctime(), 'columns': columns}
+                    filename = make_uid()
+                    try:
+                        with open(osp.join(dump_output_dir,
+                                           '%s.pickle' % filename), 'w') as fobj:
+                            fobj.write(cPickle.dumps(pdata))
+                    except IOError:
+                        print 'ERROR while pickling in', dump_output_dir, filename+'.pickle'
+                        pass
+                cnx.rollback()
+                raise
+    finally:
+        cnx.commit()
+        cu.close()
+
+def _create_copyfrom_buffer(data, columns, encoding='utf-8'):
+    """
+    Create a StringIO buffer for 'COPY FROM' command.
+    Deals with Unicode, Int, Float, Date...
+    """
+    # Create a list rather than directly create a StringIO
+    # to correctly write lines separated by '\n' in a single step
+    rows = []
+    if isinstance(data[0], (tuple, list)):
+        columns = range(len(data[0]))
+    for row in data:
+        # Iterate over the different columns and the different values
+        # and try to convert them to a correct datatype.
+        # If an error is raised, do not continue.
+        formatted_row = []
+        for col in columns:
+            value = row[col]
+            if value is None:
+                value = 'NULL'
+            elif isinstance(value, (long, int, float)):
+                value = str(value)
+            elif isinstance(value, (str, unicode)):
+                # Remove separators used in string formatting
+                if u'\t' in value or u'\r' in value or u'\n' in value:
+                    return
+                value = value.replace('\\', r'\\')
+                if not value:
+                    return
+                if isinstance(value, unicode):
+                    value = value.encode(encoding)
+            elif isinstance(value, (date, datetime)):
+                # Do not use strftime, as it yields issue
+                # with date < 1900
+                value = '%04d-%02d-%02d' % (value.year,
+                                            value.month,
+                                            value.day)
+            else:
+                return None
+            # We push the value to the new formatted row
+            # if the value is not None and could be converted to a string.
+            formatted_row.append(value)
+        rows.append('\t'.join(formatted_row))
+    return StringIO('\n'.join(rows))
+
+
 # object stores #################################################################
 
 class ObjectStore(object):
@@ -753,3 +896,261 @@
         return self.session.user.eid
     def gen_owned_by(self, entity):
         return self.session.user.eid
+
+
+###########################################################################
+## SQL object store #######################################################
+###########################################################################
+class SQLGenObjectStore(NoHookRQLObjectStore):
+    """Controller of the data import process. This version is based
+    on direct insertions throught SQL command (COPY FROM or execute many).
+
+    >>> store = SQLGenObjectStore(session)
+    >>> store.create_entity('Person', ...)
+    >>> store.flush()
+    """
+
+    def __init__(self, session, dump_output_dir=None, nb_threads_statement=3):
+        """
+        Initialize a SQLGenObjectStore.
+
+        Parameters:
+
+          - session: session on the cubicweb instance
+          - dump_output_dir: a directory to dump failed statements
+            for easier recovery. Default is None (no dump).
+          - nb_threads_statement: number of threads used
+            for SQL insertion (default is 3).
+        """
+        super(SQLGenObjectStore, self).__init__(session)
+        ### hijack default source
+        self.source = SQLGenSourceWrapper(
+            self.source, session.vreg.schema,
+            dump_output_dir=dump_output_dir,
+            nb_threads_statement=nb_threads_statement)
+        ### XXX This is done in super().__init__(), but should be
+        ### redone here to link to the correct source
+        self.add_relation = self.source.add_relation
+        self.indexes_etypes = {}
+
+    def flush(self):
+        """Flush data to the database"""
+        self.source.flush()
+
+    def relate(self, subj_eid, rtype, obj_eid, subjtype=None):
+        if subj_eid is None or obj_eid is None:
+            return
+        # XXX Could subjtype be inferred ?
+        self.source.add_relation(self.session, subj_eid, rtype, obj_eid,
+                                 self.rschema(rtype).inlined, subjtype)
+
+    def drop_indexes(self, etype):
+        """Drop indexes for a given entity type"""
+        if etype not in self.indexes_etypes:
+            cu = self.session.cnxset['system']
+            def index_to_attr(index):
+                """turn an index name to (database) attribute name"""
+                return index.replace(etype.lower(), '').replace('idx', '').strip('_')
+            indices = [(index, index_to_attr(index))
+                       for index in self.source.dbhelper.list_indices(cu, etype)
+                       # Do not consider 'cw_etype_pkey' index
+                       if not index.endswith('key')]
+            self.indexes_etypes[etype] = indices
+        for index, attr in self.indexes_etypes[etype]:
+            self.session.system_sql('DROP INDEX %s' % index)
+
+    def create_indexes(self, etype):
+        """Recreate indexes for a given entity type"""
+        for index, attr in self.indexes_etypes.get(etype, []):
+            sql = 'CREATE INDEX %s ON cw_%s(%s)' % (index, etype, attr)
+            self.session.system_sql(sql)
+
+
+###########################################################################
+## SQL Source #############################################################
+###########################################################################
+
+class SQLGenSourceWrapper(object):
+
+    def __init__(self, system_source, schema,
+                 dump_output_dir=None, nb_threads_statement=3):
+        self.system_source = system_source
+        self._sql = threading.local()
+        # Explicitely backport attributes from system source
+        self._storage_handler = self.system_source._storage_handler
+        self.preprocess_entity = self.system_source.preprocess_entity
+        self.sqlgen = self.system_source.sqlgen
+        self.copy_based_source = self.system_source.copy_based_source
+        self.uri = self.system_source.uri
+        self.eid = self.system_source.eid
+        # Directory to write temporary files
+        self.dump_output_dir = dump_output_dir
+        # Allow to execute code with SQLite backend that does
+        # not support (yet...) copy_from
+        # XXX Should be dealt with in logilab.database
+        spcfrom = system_source.dbhelper.dbapi_module.support_copy_from
+        self.support_copy_from = spcfrom
+        self.dbencoding = system_source.dbhelper.dbencoding
+        self.nb_threads_statement = nb_threads_statement
+        # initialize thread-local data for main thread
+        self.init_thread_locals()
+        self._inlined_rtypes_cache = {}
+        self._fill_inlined_rtypes_cache(schema)
+        self.schema = schema
+        self.do_fti = False
+
+    def _fill_inlined_rtypes_cache(self, schema):
+        cache = self._inlined_rtypes_cache
+        for eschema in schema.entities():
+            for rschema in eschema.ordered_relations():
+                if rschema.inlined:
+                    cache[eschema.type] = SQL_PREFIX + rschema.type
+
+    def init_thread_locals(self):
+        """initializes thread-local data"""
+        self._sql.entities = defaultdict(list)
+        self._sql.relations = {}
+        self._sql.inlined_relations = {}
+        # keep track, for each eid of the corresponding data dict
+        self._sql.eid_insertdicts = {}
+
+    def flush(self):
+        print 'starting flush'
+        _entities_sql = self._sql.entities
+        _relations_sql = self._sql.relations
+        _inlined_relations_sql = self._sql.inlined_relations
+        _insertdicts = self._sql.eid_insertdicts
+        try:
+            # try, for each inlined_relation, to find if we're also creating
+            # the host entity (i.e. the subject of the relation).
+            # In that case, simply update the insert dict and remove
+            # the need to make the
+            # UPDATE statement
+            for statement, datalist in _inlined_relations_sql.iteritems():
+                new_datalist = []
+                # for a given inlined relation,
+                # browse each couple to be inserted
+                for data in datalist:
+                    keys = data.keys()
+                    # For inlined relations, it exists only two case:
+                    # (rtype, cw_eid) or (cw_eid, rtype)
+                    if keys[0] == 'cw_eid':
+                        rtype = keys[1]
+                    else:
+                        rtype = keys[0]
+                    updated_eid = data['cw_eid']
+                    if updated_eid in _insertdicts:
+                        _insertdicts[updated_eid][rtype] = data[rtype]
+                    else:
+                        # could not find corresponding insert dict, keep the
+                        # UPDATE query
+                        new_datalist.append(data)
+                _inlined_relations_sql[statement] = new_datalist
+            _import_statements(self.system_source.get_connection,
+                               _entities_sql.items()
+                               + _relations_sql.items()
+                               + _inlined_relations_sql.items(),
+                               dump_output_dir=self.dump_output_dir,
+                               nb_threads=self.nb_threads_statement,
+                               support_copy_from=self.support_copy_from,
+                               encoding=self.dbencoding)
+        except:
+            print 'failed to flush'
+        finally:
+            _entities_sql.clear()
+            _relations_sql.clear()
+            _insertdicts.clear()
+            _inlined_relations_sql.clear()
+            print 'flush done'
+
+    def add_relation(self, session, subject, rtype, object,
+                     inlined=False, subjtype=None):
+        if inlined:
+            _sql = self._sql.inlined_relations
+            data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
+            if subjtype is None:
+                # Try to infer it
+                targets = [t.type for t in
+                           self.schema.rschema(rtype).targets()]
+                if len(targets) == 1:
+                    subjtype = targets[0]
+                else:
+                    raise ValueError('You should give the subject etype for '
+                                     'inlined relation %s'
+                                     ', as it cannot be inferred' % rtype)
+            statement = self.sqlgen.update(SQL_PREFIX + subjtype,
+                                           data, ['cw_eid'])
+        else:
+            _sql = self._sql.relations
+            data = {'eid_from': subject, 'eid_to': object}
+            statement = self.sqlgen.insert('%s_relation' % rtype, data)
+        if statement in _sql:
+            _sql[statement].append(data)
+        else:
+            _sql[statement] = [data]
+
+    def add_entity(self, session, entity):
+        with self._storage_handler(entity, 'added'):
+            attrs = self.preprocess_entity(entity)
+            rtypes = self._inlined_rtypes_cache.get(entity.__regid__, ())
+            if isinstance(rtypes, str):
+                rtypes = (rtypes,)
+            for rtype in rtypes:
+                if rtype not in attrs:
+                    attrs[rtype] = None
+            sql = self.sqlgen.insert(SQL_PREFIX + entity.__regid__, attrs)
+            self._sql.eid_insertdicts[entity.eid] = attrs
+            self._append_to_entities(sql, attrs)
+
+    def _append_to_entities(self, sql, attrs):
+        self._sql.entities[sql].append(attrs)
+
+    def _handle_insert_entity_sql(self, session, sql, attrs):
+        # We have to overwrite the source given in parameters
+        # as here, we directly use the system source
+        attrs['source'] = 'system'
+        attrs['asource'] = self.system_source.uri
+        self._append_to_entities(sql, attrs)
+
+    def _handle_is_relation_sql(self, session, sql, attrs):
+        self._append_to_entities(sql, attrs)
+
+    def _handle_is_instance_of_sql(self, session, sql, attrs):
+        self._append_to_entities(sql, attrs)
+
+    def _handle_source_relation_sql(self, session, sql, attrs):
+        self._append_to_entities(sql, attrs)
+
+    # XXX add_info is similar to the one in NativeSQLSource. It is rewritten
+    # here to correctly used the _handle_xxx of the SQLGenSourceWrapper. This
+    # part should be rewritten in a more clearly way.
+    def add_info(self, session, entity, source, extid, complete):
+        """add type and source info for an eid into the system table"""
+        # begin by inserting eid/type/source/extid into the entities table
+        if extid is not None:
+            assert isinstance(extid, str)
+            extid = b64encode(extid)
+        uri = 'system' if source.copy_based_source else source.uri
+        attrs = {'type': entity.__regid__, 'eid': entity.eid, 'extid': extid,
+                 'source': uri, 'asource': source.uri, 'mtime': datetime.utcnow()}
+        self._handle_insert_entity_sql(session, self.sqlgen.insert('entities', attrs), attrs)
+        # insert core relations: is, is_instance_of and cw_source
+        try:
+            self._handle_is_relation_sql(session, 'INSERT INTO is_relation(eid_from,eid_to) VALUES (%s,%s)',
+                                         (entity.eid, eschema_eid(session, entity.e_schema)))
+        except IndexError:
+            # during schema serialization, skip
+            pass
+        else:
+            for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
+                self._handle_is_relation_sql(session,
+                                             'INSERT INTO is_instance_of_relation(eid_from,eid_to) VALUES (%s,%s)',
+                                             (entity.eid, eschema_eid(session, eschema)))
+        if 'CWSource' in self.schema and source.eid is not None: # else, cw < 3.10
+            self._handle_is_relation_sql(session, 'INSERT INTO cw_source_relation(eid_from,eid_to) VALUES (%s,%s)',
+                                         (entity.eid, source.eid))
+        # now we can update the full text index
+        if self.do_fti and self.need_fti_indexation(entity.__regid__):
+            if complete:
+                entity.complete(entity.e_schema.indexable_attributes())
+            self.index_entity(session, entity=entity)