cubicweb/dataimport/pgstore.py
changeset 11057 0b59724cb3f2
parent 11010 09be4709c8c0
child 11129 97095348b3ee
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
       
     1 # copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """Postgres specific store"""
       
    19 from __future__ import print_function
       
    20 
       
    21 import warnings
       
    22 import os.path as osp
       
    23 from io import StringIO
       
    24 from time import asctime
       
    25 from datetime import date, datetime, time
       
    26 from collections import defaultdict
       
    27 from base64 import b64encode
       
    28 
       
    29 from six import string_types, integer_types, text_type, binary_type
       
    30 from six.moves import cPickle as pickle, range
       
    31 
       
    32 from cubicweb.utils import make_uid
       
    33 from cubicweb.server.utils import eschema_eid
       
    34 from cubicweb.server.sqlutils import SQL_PREFIX
       
    35 from cubicweb.dataimport.stores import NoHookRQLObjectStore
       
    36 
       
    37 
       
    38 def _execmany_thread_not_copy_from(cu, statement, data, table=None,
       
    39                                    columns=None, encoding='utf-8'):
       
    40     """ Execute thread without copy from
       
    41     """
       
    42     cu.executemany(statement, data)
       
    43 
       
    44 def _execmany_thread_copy_from(cu, statement, data, table,
       
    45                                columns, encoding='utf-8'):
       
    46     """ Execute thread with copy from
       
    47     """
       
    48     try:
       
    49         buf = _create_copyfrom_buffer(data, columns, encoding=encoding)
       
    50     except ValueError:
       
    51         _execmany_thread_not_copy_from(cu, statement, data)
       
    52     else:
       
    53         if columns is None:
       
    54             cu.copy_from(buf, table, null=u'NULL')
       
    55         else:
       
    56             cu.copy_from(buf, table, null=u'NULL', columns=columns)
       
    57 
       
    58 def _execmany_thread(sql_connect, statements, dump_output_dir=None,
       
    59                      support_copy_from=True, encoding='utf-8'):
       
    60     """
       
    61     Execute sql statement. If 'INSERT INTO', try to use 'COPY FROM' command,
       
    62     or fallback to execute_many.
       
    63     """
       
    64     if support_copy_from:
       
    65         execmany_func = _execmany_thread_copy_from
       
    66     else:
       
    67         execmany_func = _execmany_thread_not_copy_from
       
    68     cnx = sql_connect()
       
    69     cu = cnx.cursor()
       
    70     try:
       
    71         for statement, data in statements:
       
    72             table = None
       
    73             columns = None
       
    74             try:
       
    75                 if not statement.startswith('INSERT INTO'):
       
    76                     cu.executemany(statement, data)
       
    77                     continue
       
    78                 table = statement.split()[2]
       
    79                 if isinstance(data[0], (tuple, list)):
       
    80                     columns = None
       
    81                 else:
       
    82                     columns = list(data[0])
       
    83                 execmany_func(cu, statement, data, table, columns, encoding)
       
    84             except Exception:
       
    85                 print('unable to copy data into table %s' % table)
       
    86                 # Error in import statement, save data in dump_output_dir
       
    87                 if dump_output_dir is not None:
       
    88                     pdata = {'data': data, 'statement': statement,
       
    89                              'time': asctime(), 'columns': columns}
       
    90                     filename = make_uid()
       
    91                     try:
       
    92                         with open(osp.join(dump_output_dir,
       
    93                                            '%s.pickle' % filename), 'wb') as fobj:
       
    94                             pickle.dump(pdata, fobj)
       
    95                     except IOError:
       
    96                         print('ERROR while pickling in', dump_output_dir, filename+'.pickle')
       
    97                 cnx.rollback()
       
    98                 raise
       
    99     finally:
       
   100         cnx.commit()
       
   101         cu.close()
       
   102 
       
   103 
       
   104 def _copyfrom_buffer_convert_None(value, **opts):
       
   105     '''Convert None value to "NULL"'''
       
   106     return u'NULL'
       
   107 
       
   108 def _copyfrom_buffer_convert_number(value, **opts):
       
   109     '''Convert a number into its string representation'''
       
   110     return text_type(value)
       
   111 
       
   112 def _copyfrom_buffer_convert_string(value, **opts):
       
   113     '''Convert string value.
       
   114     '''
       
   115     escape_chars = ((u'\\', u'\\\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
       
   116                     (u'\n', u'\\n'))
       
   117     for char, replace in escape_chars:
       
   118         value = value.replace(char, replace)
       
   119     return value
       
   120 
       
   121 def _copyfrom_buffer_convert_date(value, **opts):
       
   122     '''Convert date into "YYYY-MM-DD"'''
       
   123     # Do not use strftime, as it yields issue with date < 1900
       
   124     # (http://bugs.python.org/issue1777412)
       
   125     return u'%04d-%02d-%02d' % (value.year, value.month, value.day)
       
   126 
       
   127 def _copyfrom_buffer_convert_datetime(value, **opts):
       
   128     '''Convert date into "YYYY-MM-DD HH:MM:SS.UUUUUU"'''
       
   129     # Do not use strftime, as it yields issue with date < 1900
       
   130     # (http://bugs.python.org/issue1777412)
       
   131     return u'%s %s' % (_copyfrom_buffer_convert_date(value, **opts),
       
   132                        _copyfrom_buffer_convert_time(value, **opts))
       
   133 
       
   134 def _copyfrom_buffer_convert_time(value, **opts):
       
   135     '''Convert time into "HH:MM:SS.UUUUUU"'''
       
   136     return u'%02d:%02d:%02d.%06d' % (value.hour, value.minute,
       
   137                                      value.second, value.microsecond)
       
   138 
       
   139 # (types, converter) list.
       
   140 _COPYFROM_BUFFER_CONVERTERS = [
       
   141     (type(None), _copyfrom_buffer_convert_None),
       
   142     (integer_types + (float,), _copyfrom_buffer_convert_number),
       
   143     (string_types, _copyfrom_buffer_convert_string),
       
   144     (datetime, _copyfrom_buffer_convert_datetime),
       
   145     (date, _copyfrom_buffer_convert_date),
       
   146     (time, _copyfrom_buffer_convert_time),
       
   147 ]
       
   148 
       
   149 def _create_copyfrom_buffer(data, columns=None, **convert_opts):
       
   150     """
       
   151     Create a StringIO buffer for 'COPY FROM' command.
       
   152     Deals with Unicode, Int, Float, Date... (see ``converters``)
       
   153 
       
   154     :data: a sequence/dict of tuples
       
   155     :columns: list of columns to consider (default to all columns)
       
   156     :converter_opts: keyword arguements given to converters
       
   157     """
       
   158     # Create a list rather than directly create a StringIO
       
   159     # to correctly write lines separated by '\n' in a single step
       
   160     rows = []
       
   161     if columns is None:
       
   162         if isinstance(data[0], (tuple, list)):
       
   163             columns = list(range(len(data[0])))
       
   164         elif isinstance(data[0], dict):
       
   165             columns = data[0].keys()
       
   166         else:
       
   167             raise ValueError('Could not get columns: you must provide columns.')
       
   168     for row in data:
       
   169         # Iterate over the different columns and the different values
       
   170         # and try to convert them to a correct datatype.
       
   171         # If an error is raised, do not continue.
       
   172         formatted_row = []
       
   173         for col in columns:
       
   174             try:
       
   175                 value = row[col]
       
   176             except KeyError:
       
   177                 warnings.warn(u"Column %s is not accessible in row %s"
       
   178                               % (col, row), RuntimeWarning)
       
   179                 # XXX 'value' set to None so that the import does not end in
       
   180                 # error.
       
   181                 # Instead, the extra keys are set to NULL from the
       
   182                 # database point of view.
       
   183                 value = None
       
   184             for types, converter in _COPYFROM_BUFFER_CONVERTERS:
       
   185                 if isinstance(value, types):
       
   186                     value = converter(value, **convert_opts)
       
   187                     assert isinstance(value, text_type)
       
   188                     break
       
   189             else:
       
   190                 raise ValueError("Unsupported value type %s" % type(value))
       
   191             # We push the value to the new formatted row
       
   192             # if the value is not None and could be converted to a string.
       
   193             formatted_row.append(value)
       
   194         rows.append('\t'.join(formatted_row))
       
   195     return StringIO('\n'.join(rows))
       
   196 
       
   197 
       
   198 class SQLGenObjectStore(NoHookRQLObjectStore):
       
   199     """Controller of the data import process. This version is based
       
   200     on direct insertions throught SQL command (COPY FROM or execute many).
       
   201 
       
   202     >>> store = SQLGenObjectStore(cnx)
       
   203     >>> store.create_entity('Person', ...)
       
   204     >>> store.flush()
       
   205     """
       
   206 
       
   207     def __init__(self, cnx, dump_output_dir=None, nb_threads_statement=1):
       
   208         """
       
   209         Initialize a SQLGenObjectStore.
       
   210 
       
   211         Parameters:
       
   212 
       
   213           - cnx: connection on the cubicweb instance
       
   214           - dump_output_dir: a directory to dump failed statements
       
   215             for easier recovery. Default is None (no dump).
       
   216         """
       
   217         super(SQLGenObjectStore, self).__init__(cnx)
       
   218         ### hijack default source
       
   219         self.source = SQLGenSourceWrapper(
       
   220             self.source, cnx.vreg.schema,
       
   221             dump_output_dir=dump_output_dir)
       
   222         ### XXX This is done in super().__init__(), but should be
       
   223         ### redone here to link to the correct source
       
   224         self.add_relation = self.source.add_relation
       
   225         self.indexes_etypes = {}
       
   226         if nb_threads_statement != 1:
       
   227             warn('[3.21] SQLGenObjectStore is no longer threaded', DeprecationWarning)
       
   228 
       
   229     def flush(self):
       
   230         """Flush data to the database"""
       
   231         self.source.flush()
       
   232 
       
   233     def relate(self, subj_eid, rtype, obj_eid, **kwargs):
       
   234         if subj_eid is None or obj_eid is None:
       
   235             return
       
   236         # XXX Could subjtype be inferred ?
       
   237         self.source.add_relation(self._cnx, subj_eid, rtype, obj_eid,
       
   238                                  self.rschema(rtype).inlined, **kwargs)
       
   239         if self.rschema(rtype).symmetric:
       
   240             self.source.add_relation(self._cnx, obj_eid, rtype, subj_eid,
       
   241                                      self.rschema(rtype).inlined, **kwargs)
       
   242 
       
   243     def drop_indexes(self, etype):
       
   244         """Drop indexes for a given entity type"""
       
   245         if etype not in self.indexes_etypes:
       
   246             cu = self._cnx.cnxset.cu
       
   247             def index_to_attr(index):
       
   248                 """turn an index name to (database) attribute name"""
       
   249                 return index.replace(etype.lower(), '').replace('idx', '').strip('_')
       
   250             indices = [(index, index_to_attr(index))
       
   251                        for index in self.source.dbhelper.list_indices(cu, etype)
       
   252                        # Do not consider 'cw_etype_pkey' index
       
   253                        if not index.endswith('key')]
       
   254             self.indexes_etypes[etype] = indices
       
   255         for index, attr in self.indexes_etypes[etype]:
       
   256             self._cnx.system_sql('DROP INDEX %s' % index)
       
   257 
       
   258     def create_indexes(self, etype):
       
   259         """Recreate indexes for a given entity type"""
       
   260         for index, attr in self.indexes_etypes.get(etype, []):
       
   261             sql = 'CREATE INDEX %s ON cw_%s(%s)' % (index, etype, attr)
       
   262             self._cnx.system_sql(sql)
       
   263 
       
   264 
       
   265 ###########################################################################
       
   266 ## SQL Source #############################################################
       
   267 ###########################################################################
       
   268 
       
   269 class SQLGenSourceWrapper(object):
       
   270 
       
   271     def __init__(self, system_source, schema,
       
   272                  dump_output_dir=None):
       
   273         self.system_source = system_source
       
   274         # Explicitely backport attributes from system source
       
   275         self._storage_handler = self.system_source._storage_handler
       
   276         self.preprocess_entity = self.system_source.preprocess_entity
       
   277         self.sqlgen = self.system_source.sqlgen
       
   278         self.uri = self.system_source.uri
       
   279         self.eid = self.system_source.eid
       
   280         # Directory to write temporary files
       
   281         self.dump_output_dir = dump_output_dir
       
   282         # Allow to execute code with SQLite backend that does
       
   283         # not support (yet...) copy_from
       
   284         # XXX Should be dealt with in logilab.database
       
   285         spcfrom = system_source.dbhelper.dbapi_module.support_copy_from
       
   286         self.support_copy_from = spcfrom
       
   287         self.dbencoding = system_source.dbhelper.dbencoding
       
   288         self.init_statement_lists()
       
   289         self._inlined_rtypes_cache = {}
       
   290         self._fill_inlined_rtypes_cache(schema)
       
   291         self.schema = schema
       
   292         self.do_fti = False
       
   293 
       
   294     def _fill_inlined_rtypes_cache(self, schema):
       
   295         cache = self._inlined_rtypes_cache
       
   296         for eschema in schema.entities():
       
   297             for rschema in eschema.ordered_relations():
       
   298                 if rschema.inlined:
       
   299                     cache[eschema.type] = SQL_PREFIX + rschema.type
       
   300 
       
   301     def init_statement_lists(self):
       
   302         self._sql_entities = defaultdict(list)
       
   303         self._sql_relations = {}
       
   304         self._sql_inlined_relations = {}
       
   305         self._sql_eids = defaultdict(list)
       
   306         # keep track, for each eid of the corresponding data dict
       
   307         self._sql_eid_insertdicts = {}
       
   308 
       
   309     def flush(self):
       
   310         print('starting flush')
       
   311         _entities_sql = self._sql_entities
       
   312         _relations_sql = self._sql_relations
       
   313         _inlined_relations_sql = self._sql_inlined_relations
       
   314         _insertdicts = self._sql_eid_insertdicts
       
   315         try:
       
   316             # try, for each inlined_relation, to find if we're also creating
       
   317             # the host entity (i.e. the subject of the relation).
       
   318             # In that case, simply update the insert dict and remove
       
   319             # the need to make the
       
   320             # UPDATE statement
       
   321             for statement, datalist in _inlined_relations_sql.items():
       
   322                 new_datalist = []
       
   323                 # for a given inlined relation,
       
   324                 # browse each couple to be inserted
       
   325                 for data in datalist:
       
   326                     keys = list(data)
       
   327                     # For inlined relations, it exists only two case:
       
   328                     # (rtype, cw_eid) or (cw_eid, rtype)
       
   329                     if keys[0] == 'cw_eid':
       
   330                         rtype = keys[1]
       
   331                     else:
       
   332                         rtype = keys[0]
       
   333                     updated_eid = data['cw_eid']
       
   334                     if updated_eid in _insertdicts:
       
   335                         _insertdicts[updated_eid][rtype] = data[rtype]
       
   336                     else:
       
   337                         # could not find corresponding insert dict, keep the
       
   338                         # UPDATE query
       
   339                         new_datalist.append(data)
       
   340                 _inlined_relations_sql[statement] = new_datalist
       
   341             _execmany_thread(self.system_source.get_connection,
       
   342                              list(self._sql_eids.items())
       
   343                              + list(_entities_sql.items())
       
   344                              + list(_relations_sql.items())
       
   345                              + list(_inlined_relations_sql.items()),
       
   346                              dump_output_dir=self.dump_output_dir,
       
   347                              support_copy_from=self.support_copy_from,
       
   348                              encoding=self.dbencoding)
       
   349         finally:
       
   350             _entities_sql.clear()
       
   351             _relations_sql.clear()
       
   352             _insertdicts.clear()
       
   353             _inlined_relations_sql.clear()
       
   354 
       
   355     def add_relation(self, cnx, subject, rtype, object,
       
   356                      inlined=False, **kwargs):
       
   357         if inlined:
       
   358             _sql = self._sql_inlined_relations
       
   359             data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
       
   360             subjtype = kwargs.get('subjtype')
       
   361             if subjtype is None:
       
   362                 # Try to infer it
       
   363                 targets = [t.type for t in
       
   364                            self.schema.rschema(rtype).subjects()]
       
   365                 if len(targets) == 1:
       
   366                     subjtype = targets[0]
       
   367                 else:
       
   368                     raise ValueError('You should give the subject etype for '
       
   369                                      'inlined relation %s'
       
   370                                      ', as it cannot be inferred: '
       
   371                                      'this type is given as keyword argument '
       
   372                                      '``subjtype``'% rtype)
       
   373             statement = self.sqlgen.update(SQL_PREFIX + subjtype,
       
   374                                            data, ['cw_eid'])
       
   375         else:
       
   376             _sql = self._sql_relations
       
   377             data = {'eid_from': subject, 'eid_to': object}
       
   378             statement = self.sqlgen.insert('%s_relation' % rtype, data)
       
   379         if statement in _sql:
       
   380             _sql[statement].append(data)
       
   381         else:
       
   382             _sql[statement] = [data]
       
   383 
       
   384     def add_entity(self, cnx, entity):
       
   385         with self._storage_handler(cnx, entity, 'added'):
       
   386             attrs = self.preprocess_entity(entity)
       
   387             rtypes = self._inlined_rtypes_cache.get(entity.cw_etype, ())
       
   388             if isinstance(rtypes, str):
       
   389                 rtypes = (rtypes,)
       
   390             for rtype in rtypes:
       
   391                 if rtype not in attrs:
       
   392                     attrs[rtype] = None
       
   393             sql = self.sqlgen.insert(SQL_PREFIX + entity.cw_etype, attrs)
       
   394             self._sql_eid_insertdicts[entity.eid] = attrs
       
   395             self._append_to_entities(sql, attrs)
       
   396 
       
   397     def _append_to_entities(self, sql, attrs):
       
   398         self._sql_entities[sql].append(attrs)
       
   399 
       
   400     def _handle_insert_entity_sql(self, cnx, sql, attrs):
       
   401         # We have to overwrite the source given in parameters
       
   402         # as here, we directly use the system source
       
   403         attrs['asource'] = self.system_source.uri
       
   404         self._sql_eids[sql].append(attrs)
       
   405 
       
   406     def _handle_is_relation_sql(self, cnx, sql, attrs):
       
   407         self._append_to_entities(sql, attrs)
       
   408 
       
   409     def _handle_is_instance_of_sql(self, cnx, sql, attrs):
       
   410         self._append_to_entities(sql, attrs)
       
   411 
       
   412     def _handle_source_relation_sql(self, cnx, sql, attrs):
       
   413         self._append_to_entities(sql, attrs)
       
   414 
       
   415     # add_info is _copypasted_ from the one in NativeSQLSource. We want it
       
   416     # there because it will use the _handlers of the SQLGenSourceWrapper, which
       
   417     # are not like the ones in the native source.
       
   418     def add_info(self, cnx, entity, source, extid):
       
   419         """add type and source info for an eid into the system table"""
       
   420         # begin by inserting eid/type/source/extid into the entities table
       
   421         if extid is not None:
       
   422             assert isinstance(extid, binary_type)
       
   423             extid = b64encode(extid).decode('ascii')
       
   424         attrs = {'type': entity.cw_etype, 'eid': entity.eid, 'extid': extid,
       
   425                  'asource': source.uri}
       
   426         self._handle_insert_entity_sql(cnx, self.sqlgen.insert('entities', attrs), attrs)
       
   427         # insert core relations: is, is_instance_of and cw_source
       
   428         try:
       
   429             self._handle_is_relation_sql(cnx, 'INSERT INTO is_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   430                                          (entity.eid, eschema_eid(cnx, entity.e_schema)))
       
   431         except IndexError:
       
   432             # during schema serialization, skip
       
   433             pass
       
   434         else:
       
   435             for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
       
   436                 self._handle_is_relation_sql(cnx,
       
   437                                              'INSERT INTO is_instance_of_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   438                                              (entity.eid, eschema_eid(cnx, eschema)))
       
   439         if 'CWSource' in self.schema and source.eid is not None: # else, cw < 3.10
       
   440             self._handle_is_relation_sql(cnx, 'INSERT INTO cw_source_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   441                                          (entity.eid, source.eid))
       
   442         # now we can update the full text index
       
   443         if self.do_fti and self.need_fti_indexation(entity.cw_etype):
       
   444             self.index_entity(cnx, entity=entity)