dataimport/pgstore.py
author Julien Cristau <julien.cristau@logilab.fr>
Mon, 19 Oct 2015 14:29:06 +0200
changeset 10804 ee113e1e03de
parent 10662 10942ed172de
child 10810 0768bf2333a7
permissions -rw-r--r--
[devtools] pass a key to sort() method python3 dicts are not comparable

# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""Postgres specific store"""
from __future__ import print_function

import threading
import warnings
import os.path as osp
from StringIO import StringIO
from time import asctime
from datetime import date, datetime, time
from collections import defaultdict
from base64 import b64encode

from six import string_types, integer_types
from six.moves import cPickle as pickle, range

from cubicweb.utils import make_uid
from cubicweb.server.sqlutils import SQL_PREFIX
from cubicweb.dataimport.stores import NoHookRQLObjectStore

def _import_statements(sql_connect, statements, nb_threads=3,
                       dump_output_dir=None,
                       support_copy_from=True, encoding='utf-8'):
    """
    Import a bunch of sql statements, using different threads.
    """
    try:
        chunksize = (len(statements) / nb_threads) + 1
        threads = []
        for i in range(nb_threads):
            chunks = statements[i*chunksize:(i+1)*chunksize]
            thread = threading.Thread(target=_execmany_thread,
                                      args=(sql_connect, chunks,
                                            dump_output_dir,
                                            support_copy_from,
                                            encoding))
            thread.start()
            threads.append(thread)
        for t in threads:
            t.join()
    except Exception:
        print('Error in import statements')

def _execmany_thread_not_copy_from(cu, statement, data, table=None,
                                   columns=None, encoding='utf-8'):
    """ Execute thread without copy from
    """
    cu.executemany(statement, data)

def _execmany_thread_copy_from(cu, statement, data, table,
                               columns, encoding='utf-8'):
    """ Execute thread with copy from
    """
    buf = _create_copyfrom_buffer(data, columns, encoding=encoding)
    if buf is None:
        _execmany_thread_not_copy_from(cu, statement, data)
    else:
        if columns is None:
            cu.copy_from(buf, table, null='NULL')
        else:
            cu.copy_from(buf, table, null='NULL', columns=columns)

def _execmany_thread(sql_connect, statements, dump_output_dir=None,
                     support_copy_from=True, encoding='utf-8'):
    """
    Execute sql statement. If 'INSERT INTO', try to use 'COPY FROM' command,
    or fallback to execute_many.
    """
    if support_copy_from:
        execmany_func = _execmany_thread_copy_from
    else:
        execmany_func = _execmany_thread_not_copy_from
    cnx = sql_connect()
    cu = cnx.cursor()
    try:
        for statement, data in statements:
            table = None
            columns = None
            try:
                if not statement.startswith('INSERT INTO'):
                    cu.executemany(statement, data)
                    continue
                table = statement.split()[2]
                if isinstance(data[0], (tuple, list)):
                    columns = None
                else:
                    columns = list(data[0])
                execmany_func(cu, statement, data, table, columns, encoding)
            except Exception:
                print('unable to copy data into table %s' % table)
                # Error in import statement, save data in dump_output_dir
                if dump_output_dir is not None:
                    pdata = {'data': data, 'statement': statement,
                             'time': asctime(), 'columns': columns}
                    filename = make_uid()
                    try:
                        with open(osp.join(dump_output_dir,
                                           '%s.pickle' % filename), 'wb') as fobj:
                            pickle.dump(pdata, fobj)
                    except IOError:
                        print('ERROR while pickling in', dump_output_dir, filename+'.pickle')
                cnx.rollback()
                raise
    finally:
        cnx.commit()
        cu.close()


def _copyfrom_buffer_convert_None(value, **opts):
    '''Convert None value to "NULL"'''
    return 'NULL'

def _copyfrom_buffer_convert_number(value, **opts):
    '''Convert a number into its string representation'''
    return str(value)

