# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""Postgres specific store"""
from __future__ import print_function
import threading
import warnings
import os.path as osp
from io import StringIO
from time import asctime
from datetime import date, datetime, time
from collections import defaultdict
from base64 import b64encode
from six import string_types, integer_types, text_type
from six.moves import cPickle as pickle, range
from cubicweb.utils import make_uid
from cubicweb.server.sqlutils import SQL_PREFIX
from cubicweb.dataimport.stores import NoHookRQLObjectStore
def _import_statements(sql_connect, statements, nb_threads=3,
dump_output_dir=None,
support_copy_from=True, encoding='utf-8'):
"""
Import a bunch of sql statements, using different threads.
"""
try:
chunksize = (len(statements) / nb_threads) + 1
threads = []
for i in range(nb_threads):
chunks = statements[i*chunksize:(i+1)*chunksize]
thread = threading.Thread(target=_execmany_thread,
args=(sql_connect, chunks,
dump_output_dir,
support_copy_from,
encoding))
thread.start()
threads.append(thread)
for t in threads:
t.join()
except Exception:
print('Error in import statements')
def _execmany_thread_not_copy_from(cu, statement, data, table=None,
columns=None, encoding='utf-8'):
""" Execute thread without copy from
"""
cu.executemany(statement, data)
def _execmany_thread_copy_from(cu, statement, data, table,
columns, encoding='utf-8'):
""" Execute thread with copy from
"""
buf = _create_copyfrom_buffer(data, columns, encoding=encoding)
if buf is None:
_execmany_thread_not_copy_from(cu, statement, data)
else:
if columns is None:
cu.copy_from(buf, table, null=u'NULL')
else:
cu.copy_from(buf, table, null=u'NULL', columns=columns)
def _execmany_thread(sql_connect, statements, dump_output_dir=None,
support_copy_from=True, encoding='utf-8'):
"""
Execute sql statement. If 'INSERT INTO', try to use 'COPY FROM' command,
or fallback to execute_many.
"""
if support_copy_from:
execmany_func = _execmany_thread_copy_from
else:
execmany_func = _execmany_thread_not_copy_from
cnx = sql_connect()
cu = cnx.cursor()
try:
for statement, data in statements:
table = None
columns = None
try:
if not statement.startswith('INSERT INTO'):
cu.executemany(statement, data)
continue
table = statement.split()[2]
if isinstance(data[0], (tuple, list)):
columns = None
else:
columns = list(data[0])
execmany_func(cu, statement, data, table, columns, encoding)
except Exception:
print('unable to copy data into table %s' % table)
# Error in import statement, save data in dump_output_dir
if dump_output_dir is not None:
pdata = {'data': data, 'statement': statement,
'time': asctime(), 'columns': columns}
filename = make_uid()
try:
with open(osp.join(dump_output_dir,
'%s.pickle' % filename), 'wb') as fobj:
pickle.dump(pdata, fobj)
except IOError:
print('ERROR while pickling in', dump_output_dir, filename+'.pickle')
cnx.rollback()
raise
finally:
cnx.commit()
cu.close()
def _copyfrom_buffer_convert_None(value, **opts):
'''Convert None value to "NULL"'''
return u'NULL'
def _copyfrom_buffer_convert_number(value, **opts):
'''Convert a number into its string representation'''
return text_type(value)
def _copyfrom_buffer_convert_string(value, **opts):
'''Convert string value.
'''
escape_chars = ((u'\\', u'\\\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
(u'\n', u'\\n'))
for char, replace in escape_chars:
value = value.replace(char, replace)
return value
def _copyfrom_buffer_convert_date(value, **opts):
'''Convert date into "YYYY-MM-DD"'''
# Do not use strftime, as it yields issue with date < 1900
# (http://bugs.python.org/issue1777412)
return u'%04d-%02d-%02d' % (value.year, value.month, value.day)
def _copyfrom_buffer_convert_datetime(value, **opts):
'''Convert date into "YYYY-MM-DD HH:MM:SS.UUUUUU"'''
# Do not use strftime, as it yields issue with date < 1900
# (http://bugs.python.org/issue1777412)
return u'%s %s' % (_copyfrom_buffer_convert_date(value, **opts),
_copyfrom_buffer_convert_time(value, **opts))
def _copyfrom_buffer_convert_time(value, **opts):
'''Convert time into "HH:MM:SS.UUUUUU"'''
return u'%02d:%02d:%02d.%06d' % (value.hour, value.minute,
value.second, value.microsecond)
# (types, converter) list.
_COPYFROM_BUFFER_CONVERTERS = [
(type(None), _copyfrom_buffer_convert_None),
(integer_types + (float,), _copyfrom_buffer_convert_number),
(string_types, _copyfrom_buffer_convert_string),
(datetime, _copyfrom_buffer_convert_datetime),
(date, _copyfrom_buffer_convert_date),
(time, _copyfrom_buffer_convert_time),
]
def _create_copyfrom_buffer(data, columns=None, **convert_opts):
"""
Create a StringIO buffer for 'COPY FROM' command.
Deals with Unicode, Int, Float, Date... (see ``converters``)
:data: a sequence/dict of tuples
:columns: list of columns to consider (default to all columns)
:converter_opts: keyword arguements given to converters
"""
# Create a list rather than directly create a StringIO
# to correctly write lines separated by '\n' in a single step
rows = []
if columns is None:
if isinstance(data[0], (tuple, list)):
columns = list(range(len(data[0])))
elif isinstance(data[0], dict):
columns = data[0].keys()
else:
raise ValueError('Could not get columns: you must provide columns.')
for row in data:
# Iterate over the different columns and the different values
# and try to convert them to a correct datatype.
# If an error is raised, do not continue.
formatted_row = []
for col in columns:
try:
value = row[col]
except KeyError:
warnings.warn(u"Column %s is not accessible in row %s"
% (col, row), RuntimeWarning)
# XXX 'value' set to None so that the import does not end in
# error.
# Instead, the extra keys are set to NULL from the
# database point of view.
value = None
for types, converter in _COPYFROM_BUFFER_CONVERTERS:
if isinstance(value, types):
value = converter(value, **convert_opts)
assert isinstance(value, text_type)
break
else:
raise ValueError("Unsupported value type %s" % type(value))
# We push the value to the new formatted row
# if the value is not None and could be converted to a string.
formatted_row.append(value)
rows.append('\t'.join(formatted_row))
return StringIO('\n'.join(rows))
class SQLGenObjectStore(NoHookRQLObjectStore):
"""Controller of the data import process. This version is based
on direct insertions throught SQL command (COPY FROM or execute many).
>>> store = SQLGenObjectStore(cnx)
>>> store.create_entity('Person', ...)
>>> store.flush()
"""
def __init__(self, cnx, dump_output_dir=None, nb_threads_statement=3):
"""
Initialize a SQLGenObjectStore.
Parameters:
- cnx: connection on the cubicweb instance
- dump_output_dir: a directory to dump failed statements
for easier recovery. Default is None (no dump).
- nb_threads_statement: number of threads used
for SQL insertion (default is 3).
"""
super(SQLGenObjectStore, self).__init__(cnx)
### hijack default source
self.source = SQLGenSourceWrapper(
self.source, cnx.vreg.schema,
dump_output_dir=dump_output_dir,
nb_threads_statement=nb_threads_statement)
### XXX This is done in super().__init__(), but should be
### redone here to link to the correct source
self.add_relation = self.source.add_relation
self.indexes_etypes = {}
def flush(self):
"""Flush data to the database"""
self.source.flush()
def relate(self, subj_eid, rtype, obj_eid, **kwargs):
if subj_eid is None or obj_eid is None:
return
# XXX Could subjtype be inferred ?
self.source.add_relation(self._cnx, subj_eid, rtype, obj_eid,
self.rschema(rtype).inlined, **kwargs)
if self.rschema(rtype).symmetric:
self.source.add_relation(self._cnx, obj_eid, rtype, subj_eid,
self.rschema(rtype).inlined, **kwargs)
def drop_indexes(self, etype):
"""Drop indexes for a given entity type"""
if etype not in self.indexes_etypes:
cu = self._cnx.cnxset.cu
def index_to_attr(index):
"""turn an index name to (database) attribute name"""
return index.replace(etype.lower(), '').replace('idx', '').strip('_')
indices = [(index, index_to_attr(index))
for index in self.source.dbhelper.list_indices(cu, etype)
# Do not consider 'cw_etype_pkey' index
if not index.endswith('key')]
self.indexes_etypes[etype] = indices
for index, attr in self.indexes_etypes[etype]:
self._cnx.system_sql('DROP INDEX %s' % index)
def create_indexes(self, etype):
"""Recreate indexes for a given entity type"""
for index, attr in self.indexes_etypes.get(etype, []):
sql = 'CREATE INDEX %s ON cw_%s(%s)' % (index, etype, attr)
self._cnx.system_sql(sql)
###########################################################################
## SQL Source #############################################################
###########################################################################
class SQLGenSourceWrapper(object):
def __init__(self, system_source, schema,
dump_output_dir=None, nb_threads_statement=3):
self.system_source = system_source
self._sql = threading.local()
# Explicitely backport attributes from system source
self._storage_handler = self.system_source._storage_handler
self.preprocess_entity = self.system_source.preprocess_entity
self.sqlgen = self.system_source.sqlgen
self.uri = self.system_source.uri
self.eid = self.system_source.eid
# Directory to write temporary files
self.dump_output_dir = dump_output_dir
# Allow to execute code with SQLite backend that does
# not support (yet...) copy_from
# XXX Should be dealt with in logilab.database
spcfrom = system_source.dbhelper.dbapi_module.support_copy_from
self.support_copy_from = spcfrom
self.dbencoding = system_source.dbhelper.dbencoding
self.nb_threads_statement = nb_threads_statement
# initialize thread-local data for main thread
self.init_thread_locals()
self._inlined_rtypes_cache = {}
self._fill_inlined_rtypes_cache(schema)
self.schema = schema
self.do_fti = False
def _fill_inlined_rtypes_cache(self, schema):
cache = self._inlined_rtypes_cache
for eschema in schema.entities():
for rschema in eschema.ordered_relations():
if rschema.inlined:
cache[eschema.type] = SQL_PREFIX + rschema.type
def init_thread_locals(self):
"""initializes thread-local data"""
self._sql.entities = defaultdict(list)
self._sql.relations = {}
self._sql.inlined_relations = {}
# keep track, for each eid of the corresponding data dict
self._sql.eid_insertdicts = {}
def flush(self):
print('starting flush')
_entities_sql = self._sql.entities
_relations_sql = self._sql.relations
_inlined_relations_sql = self._sql.inlined_relations
_insertdicts = self._sql.eid_insertdicts
try:
# try, for each inlined_relation, to find if we're also creating
# the host entity (i.e. the subject of the relation).
# In that case, simply update the insert dict and remove
# the need to make the
# UPDATE statement
for statement, datalist in _inlined_relations_sql.items():
new_datalist = []
# for a given inlined relation,
# browse each couple to be inserted
for data in datalist:
keys = list(data)
# For inlined relations, it exists only two case:
# (rtype, cw_eid) or (cw_eid, rtype)
if keys[0] == 'cw_eid':
rtype = keys[1]
else:
rtype = keys[0]
updated_eid = data['cw_eid']
if updated_eid in _insertdicts:
_insertdicts[updated_eid][rtype] = data[rtype]
else:
# could not find corresponding insert dict, keep the
# UPDATE query
new_datalist.append(data)
_inlined_relations_sql[statement] = new_datalist
_import_statements(self.system_source.get_connection,
_entities_sql.items()
+ _relations_sql.items()
+ _inlined_relations_sql.items(),
dump_output_dir=self.dump_output_dir,
nb_threads=self.nb_threads_statement,
support_copy_from=self.support_copy_from,
encoding=self.dbencoding)
finally:
_entities_sql.clear()
_relations_sql.clear()
_insertdicts.clear()
_inlined_relations_sql.clear()
def add_relation(self, cnx, subject, rtype, object,
inlined=False, **kwargs):
if inlined:
_sql = self._sql.inlined_relations
data = {'cw_eid': subject, SQL_PREFIX + rtype: object}
subjtype = kwargs.get('subjtype')
if subjtype is None:
# Try to infer it
targets = [t.type for t in
self.schema.rschema(rtype).subjects()]
if len(targets) == 1:
subjtype = targets[0]
else:
raise ValueError('You should give the subject etype for '
'inlined relation %s'
', as it cannot be inferred: '
'this type is given as keyword argument '
'``subjtype``'% rtype)
statement = self.sqlgen.update(SQL_PREFIX + subjtype,
data, ['cw_eid'])
else:
_sql = self._sql.relations
data = {'eid_from': subject, 'eid_to': object}
statement = self.sqlgen.insert('%s_relation' % rtype, data)
if statement in _sql:
_sql[statement].append(data)
else:
_sql[statement] = [data]
def add_entity(self, cnx, entity):
with self._storage_handler(entity, 'added'):
attrs = self.preprocess_entity(entity)
rtypes = self._inlined_rtypes_cache.get(entity.cw_etype, ())
if isinstance(rtypes, str):
rtypes = (rtypes,)
for rtype in rtypes:
if rtype not in attrs:
attrs[rtype] = None
sql = self.sqlgen.insert(SQL_PREFIX + entity.cw_etype, attrs)
self._sql.eid_insertdicts[entity.eid] = attrs
self._append_to_entities(sql, attrs)
def _append_to_entities(self, sql, attrs):
self._sql.entities[sql].append(attrs)
def _handle_insert_entity_sql(self, cnx, sql, attrs):
# We have to overwrite the source given in parameters
# as here, we directly use the system source
attrs['asource'] = self.system_source.uri
self._append_to_entities(sql, attrs)
def _handle_is_relation_sql(self, cnx, sql, attrs):
self._append_to_entities(sql, attrs)
def _handle_is_instance_of_sql(self, cnx, sql, attrs):
self._append_to_entities(sql, attrs)
def _handle_source_relation_sql(self, cnx, sql, attrs):
self._append_to_entities(sql, attrs)
# add_info is _copypasted_ from the one in NativeSQLSource. We want it
# there because it will use the _handlers of the SQLGenSourceWrapper, which
# are not like the ones in the native source.
def add_info(self, cnx, entity, source, extid):
"""add type and source info for an eid into the system table"""
# begin by inserting eid/type/source/extid into the entities table
if extid is not None:
assert isinstance(extid, str)
extid = b64encode(extid)
attrs = {'type': entity.cw_etype, 'eid': entity.eid, 'extid': extid,
'asource': source.uri}
self._handle_insert_entity_sql(cnx, self.sqlgen.insert('entities', attrs), attrs)
# insert core relations: is, is_instance_of and cw_source
try:
self._handle_is_relation_sql(cnx, 'INSERT INTO is_relation(eid_from,eid_to) VALUES (%s,%s)',
(entity.eid, eschema_eid(cnx, entity.e_schema)))
except IndexError:
# during schema serialization, skip
pass
else:
for eschema in entity.e_schema.ancestors() + [entity.e_schema]:
self._handle_is_relation_sql(cnx,
'INSERT INTO is_instance_of_relation(eid_from,eid_to) VALUES (%s,%s)',
(entity.eid, eschema_eid(cnx, eschema)))
if 'CWSource' in self.schema and source.eid is not None: # else, cw < 3.10
self._handle_is_relation_sql(cnx, 'INSERT INTO cw_source_relation(eid_from,eid_to) VALUES (%s,%s)',
(entity.eid, source.eid))
# now we can update the full text index
if self.do_fti and self.need_fti_indexation(entity.cw_etype):
self.index_entity(cnx, entity=entity)