dataimport/pgstore.py
changeset 10513 7bec01a59f92
child 10589 7c23b7de2b8d
child 10978 6f88cb7b7a84
equal deleted inserted replaced
10512:99bdd4bddd77 10513:7bec01a59f92
       
     1 # copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """Postgres specific store"""
       
    19 
       
    20 import threading
       
    21 import warnings
       
    22 import cPickle
       
    23 import os.path as osp
       
    24 from StringIO import StringIO
       
    25 from time import asctime
       
    26 from datetime import date, datetime, time
       
    27 from collections import defaultdict
       
    28 from base64 import b64encode
       
    29 
       
    30 from cubicweb.utils import make_uid
       
    31 from cubicweb.server.sqlutils import SQL_PREFIX
       
    32 from cubicweb.dataimport.stores import NoHookRQLObjectStore
       
    33 
       
    34 def _import_statements(sql_connect, statements, nb_threads=3,
       
    35                        dump_output_dir=None,
       
    36                        support_copy_from=True, encoding='utf-8'):
       
    37     """
       
    38     Import a bunch of sql statements, using different threads.
       
    39     """
       
    40     try:
       
    41         chunksize = (len(statements) / nb_threads) + 1
       
    42         threads = []
       
    43         for i in xrange(nb_threads):
       
    44             chunks = statements[i*chunksize:(i+1)*chunksize]
       
    45             thread = threading.Thread(target=_execmany_thread,
       
    46                                       args=(sql_connect, chunks,
       
    47                                             dump_output_dir,
       
    48                                             support_copy_from,
       
    49                                             encoding))
       
    50             thread.start()
       
    51             threads.append(thread)
       
    52         for t in threads:
       
    53             t.join()
       
    54     except Exception:
       
    55         print 'Error in import statements'
       
    56 
       
    57 def _execmany_thread_not_copy_from(cu, statement, data, table=None,
       
    58                                    columns=None, encoding='utf-8'):
       
    59     """ Execute thread without copy from
       
    60     """
       
    61     cu.executemany(statement, data)
       
    62 
       
    63 def _execmany_thread_copy_from(cu, statement, data, table,
       
    64                                columns, encoding='utf-8'):
       
    65     """ Execute thread with copy from
       
    66     """
       
    67     buf = _create_copyfrom_buffer(data, columns, encoding=encoding)
       
    68     if buf is None:
       
    69         _execmany_thread_not_copy_from(cu, statement, data)
       
    70     else:
       
    71         if columns is None:
       
    72             cu.copy_from(buf, table, null='NULL')
       
    73         else:
       
    74             cu.copy_from(buf, table, null='NULL', columns=columns)
       
    75 
       
    76 def _execmany_thread(sql_connect, statements, dump_output_dir=None,
       
    77                      support_copy_from=True, encoding='utf-8'):
       
    78     """
       
    79     Execute sql statement. If 'INSERT INTO', try to use 'COPY FROM' command,
       
    80     or fallback to execute_many.
       
    81     """
       
    82     if support_copy_from:
       
    83         execmany_func = _execmany_thread_copy_from
       
    84     else:
       
    85         execmany_func = _execmany_thread_not_copy_from
       
    86     cnx = sql_connect()
       
    87     cu = cnx.cursor()
       
    88     try:
       
    89         for statement, data in statements:
       
    90             table = None
       
    91             columns = None
       
    92             try:
       
    93                 if not statement.startswith('INSERT INTO'):
       
    94                     cu.executemany(statement, data)
       
    95                     continue
       
    96                 table = statement.split()[2]
       
    97                 if isinstance(data[0], (tuple, list)):
       
    98                     columns = None
       
    99                 else:
       
   100                     columns = list(data[0])
       
   101                 execmany_func(cu, statement, data, table, columns, encoding)
       
   102             except Exception:
       
   103                 print 'unable to copy data into table %s' % table
       
   104                 # Error in import statement, save data in dump_output_dir
       
   105                 if dump_output_dir is not None:
       
   106                     pdata = {'data': data, 'statement': statement,
       
   107                              'time': asctime(), 'columns': columns}
       
   108                     filename = make_uid()
       
   109                     try:
       
   110                         with open(osp.join(dump_output_dir,
       
   111                                            '%s.pickle' % filename), 'w') as fobj:
       
   112                             fobj.write(cPickle.dumps(pdata))
       
   113                     except IOError:
       
   114                         print 'ERROR while pickling in', dump_output_dir, filename+'.pickle'
       
   115                         pass
       
   116                 cnx.rollback()
       
   117                 raise
       
   118     finally:
       
   119         cnx.commit()
       
   120         cu.close()
       
   121 
       
   122 
       
   123 def _copyfrom_buffer_convert_None(value, **opts):
       
   124     '''Convert None value to "NULL"'''
       
   125     return 'NULL'
       
   126 
       
   127 def _copyfrom_buffer_convert_number(value, **opts):
       
   128     '''Convert a number into its string representation'''
       
   129     return str(value)
       
   130 
       
   131 def _copyfrom_buffer_convert_string(value, **opts):
       
   132     '''Convert string value.
       
   133 
       
   134     Recognized keywords:
       
   135     :encoding: resulting string encoding (default: utf-8)
       
   136     '''
       
   137     encoding = opts.get('encoding','utf-8')
       
   138     escape_chars = ((u'\\', ur'\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
       
   139                     (u'\n', u'\\n'))
       
   140     for char, replace in escape_chars:
       
   141         value = value.replace(char, replace)
       
   142     if isinstance(value, unicode):
       
   143         value = value.encode(encoding)
       
   144     return value
       
   145 
       
   146 def _copyfrom_buffer_convert_date(value, **opts):
       
   147     '''Convert date into "YYYY-MM-DD"'''
       
   148     # Do not use strftime, as it yields issue with date < 1900
       
   149     # (http://bugs.python.org/issue1777412)
       
   150     return '%04d-%02d-%02d' % (value.year, value.month, value.day)
       
   151 
       
   152 def _copyfrom_buffer_convert_datetime(value, **opts):
       
   153     '''Convert date into "YYYY-MM-DD HH:MM:SS.UUUUUU"'''
       
   154     # Do not use strftime, as it yields issue with date < 1900
       
   155     # (http://bugs.python.org/issue1777412)
       
   156     return '%s %s' % (_copyfrom_buffer_convert_date(value, **opts),
       
   157                       _copyfrom_buffer_convert_time(value, **opts))
       
   158 
       
   159 def _copyfrom_buffer_convert_time(value, **opts):
       
   160     '''Convert time into "HH:MM:SS.UUUUUU"'''
       
   161     return '%02d:%02d:%02d.%06d' % (value.hour, value.minute,
       
   162                                     value.second, value.microsecond)
       
   163 
       
   164 # (types, converter) list.
       
   165 _COPYFROM_BUFFER_CONVERTERS = [
       
   166     (type(None), _copyfrom_buffer_convert_None),
       
   167     ((long, int, float), _copyfrom_buffer_convert_number),
       
   168     (basestring, _copyfrom_buffer_convert_string),
       
   169     (datetime, _copyfrom_buffer_convert_datetime),
       
   170     (date, _copyfrom_buffer_convert_date),
       
   171     (time, _copyfrom_buffer_convert_time),
       
   172 ]
       
   173 
       
   174 def _create_copyfrom_buffer(data, columns=None, **convert_opts):
       
   175     """
       
   176     Create a StringIO buffer for 'COPY FROM' command.
       
   177     Deals with Unicode, Int, Float, Date... (see ``converters``)
       
   178 
       
   179     :data: a sequence/dict of tuples
       
   180     :columns: list of columns to consider (default to all columns)
       
   181     :converter_opts: keyword arguements given to converters
       
   182     """
       
   183     # Create a list rather than directly create a StringIO
       
   184     # to correctly write lines separated by '\n' in a single step
       
   185     rows = []
       
   186     if columns is None:
       
   187         if isinstance(data[0], (tuple, list)):
       
   188             columns = range(len(data[0]))
       
   189         elif isinstance(data[0], dict):
       
   190             columns = data[0].keys()
       
   191         else:
       
   192             raise ValueError('Could not get columns: you must provide columns.')
       
   193     for row in data:
       
   194         # Iterate over the different columns and the different values
       
   195         # and try to convert them to a correct datatype.
       
   196         # If an error is raised, do not continue.
       
   197         formatted_row = []
       
   198         for col in columns:
       
   199             try:
       
   200                 value = row[col]
       
   201             except KeyError:
       
   202                 warnings.warn(u"Column %s is not accessible in row %s"
       
   203                               % (col, row), RuntimeWarning)
       
   204                 # XXX 'value' set to None so that the import does not end in
       
   205                 # error.
       
   206                 # Instead, the extra keys are set to NULL from the
       
   207                 # database point of view.
       
   208                 value = None
       
   209             for types, converter in _COPYFROM_BUFFER_CONVERTERS:
       
   210                 if isinstance(value, types):
       
   211                     value = converter(value, **convert_opts)
       
   212                     break
       
   213             else:
       
   214                 raise ValueError("Unsupported value type %s" % type(value))
       
   215             # We push the value to the new formatted row
       
   216             # if the value is not None and could be converted to a string.
       
   217             formatted_row.append(value)
       
   218         rows.append('\t'.join(formatted_row))
       
   219     return StringIO('\n'.join(rows))
       
   220 
       
   221 
       
   222 class SQLGenObjectStore(NoHookRQLObjectStore):
       
   223     """Controller of the data import process. This version is based
       
   224     on direct insertions throught SQL command (COPY FROM or execute many).
       
   225 
       
   226     >>> store = SQLGenObjectStore(cnx)
       
   227     >>> store.create_entity('Person', ...)
       
   228     >>> store.flush()
       
   229     """
       
   230 
       
   231     def __init__(self, cnx, dump_output_dir=None, nb_threads_statement=3):
       
   232         """
       
   233         Initialize a SQLGenObjectStore.
       
   234 
       
   235         Parameters:
       
   236 
       
   237           - cnx: connection on the cubicweb instance
       
   238           - dump_output_dir: a directory to dump failed statements
       
   239             for easier recovery. Default is None (no dump).
       
   240           - nb_threads_statement: number of threads used
       
   241             for SQL insertion (default is 3).
       
   242         """
       
   243         super(SQLGenObjectStore, self).__init__(cnx)
       
   244         ### hijack default source
       
   245         self.source = SQLGenSourceWrapper(
       
   246             self.source, cnx.vreg.schema,
       
   247             dump_output_dir=dump_output_dir,
       
   248             nb_threads_statement=nb_threads_statement)
       
   249         ### XXX This is done in super().__init__(), but should be
       
   250         ### redone here to link to the correct source
       
   251         self.add_relation = self.source.add_relation
       
   252         self.indexes_etypes = {}
       
   253 
       
   254     def flush(self):
       
   255         """Flush data to the database"""
       
   256         self.source.flush()
       
   257 
       
   258     def relate(self, subj_eid, rtype, obj_eid, **kwargs):
       
   259         if subj_eid is None or obj_eid is None:
       
   260             return
       
   261         # XXX Could subjtype be inferred ?
       
   262         self.source.add_relation(self._cnx, subj_eid, rtype, obj_eid,
       
   263                                  self.rschema(rtype).inlined, **kwargs)
       
   264         if self.rschema(rtype).symmetric:
       
   265             self.source.add_relation(self._cnx, obj_eid, rtype, subj_eid,
       
   266                                      self.rschema(rtype).inlined, **kwargs)
       
   267 
       
   268     def drop_indexes(self, etype):
       
   269         """Drop indexes for a given entity type"""
       
   270         if etype not in self.indexes_etypes:
       
   271             cu = self._cnx.cnxset.cu
       
   272             def index_to_attr(index):
       
   273                 """turn an index name to (database) attribute name"""
       
   274                 return index.replace(etype.lower(), '').replace('idx', '').strip('_')
       
   275             indices = [(index, index_to_attr(index))
       
   276                        for index in self.source.dbhelper.list_indices(cu, etype)
       
   277                        # Do not consider 'cw_etype_pkey' index
       
   278                        if not index.endswith('key')]
       
   279             self.indexes_etypes[etype] = indices
       
   280         for index, attr in self.indexes_etypes[etype]:
       
   281             self._cnx.system_sql('DROP INDEX %s' % index)
       
   282 
       
   283     def create_indexes(self, etype):
       
   284         """Recreate indexes for a given entity type"""
       
   285         for index, attr in self.indexes_etypes.get(etype, []):
       
   286             sql = 'CREATE INDEX %s ON cw_%s(%s)' % (index, etype, attr)
       
   287             self._cnx.system_sql(sql)
       
   288 
       
   289 
       
   290 ###########################################################################
       
   291 ## SQL Source #############################################################
       
   292 ###########################################################################
       
   293 
       
   294 class SQLGenSourceWrapper(object):
       
   295 
       
   296     def __init__(self, system_source, schema,
       
   297                  dump_output_dir=None, nb_threads_statement=3):
       
   298         self.system_source = system_source
       
   299         self._sql = threading.local()
       
   300         # Explicitely backport attributes from system source
       
   301         self._storage_handler = self.system_source._storage_handler
       
   302         self.preprocess_entity = self.system_source.preprocess_entity
       
   303         self.sqlgen = self.system_source.sqlgen
       
   304         self.uri = self.system_source.uri
       
   305         self.eid = self.system_source.eid
       
   306         # Directory to write temporary files
       
   307         self.dump_output_dir = dump_output_dir
       
   308         # Allow to execute code with SQLite backend that does
       
   309         # not support (yet...) copy_from
       
   310         # XXX Should be dealt with in logilab.database
       
   311         spcfrom = system_source.dbhelper.dbapi_module.support_copy_from
       
   312         self.support_copy_from = spcfrom
       
   313         self.dbencoding = system_source.dbhelper.dbencoding
       
   314         self.nb_threads_statement = nb_threads_statement
       
   315         # initialize thread-local data for main thread
       
   316         self.init_thread_locals()
       
   317         self._inlined_rtypes_cache = {}
       
   318         self._fill_inlined_rtypes_cache(schema)
       
   319         self.schema = schema
       
   320         self.do_fti = False
       
   321 
       
   322     def _fill_inlined_rtypes_cache(self, schema):
       
   323         cache = self._inlined_rtypes_cache
       
   324         for eschema in schema.entities():
       
   325             for rschema in eschema.ordered_relations():
       
   326                 if rschema.inlined:
       
   327                     cache[eschema.type] = SQL_PREFIX + rschema.type
       
   328 
       
   329     def init_thread_locals(self):
       
   330         """initializes thread-local data"""
       
   331         self._sql.entities = defaultdict(list)
       
   332         self._sql.relations = {}
       
   333         self._sql.inlined_relations = {}
       
   334         # keep track, for each eid of the corresponding data dict
       
   335         self._sql.eid_insertdicts = {}
       
   336 
       
   337     def flush(self):
       
   338         print 'starting flush'
       
   339         _entities_sql = self._sql.entities
       
   340         _relations_sql = self._sql.relations
       
   341         _inlined_relations_sql = self._sql.inlined_relations
       
   342         _insertdicts = self._sql.eid_insertdicts
       
   343         try:
       
   344             # try, for each inlined_relation, to find if we're also creating
       
   345             # the host entity (i.e. the subject of the relation).
       
   346             # In that case, simply update the insert dict and remove
       
   347             # the need to make the
       
   348             # UPDATE statement
       
   349             for statement, datalist in _inlined_relations_sql.iteritems():
       
   350                 new_datalist = []
       
   351                 # for a given inlined relation,
       
   352                 # browse each couple to be inserted
       
   353                 for data in datalist:
       
   354                     keys = list(data)
       
   355                     # For inlined relations, it exists only two case:
       
   356                     # (rtype, cw_eid) or (cw_eid, rtype)
       
   357                     if keys[0] == 'cw_eid':
       
   358                         rtype = keys[1]
       
   359                     else:
       
   360                         rtype = keys[0]
       
   361                     updated_eid = data['cw_eid']
       
   362                     if updated_eid in _insertdicts:
       
   363                         _insertdicts[updated_eid][rtype] = data[rtype]
       
   364                     else:
       
   365                         # could not find corresponding insert dict, keep the
       
   366                         # UPDATE query
       
   367                         new_datalist.append(data)
       
   368                 _inlined_relations_sql[statement] = new_datalist
       
   369             _import_statements(self.system_source.get_connection,
       
   370                                _entities_sql.items()
       
   371                                + _relations_sql.items()
       
   372                                + _inlined_relations_sql.items(),
       
   373                                dump_output_dir=self.dump_output_dir,
       
   374                                nb_threads=self.nb_threads_statement,
       
   375                                support_copy_from=self.support_copy_from,
       
   376                                encoding=self.dbencoding)
       
   377         finally:
       
   378             _entities_sql.clear()
       
   379             _relations_sql.clear()
       
   380             _insertdicts.clear()
       
   381             _inlined_relations_sql.clear()
       
   382 
       
   383     def add_relation(self, cnx, subject, rtype, object,
       
   384                      inlined=False, **kwargs):
       
   385         if inlined:
       
   386             _sql = self._sql.inlined_relations
       
   387             data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
       
   388             subjtype = kwargs.get('subjtype')
       
   389             if subjtype is None:
       
   390                 # Try to infer it
       
   391                 targets = [t.type for t in
       
   392                            self.schema.rschema(rtype).subjects()]
       
   393                 if len(targets) == 1:
       
   394                     subjtype = targets[0]
       
   395                 else:
       
   396                     raise ValueError('You should give the subject etype for '
       
   397                                      'inlined relation %s'
       
   398                                      ', as it cannot be inferred: '
       
   399                                      'this type is given as keyword argument '
       
   400                                      '``subjtype``'% rtype)
       
   401             statement = self.sqlgen.update(SQL_PREFIX + subjtype,
       
   402                                            data, ['cw_eid'])
       
   403         else:
       
   404             _sql = self._sql.relations
       
   405             data = {'eid_from': subject, 'eid_to': object}
       
   406             statement = self.sqlgen.insert('%s_relation' % rtype, data)
       
   407         if statement in _sql:
       
   408             _sql[statement].append(data)
       
   409         else:
       
   410             _sql[statement] = [data]
       
   411 
       
   412     def add_entity(self, cnx, entity):
       
   413         with self._storage_handler(entity, 'added'):
       
   414             attrs = self.preprocess_entity(entity)
       
   415             rtypes = self._inlined_rtypes_cache.get(entity.cw_etype, ())
       
   416             if isinstance(rtypes, str):
       
   417                 rtypes = (rtypes,)
       
   418             for rtype in rtypes:
       
   419                 if rtype not in attrs:
       
   420                     attrs[rtype] = None
       
   421             sql = self.sqlgen.insert(SQL_PREFIX + entity.cw_etype, attrs)
       
   422             self._sql.eid_insertdicts[entity.eid] = attrs
       
   423             self._append_to_entities(sql, attrs)
       
   424 
       
   425     def _append_to_entities(self, sql, attrs):
       
   426         self._sql.entities[sql].append(attrs)
       
   427 
       
   428     def _handle_insert_entity_sql(self, cnx, sql, attrs):
       
   429         # We have to overwrite the source given in parameters
       
   430         # as here, we directly use the system source
       
   431         attrs['asource'] = self.system_source.uri
       
   432         self._append_to_entities(sql, attrs)
       
   433 
       
   434     def _handle_is_relation_sql(self, cnx, sql, attrs):
       
   435         self._append_to_entities(sql, attrs)
       
   436 
       
   437     def _handle_is_instance_of_sql(self, cnx, sql, attrs):
       
   438         self._append_to_entities(sql, attrs)
       
   439 
       
   440     def _handle_source_relation_sql(self, cnx, sql, attrs):
       
   441         self._append_to_entities(sql, attrs)
       
   442 
       
   443     # add_info is _copypasted_ from the one in NativeSQLSource. We want it
       
   444     # there because it will use the _handlers of the SQLGenSourceWrapper, which
       
   445     # are not like the ones in the native source.
       
   446     def add_info(self, cnx, entity, source, extid):
       
   447         """add type and source info for an eid into the system table"""
       
   448         # begin by inserting eid/type/source/extid into the entities table
       
   449         if extid is not None:
       
   450             assert isinstance(extid, str)
       
   451             extid = b64encode(extid)
       
   452         attrs = {'type': entity.cw_etype, 'eid': entity.eid, 'extid': extid,
       
   453                  'asource': source.uri}
       
   454         self._handle_insert_entity_sql(cnx, self.sqlgen.insert('entities', attrs), attrs)
       
   455         # insert core relations: is, is_instance_of and cw_source
       
   456         try:
       
   457             self._handle_is_relation_sql(cnx, 'INSERT INTO is_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   458                                          (entity.eid, eschema_eid(cnx, entity.e_schema)))
       
   459         except IndexError:
       
   460             # during schema serialization, skip
       
   461             pass
       
   462         else:
       
   463             for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
       
   464                 self._handle_is_relation_sql(cnx,
       
   465                                              'INSERT INTO is_instance_of_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   466                                              (entity.eid, eschema_eid(cnx, eschema)))
       
   467         if 'CWSource' in self.schema and source.eid is not None: # else, cw < 3.10
       
   468             self._handle_is_relation_sql(cnx, 'INSERT INTO cw_source_relation(eid_from,eid_to) VALUES (%s,%s)',
       
   469                                          (entity.eid, source.eid))
       
   470         # now we can update the full text index
       
   471         if self.do_fti and self.need_fti_indexation(entity.cw_etype):
       
   472             self.index_entity(cnx, entity=entity)