def _copyfrom_buffer_convert_string(value, **opts):
    '''Convert string value.

    Recognized keywords:
    :encoding: resulting string encoding (default: utf-8)
    '''
    encoding = opts.get('encoding','utf-8')
    escape_chars = ((u'\\', u'\\\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
                    (u'\n', u'\\n'))
    for char, replace in escape_chars:
        value = value.replace(char, replace)
    if isinstance(value, unicode):
        value = value.encode(encoding)
    return value

def _copyfrom_buffer_convert_date(value, **opts):
    '''Convert date into "YYYY-MM-DD"'''
    # Do not use strftime, as it yields issue with date < 1900
    # (http://bugs.python.org/issue1777412)
    return '%04d-%02d-%02d' % (value.year, value.month, value.day)

def _copyfrom_buffer_convert_datetime(value, **opts):
    '''Convert date into "YYYY-MM-DD HH:MM:SS.UUUUUU"'''
    # Do not use strftime, as it yields issue with date < 1900
    # (http://bugs.python.org/issue1777412)
    return '%s %s' % (_copyfrom_buffer_convert_date(value, **opts),
                      _copyfrom_buffer_convert_time(value, **opts))

def _copyfrom_buffer_convert_time(value, **opts):
    '''Convert time into "HH:MM:SS.UUUUUU"'''
    return '%02d:%02d:%02d.%06d' % (value.hour, value.minute,
                                    value.second, value.microsecond)

# (types, converter) list.
_COPYFROM_BUFFER_CONVERTERS = [
    (type(None), _copyfrom_buffer_convert_None),
    (integer_types + (float,), _copyfrom_buffer_convert_number),
    (string_types, _copyfrom_buffer_convert_string),
    (datetime, _copyfrom_buffer_convert_datetime),
    (date, _copyfrom_buffer_convert_date),
    (time, _copyfrom_buffer_convert_time),
]

def _create_copyfrom_buffer(data, columns=None, **convert_opts):
    """
    Create a StringIO buffer for 'COPY FROM' command.
    Deals with Unicode, Int, Float, Date... (see ``converters``)

    :data: a sequence/dict of tuples
    :columns: list of columns to consider (default to all columns)
    :converter_opts: keyword arguements given to converters
    """
    # Create a list rather than directly create a StringIO
    # to correctly write lines separated by '\n' in a single step
    rows = []
    if columns is None:
        if isinstance(data[0], (tuple, list)):
            columns = list(range(len(data[0])))
        elif isinstance(data[0], dict):
            columns = data[0].keys()
        else:
            raise ValueError('Could not get columns: you must provide columns.')
    for row in data:
        # Iterate over the different columns and the different values
        # and try to convert them to a correct datatype.
        # If an error is raised, do not continue.
        formatted_row = []
        for col in columns:
            try:
                value = row[col]
            except KeyError:
                warnings.warn(u"Column %s is not accessible in row %s"
                              % (col, row), RuntimeWarning)
                # XXX 'value' set to None so that the import does not end in
                # error.
                # Instead, the extra keys are set to NULL from the
                # database point of view.
                value = None
            for types, converter in _COPYFROM_BUFFER_CONVERTERS:
                if isinstance(value, types):
                    value = converter(value, **convert_opts)
                    break
            else:
                raise ValueError("Unsupported value type %s" % type(value))
            # We push the value to the new formatted row
            # if the value is not None and could be converted to a string.
            formatted_row.append(value)
        rows.append('\t'.join(formatted_row))
    return StringIO('\n'.join(rows))


class SQLGenObjectStore(NoHookRQLObjectStore):
    """Controller of the data import process. This version is based
    on direct insertions throught SQL command (COPY FROM or execute many).

    >>> store = SQLGenObjectStore(cnx)
    >>> store.create_entity('Person', ...)
    >>> store.flush()
    """

    def __init__(self, cnx, dump_output_dir=None, nb_threads_statement=3):
        """
        Initialize a SQLGenObjectStore.

        Parameters:

          - cnx: connection on the cubicweb instance
          - dump_output_dir: a directory to dump failed statements
            for easier recovery. Default is None (no dump).
          - nb_threads_statement: number of threads used
            for SQL insertion (default is 3).
        """
        super(SQLGenObjectStore, self).__init__(cnx)
        ### hijack default source
        self.source = SQLGenSourceWrapper(
            self.source, cnx.vreg.schema,
            dump_output_dir=dump_output_dir,
            nb_threads_statement=nb_threads_statement)
        ### XXX This is done in super().__init__(), but should be
        ### redone here to link to the correct source
        self.add_relation = self.source.add_relation
        self.indexes_etypes = {}

    def flush(self):
        """Flush data to the database"""
        self.source.flush()

    def relate(self, subj_eid, rtype, obj_eid, **kwargs):
        if subj_eid is None or obj_eid is None:
            return
        # XXX Could subjtype be inferred ?
        self.source.add_relation(self._cnx, subj_eid, rtype, obj_eid,
                                 self.rschema(rtype).inlined, **kwargs)
        if self.rschema(rtype).symmetric:
            self.source.add_relation(self._cnx, obj_eid, rtype, subj_eid,
                                     self.rschema(rtype).inlined, **kwargs)

    def drop_indexes(self, etype):
        """Drop indexes for a given entity type"""
        if etype not in self.indexes_etypes:
            cu = self._cnx.cnxset.cu
            def index_to_attr(index):
                """turn an index name to (database) attribute name"""
                return index.replace(etype.lower(), '').replace('idx', '').strip('_')
            indices = [(index, index_to_attr(index))
                       for index in self.source.dbhelper.list_indices(cu, etype)
                       # Do not consider 'cw_etype_pkey' index
                       if not index.endswith('key')]
            self.indexes_etypes[etype] = indices
        for index, attr in self.indexes_etypes[etype]:
            self._cnx.system_sql('DROP INDEX %s' % index)

    def create_indexes(self, etype):
        """Recreate indexes for a given entity type"""
        for index, attr in self.indexes_etypes.get(etype, []):
            sql = 'CREATE INDEX %s ON cw_%s(%s)' % (index, etype, attr)
            self._cnx.system_sql(sql)


###########################################################################
## SQL Source #############################################################
###########################################################################

class SQLGenSourceWrapper(object):

    def __init__(self, system_source, schema,
                 dump_output_dir=None, nb_threads_statement=3):
        self.system_source = system_source
        self._sql = threading.local()
        # Explicitely backport attributes from system source
        self._storage_handler = self.system_source._storage_handler
        self.preprocess_entity = self.system_source.preprocess_entity
        self.sqlgen = self.system_source.sqlgen
        self.uri = self.system_source.uri
        self.eid = self.system_source.eid
        # Directory to write temporary files
        self.dump_output_dir = dump_output_dir
        # Allow to execute code with SQLite backend that does
        # not support (yet...) copy_from
        # XXX Should be dealt with in logilab.database
        spcfrom = system_source.dbhelper.dbapi_module.support_copy_from
        self.support_copy_from = spcfrom
        self.dbencoding = system_source.dbhelper.dbencoding
        self.nb_threads_statement = nb_threads_statement
        # initialize thread-local data for main thread
        self.init_thread_locals()
        self._inlined_rtypes_cache = {}
        self._fill_inlined_rtypes_cache(schema)
        self.schema = schema
        self.do_fti = False

    def _fill_inlined_rtypes_cache(self, schema):
        cache = self._inlined_rtypes_cache
        for eschema in schema.entities():
            for rschema in eschema.ordered_relations():
                if rschema.inlined:
                    cache[eschema.type] = SQL_PREFIX + rschema.type

    def init_thread_locals(self):
        """initializes thread-local data"""
        self._sql.entities = defaultdict(list)
        self._sql.relations = {}
        self._sql.inlined_relations = {}
        # keep track, for each eid of the corresponding data dict
        self._sql.eid_insertdicts = {}

    def flush(self):
        print('starting flush')
        _entities_sql = self._sql.entities
        _relations_sql = self._sql.relations
        _inlined_relations_sql = self._sql.inlined_relations
        _insertdicts = self._sql.eid_insertdicts
        try:
            # try, for each inlined_relation, to find if we're also creating
            # the host entity (i.e. the subject of the relation).
            # In that case, simply update the insert dict and remove
            # the need to make the
            # UPDATE statement
            for statement, datalist in _inlined_relations_sql.items():
                new_datalist = []
                # for a given inlined relation,
                # browse each couple to be inserted
                for data in datalist:
                    keys = list(data)
                    # For inlined relations, it exists only two case:
                    # (rtype, cw_eid) or (cw_eid, rtype)
                    if keys[0] == 'cw_eid':
                        rtype = keys[1]
                    else:
                        rtype = keys[0]
                    updated_eid = data['cw_eid']
                    if updated_eid in _insertdicts:
                        _insertdicts[updated_eid][rtype] = data[rtype]
                    else:
                        # could not find corresponding insert dict, keep the
                        # UPDATE query
                        new_datalist.append(data)
                _inlined_relations_sql[statement] = new_datalist
            _import_statements(self.system_source.get_connection,
                               _entities_sql.items()
                               + _relations_sql.items()
                               + _inlined_relations_sql.items(),
                               dump_output_dir=self.dump_output_dir,
                               nb_threads=self.nb_threads_statement,
                               support_copy_from=self.support_copy_from,
                               encoding=self.dbencoding)
        finally:
            _entities_sql.clear()
            _relations_sql.clear()
            _insertdicts.clear()
            _inlined_relations_sql.clear()

    def add_relation(self, cnx, subject, rtype, object,
                     inlined=False, **kwargs):
        if inlined:
            _sql = self._sql.inlined_relations
            data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
            subjtype = kwargs.get('subjtype')
            if subjtype is None:
                # Try to infer it
                targets = [t.type for t in
                           self.schema.rschema(rtype).subjects()]
                if len(targets) == 1:
                    subjtype = targets[0]
                else:
                    raise ValueError('You should give the subject etype for '
                                     'inlined relation %s'
                                     ', as it cannot be inferred: '
                                     'this type is given as keyword argument '
                                     '``subjtype``'% rtype)
            statement = self.sqlgen.update(SQL_PREFIX + subjtype,
                                           data, ['cw_eid'])
        else:
            _sql = self._sql.relations
            data = {'eid_from': subject, 'eid_to': object}
            statement = self.sqlgen.insert('%s_relation' % rtype, data)
        if statement in _sql:
            _sql[statement].append(data)
        else:
            _sql[statement] = [data]

    def add_entity(self, cnx, entity):
        with self._storage_handler(entity, 'added'):
            attrs = self.preprocess_entity(entity)
            rtypes = self._inlined_rtypes_cache.get(entity.cw_etype, ())
            if isinstance(rtypes, str):
                rtypes = (rtypes,)
            for rtype in rtypes:
                if rtype not in attrs:
                    attrs[rtype] = None
            sql = self.sqlgen.insert(SQL_PREFIX + entity.cw_etype, attrs)
            self._sql.eid_insertdicts[entity.eid] = attrs
            self._append_to_entities(sql, attrs)

    def _append_to_entities(self, sql, attrs):
        self._sql.entities[sql].append(attrs)

    def _handle_insert_entity_sql(self, cnx, sql, attrs):
        # We have to overwrite the source given in parameters
        # as here, we directly use the system source
        attrs['asource'] = self.system_source.uri
        self._append_to_entities(sql, attrs)

    def _handle_is_relation_sql(self, cnx, sql, attrs):
        self._append_to_entities(sql, attrs)

    def _handle_is_instance_of_sql(self, cnx, sql, attrs):
        self._append_to_entities(sql, attrs)

    def _handle_source_relation_sql(self, cnx, sql, attrs):
        self._append_to_entities(sql, attrs)

    # add_info is _copypasted_ from the one in NativeSQLSource. We want it
    # there because it will use the _handlers of the SQLGenSourceWrapper, which
    # are not like the ones in the native source.
    def add_info(self, cnx, entity, source, extid):
        """add type and source info for an eid into the system table"""
        # begin by inserting eid/type/source/extid into the entities table
        if extid is not None:
            assert isinstance(extid, str)
            extid = b64encode(extid)
        attrs = {'type': entity.cw_etype, 'eid': entity.eid, 'extid': extid,
                 'asource': source.uri}
        self._handle_insert_entity_sql(cnx, self.sqlgen.insert('entities', attrs), attrs)
        # insert core relations: is, is_instance_of and cw_source
        try:
            self._handle_is_relation_sql(cnx, 'INSERT INTO is_relation(eid_from,eid_to) VALUES (%s,%s)',
                                         (entity.eid, eschema_eid(cnx, entity.e_schema)))
        except IndexError:
            # during schema serialization, skip
            pass
        else:
            for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
                self._handle_is_relation_sql(cnx,
                                             'INSERT INTO is_instance_of_relation(eid_from,eid_to) VALUES (%s,%s)',
                                             (entity.eid, eschema_eid(cnx, eschema)))
        if 'CWSource' in self.schema and source.eid is not None: # else, cw < 3.10
            self._handle_is_relation_sql(cnx, 'INSERT INTO cw_source_relation(eid_from,eid_to) VALUES (%s,%s)',
                                         (entity.eid, source.eid))
        # now we can update the full text index
        if self.do_fti and self.need_fti_indexation(entity.cw_etype):
            self.index_entity(cnx, entity=entity)