diff -r 058bb3dc685f -r 0b59724cb3f2 cubicweb/server/sqlutils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/cubicweb/server/sqlutils.py Sat Jan 16 13:48:51 2016 +0100 @@ -0,0 +1,591 @@ +# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of CubicWeb. +# +# CubicWeb is free software: you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) +# any later version. +# +# CubicWeb is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with CubicWeb. If not, see . +"""SQL utilities functions and classes.""" +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import sys +import re +import subprocess +from os.path import abspath +from logging import getLogger +from datetime import time, datetime, timedelta + +from six import string_types, text_type +from six.moves import filter + +from pytz import utc + +from logilab import database as db, common as lgc +from logilab.common.shellutils import ProgressBar, DummyProgressBar +from logilab.common.deprecation import deprecated +from logilab.common.logging_ext import set_log_methods +from logilab.common.date import utctime, utcdatetime, strptime +from logilab.database.sqlgen import SQLGenerator + +from cubicweb import Binary, ConfigurationError +from cubicweb.uilib import remove_html_tags +from cubicweb.schema import PURE_VIRTUAL_RTYPES +from cubicweb.server import SQL_CONNECT_HOOKS +from cubicweb.server.utils import crypt_password + +lgc.USE_MX_DATETIME = False +SQL_PREFIX = 'cw_' + + +def _run_command(cmd): + if isinstance(cmd, string_types): + print(cmd) + return subprocess.call(cmd, shell=True) + else: + print(' '.join(cmd)) + return subprocess.call(cmd) + + +def sqlexec(sqlstmts, cursor_or_execute, withpb=True, + pbtitle='', delimiter=';', cnx=None): + """execute sql statements ignoring DROP/ CREATE GROUP or USER statements + error. + + :sqlstmts_as_string: a string or a list of sql statements. + :cursor_or_execute: sql cursor or a callback used to execute statements + :cnx: if given, commit/rollback at each statement. + + :withpb: if True, display a progresse bar + :pbtitle: a string displayed as the progress bar title (if `withpb=True`) + + :delimiter: a string used to split sqlstmts (if it is a string) + + Return the failed statements (same type as sqlstmts) + """ + if hasattr(cursor_or_execute, 'execute'): + execute = cursor_or_execute.execute + else: + execute = cursor_or_execute + sqlstmts_as_string = False + if isinstance(sqlstmts, string_types): + sqlstmts_as_string = True + sqlstmts = sqlstmts.split(delimiter) + if withpb: + if sys.stdout.isatty(): + pb = ProgressBar(len(sqlstmts), title=pbtitle) + else: + pb = DummyProgressBar() + failed = [] + for sql in sqlstmts: + sql = sql.strip() + if withpb: + pb.update() + if not sql: + continue + try: + # some dbapi modules doesn't accept unicode for sql string + execute(str(sql)) + except Exception: + if cnx: + cnx.rollback() + failed.append(sql) + else: + if cnx: + cnx.commit() + if withpb: + print() + if sqlstmts_as_string: + failed = delimiter.join(failed) + return failed + + +def sqlgrants(schema, driver, user, + text_index=True, set_owner=True, + skip_relations=(), skip_entities=()): + """return sql to give all access privileges to the given user on the system + schema + """ + from cubicweb.server.schema2sql import grant_schema + from cubicweb.server.sources import native + output = [] + w = output.append + w(native.grant_schema(user, set_owner)) + w('') + if text_index: + dbhelper = db.get_db_helper(driver) + w(dbhelper.sql_grant_user_on_fti(user)) + w('') + w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX)) + return '\n'.join(output) + + +def sqlschema(schema, driver, text_index=True, + user=None, set_owner=False, + skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()): + """return the system sql schema, according to the given parameters""" + from cubicweb.server.schema2sql import schema2sql + from cubicweb.server.sources import native + if set_owner: + assert user, 'user is argument required when set_owner is true' + output = [] + w = output.append + w(native.sql_schema(driver)) + w('') + dbhelper = db.get_db_helper(driver) + if text_index: + w(dbhelper.sql_init_fti().replace(';', ';;')) + w('') + w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX, + skip_entities=skip_entities, + skip_relations=skip_relations).replace(';', ';;')) + if dbhelper.users_support and user: + w('') + w(sqlgrants(schema, driver, user, text_index, set_owner, + skip_relations, skip_entities).replace(';', ';;')) + return '\n'.join(output) + + +def sqldropschema(schema, driver, text_index=True, + skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()): + """return the sql to drop the schema, according to the given parameters""" + from cubicweb.server.schema2sql import dropschema2sql + from cubicweb.server.sources import native + output = [] + w = output.append + if text_index: + dbhelper = db.get_db_helper(driver) + w(dbhelper.sql_drop_fti()) + w('') + w(dropschema2sql(dbhelper, schema, prefix=SQL_PREFIX, + skip_entities=skip_entities, + skip_relations=skip_relations)) + w('') + w(native.sql_drop_schema(driver)) + return '\n'.join(output) + + +_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match +def sql_drop_all_user_tables(driver_or_helper, sqlcursor): + """Return ths sql to drop all tables found in the database system.""" + if not getattr(driver_or_helper, 'list_tables', None): + dbhelper = db.get_db_helper(driver_or_helper) + else: + dbhelper = driver_or_helper + + cmds = [dbhelper.sql_drop_sequence('entities_id_seq')] + # for mssql, we need to drop views before tables + if hasattr(dbhelper, 'list_views'): + cmds += ['DROP VIEW %s;' % name + for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))] + cmds += ['DROP TABLE %s;' % name + for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))] + return '\n'.join(cmds) + + +class ConnectionWrapper(object): + """handle connection to the system source, at some point associated to a + :class:`Session` + """ + + # since 3.19, we only have to manage the system source connection + def __init__(self, system_source): + # dictionary of (source, connection), indexed by sources'uri + self._source = system_source + self.cnx = system_source.get_connection() + self.cu = self.cnx.cursor() + + def commit(self): + """commit the current transaction for this user""" + # let exception propagates + self.cnx.commit() + + def rollback(self): + """rollback the current transaction for this user""" + # catch exceptions, rollback other sources anyway + try: + self.cnx.rollback() + except Exception: + self._source.critical('rollback error', exc_info=sys.exc_info()) + # error on rollback, the connection is much probably in a really + # bad state. Replace it by a new one. + self.reconnect() + + def close(self, i_know_what_i_do=False): + """close all connections in the set""" + if i_know_what_i_do is not True: # unexpected closing safety belt + raise RuntimeError('connections set shouldn\'t be closed') + try: + self.cu.close() + self.cu = None + except Exception: + pass + try: + self.cnx.close() + self.cnx = None + except Exception: + pass + + # internals ############################################################### + + def cnxset_freed(self): + """connections set is being freed from a session""" + pass # no nothing by default + + def reconnect(self): + """reopen a connection for this source or all sources if none specified + """ + try: + # properly close existing connection if any + self.cnx.close() + except Exception: + pass + self._source.info('trying to reconnect') + self.cnx = self._source.get_connection() + self.cu = self.cnx.cursor() + + @deprecated('[3.19] use .cu instead') + def __getitem__(self, uri): + assert uri == 'system' + return self.cu + + @deprecated('[3.19] use repo.system_source instead') + def source(self, uid): + assert uid == 'system' + return self._source + + @deprecated('[3.19] use .cnx instead') + def connection(self, uid): + assert uid == 'system' + return self.cnx + + +class SqliteConnectionWrapper(ConnectionWrapper): + """Sqlite specific connection wrapper: close the connection each time it's + freed (and reopen it later when needed) + """ + def __init__(self, system_source): + # don't call parent's __init__, we don't want to initiate the connection + self._source = system_source + + _cnx = None + + def cnxset_freed(self): + self.cu.close() + self.cnx.close() + self.cnx = self.cu = None + + @property + def cnx(self): + if self._cnx is None: + self._cnx = self._source.get_connection() + self._cu = self._cnx.cursor() + return self._cnx + @cnx.setter + def cnx(self, value): + self._cnx = value + + @property + def cu(self): + if self._cnx is None: + self._cnx = self._source.get_connection() + self._cu = self._cnx.cursor() + return self._cu + @cu.setter + def cu(self, value): + self._cu = value + + +class SQLAdapterMixIn(object): + """Mixin for SQL data sources, getting a connection from a configuration + dictionary and handling connection locking + """ + cnx_wrap = ConnectionWrapper + + def __init__(self, source_config, repairing=False): + try: + self.dbdriver = source_config['db-driver'].lower() + dbname = source_config['db-name'] + except KeyError: + raise ConfigurationError('missing some expected entries in sources file') + dbhost = source_config.get('db-host') + port = source_config.get('db-port') + dbport = port and int(port) or None + dbuser = source_config.get('db-user') + dbpassword = source_config.get('db-password') + dbencoding = source_config.get('db-encoding', 'UTF-8') + dbextraargs = source_config.get('db-extra-arguments') + dbnamespace = source_config.get('db-namespace') + self.dbhelper = db.get_db_helper(self.dbdriver) + self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser, + dbpassword, dbextraargs, + dbencoding, dbnamespace) + self.sqlgen = SQLGenerator() + # copy back some commonly accessed attributes + dbapi_module = self.dbhelper.dbapi_module + self.OperationalError = dbapi_module.OperationalError + self.InterfaceError = dbapi_module.InterfaceError + self.DbapiError = dbapi_module.Error + self._binary = self.dbhelper.binary_value + self._process_value = dbapi_module.process_value + self._dbencoding = dbencoding + if self.dbdriver == 'sqlite': + self.cnx_wrap = SqliteConnectionWrapper + self.dbhelper.dbname = abspath(self.dbhelper.dbname) + if not repairing: + statement_timeout = int(source_config.get('db-statement-timeout', 0)) + if statement_timeout > 0: + def set_postgres_timeout(cnx): + cnx.cursor().execute('SET statement_timeout to %d' % statement_timeout) + cnx.commit() + postgres_hooks = SQL_CONNECT_HOOKS['postgres'] + postgres_hooks.append(set_postgres_timeout) + + def wrapped_connection(self): + """open and return a connection to the database, wrapped into a class + handling reconnection and all + """ + return self.cnx_wrap(self) + + def get_connection(self): + """open and return a connection to the database""" + return self.dbhelper.get_connection() + + def backup_to_file(self, backupfile, confirm): + for cmd in self.dbhelper.backup_commands(backupfile, + keepownership=False): + if _run_command(cmd): + if not confirm(' [Failed] Continue anyway?', default='n'): + raise Exception('Failed command: %s' % cmd) + + def restore_from_file(self, backupfile, confirm, drop=True): + for cmd in self.dbhelper.restore_commands(backupfile, + keepownership=False, + drop=drop): + if _run_command(cmd): + if not confirm(' [Failed] Continue anyway?', default='n'): + raise Exception('Failed command: %s' % cmd) + + def merge_args(self, args, query_args): + if args is not None: + newargs = {} + for key, val in args.items(): + # convert cubicweb binary into db binary + if isinstance(val, Binary): + val = self._binary(val.getvalue()) + # convert timestamp to utc. + # expect SET TiME ZONE to UTC at connection opening time. + # This shouldn't change anything for datetime without TZ. + elif isinstance(val, datetime) and val.tzinfo is not None: + val = utcdatetime(val) + elif isinstance(val, time) and val.tzinfo is not None: + val = utctime(val) + newargs[key] = val + # should not collide + assert not (frozenset(newargs) & frozenset(query_args)), \ + 'unexpected collision: %s' % (frozenset(newargs) & frozenset(query_args)) + newargs.update(query_args) + return newargs + return query_args + + def process_result(self, cursor, cnx=None, column_callbacks=None): + """return a list of CubicWeb compliant values from data in the given cursor + """ + return list(self.iter_process_result(cursor, cnx, column_callbacks)) + + def iter_process_result(self, cursor, cnx, column_callbacks=None): + """return a iterator on tuples of CubicWeb compliant values from data + in the given cursor + """ + # use two different implementations to avoid paying the price of + # callback lookup for each *cell* in results when there is nothing to + # lookup + if not column_callbacks: + return self.dbhelper.dbapi_module.process_cursor(cursor, self._dbencoding, + Binary) + assert cnx + return self._cb_process_result(cursor, column_callbacks, cnx) + + def _cb_process_result(self, cursor, column_callbacks, cnx): + # begin bind to locals for optimization + descr = cursor.description + encoding = self._dbencoding + process_value = self._process_value + binary = Binary + # /end + cursor.arraysize = 100 + while True: + results = cursor.fetchmany() + if not results: + break + for line in results: + result = [] + for col, value in enumerate(line): + if value is None: + result.append(value) + continue + cbstack = column_callbacks.get(col, None) + if cbstack is None: + value = process_value(value, descr[col], encoding, binary) + else: + for cb in cbstack: + value = cb(self, cnx, value) + result.append(value) + yield result + + def preprocess_entity(self, entity): + """return a dictionary to use as extra argument to cursor.execute + to insert/update an entity into a SQL database + """ + attrs = {} + eschema = entity.e_schema + converters = getattr(self.dbhelper, 'TYPE_CONVERTERS', {}) + for attr, value in entity.cw_edited.items(): + if value is not None and eschema.subjrels[attr].final: + atype = str(entity.e_schema.destination(attr)) + if atype in converters: + # It is easier to modify preprocess_entity rather + # than add_entity (native) as this behavior + # may also be used for update. + value = converters[atype](value) + elif atype == 'Password': # XXX could be done using a TYPE_CONVERTERS callback + # if value is a Binary instance, this mean we got it + # from a query result and so it is already encrypted + if isinstance(value, Binary): + value = value.getvalue() + else: + value = crypt_password(value) + value = self._binary(value) + elif isinstance(value, Binary): + value = self._binary(value.getvalue()) + attrs[SQL_PREFIX+str(attr)] = value + attrs[SQL_PREFIX+'eid'] = entity.eid + return attrs + + # these are overridden by set_log_methods below + # only defining here to prevent pylint from complaining + info = warning = error = critical = exception = debug = lambda msg,*a,**kw: None + +set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) + + +# connection initialization functions ########################################## + +def _install_sqlite_querier_patch(): + """This monkey-patch hotfixes a bug sqlite causing some dates to be returned as strings rather than + date objects (http://www.sqlite.org/cvstrac/tktview?tn=1327,33) + """ + from cubicweb.server.querier import QuerierHelper + + if hasattr(QuerierHelper, '_sqlite_patched'): + return # already monkey patched + + def wrap_execute(base_execute): + def new_execute(*args, **kwargs): + rset = base_execute(*args, **kwargs) + if rset.description: + found_date = False + for row, rowdesc in zip(rset, rset.description): + for cellindex, (value, vtype) in enumerate(zip(row, rowdesc)): + if vtype in ('TZDatetime', 'Date', 'Datetime') \ + and isinstance(value, text_type): + found_date = True + value = value.rsplit('.', 1)[0] + try: + row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') + except Exception: + row[cellindex] = strptime(value, '%Y-%m-%d') + if vtype == 'TZDatetime': + row[cellindex] = row[cellindex].replace(tzinfo=utc) + if vtype == 'Time' and isinstance(value, text_type): + found_date = True + try: + row[cellindex] = strptime(value, '%H:%M:%S') + except Exception: + # DateTime used as Time? + row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') + if vtype == 'Interval' and isinstance(value, int): + found_date = True + # XXX value is in number of seconds? + row[cellindex] = timedelta(0, value, 0) + if not found_date: + break + return rset + return new_execute + + QuerierHelper.execute = wrap_execute(QuerierHelper.execute) + QuerierHelper._sqlite_patched = True + + +def _init_sqlite_connection(cnx): + """Internal function that will be called to init a sqlite connection""" + _install_sqlite_querier_patch() + + class group_concat(object): + def __init__(self): + self.values = set() + def step(self, value): + if value is not None: + self.values.add(value) + def finalize(self): + return ', '.join(text_type(v) for v in self.values) + + cnx.create_aggregate("GROUP_CONCAT", 1, group_concat) + + def _limit_size(text, maxsize, format='text/plain'): + if len(text) < maxsize: + return text + if format in ('text/html', 'text/xhtml', 'text/xml'): + text = remove_html_tags(text) + if len(text) > maxsize: + text = text[:maxsize] + '...' + return text + + def limit_size3(text, format, maxsize): + return _limit_size(text, maxsize, format) + cnx.create_function("LIMIT_SIZE", 3, limit_size3) + + def limit_size2(text, maxsize): + return _limit_size(text, maxsize) + cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) + + from logilab.common.date import strptime + def weekday(ustr): + try: + dt = strptime(ustr, '%Y-%m-%d %H:%M:%S') + except: + dt = strptime(ustr, '%Y-%m-%d') + # expect sunday to be 1, saturday 7 while weekday method return 0 for + # monday + return (dt.weekday() + 1) % 7 + cnx.create_function("WEEKDAY", 1, weekday) + + cnx.cursor().execute("pragma foreign_keys = on") + + import yams.constraints + yams.constraints.patch_sqlite_decimal() + +sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) +sqlite_hooks.append(_init_sqlite_connection) + + +def _init_postgres_connection(cnx): + """Internal function that will be called to init a postgresql connection""" + cnx.cursor().execute('SET TIME ZONE UTC') + # commit is needed, else setting are lost if the connection is first + # rolled back + cnx.commit() + +postgres_hooks = SQL_CONNECT_HOOKS.setdefault('postgres', []) +postgres_hooks.append(_init_postgres_connection)