author | Denis Laxalde <denis.laxalde@logilab.fr> |
Mon, 19 Jun 2017 18:15:28 +0200 | |
changeset 12188 | fea018b2e056 |
parent 12047 | 85416b43310a |
child 12220 | 3ba6016a459c |
permissions | -rw-r--r-- |
# copyright 2003-2016 LOGILAB S.A. (Paris, FRANCE), all rights reserved. # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr # # This file is part of CubicWeb. # # CubicWeb is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 2.1 of the License, or (at your option) # any later version. # # CubicWeb is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License along # with CubicWeb. If not, see <http://www.gnu.org/licenses/>. """SQL utilities functions and classes.""" from __future__ import print_function import sys import re import subprocess from os.path import abspath from logging import getLogger from datetime import time, datetime, timedelta from six import string_types, text_type from six.moves import filter from pytz import utc from logilab import database as db, common as lgc from logilab.common.shellutils import ProgressBar, DummyProgressBar from logilab.common.deprecation import deprecated from logilab.common.logging_ext import set_log_methods from logilab.common.date import utctime, utcdatetime, strptime from logilab.database.sqlgen import SQLGenerator from cubicweb import Binary, ConfigurationError from cubicweb.uilib import remove_html_tags from cubicweb.schema import PURE_VIRTUAL_RTYPES from cubicweb.server import SQL_CONNECT_HOOKS from cubicweb.server.utils import crypt_password lgc.USE_MX_DATETIME = False SQL_PREFIX = 'cw_' def _run_command(cmd): if isinstance(cmd, string_types): print(cmd) return subprocess.call(cmd, shell=True) else: print(' '.join(cmd)) return subprocess.call(cmd) def sqlexec(sqlstmts, cursor_or_execute, withpb=True, pbtitle='', delimiter=';', cnx=None): """execute sql statements ignoring DROP/ CREATE GROUP or USER statements error. :sqlstmts_as_string: a string or a list of sql statements. :cursor_or_execute: sql cursor or a callback used to execute statements :cnx: if given, commit/rollback at each statement. :withpb: if True, display a progresse bar :pbtitle: a string displayed as the progress bar title (if `withpb=True`) :delimiter: a string used to split sqlstmts (if it is a string) Return the failed statements (same type as sqlstmts) """ if hasattr(cursor_or_execute, 'execute'): execute = cursor_or_execute.execute else: execute = cursor_or_execute sqlstmts_as_string = False if isinstance(sqlstmts, string_types): sqlstmts_as_string = True sqlstmts = sqlstmts.split(delimiter) if withpb: if sys.stdout.isatty(): pb = ProgressBar(len(sqlstmts), title=pbtitle) else: pb = DummyProgressBar() failed = [] for sql in sqlstmts: sql = sql.strip() if withpb: pb.update() if not sql: continue try: # some dbapi modules doesn't accept unicode for sql string execute(str(sql)) except Exception as ex: print(ex, file=sys.stderr) if cnx: cnx.rollback() failed.append(sql) else: if cnx: cnx.commit() if withpb: print() if sqlstmts_as_string: failed = delimiter.join(failed) return failed def sqlgrants(schema, driver, user, text_index=True, set_owner=True, skip_relations=(), skip_entities=()): """Return a list of SQL statements to give all access privileges to the given user on the database. """ from cubicweb.server.schema2sql import grant_schema from cubicweb.server.sources import native stmts = list(native.grant_schema(user, set_owner)) if text_index: dbhelper = db.get_db_helper(driver) # XXX should return a list of sql statements rather than ';' joined statements stmts += dbhelper.sql_grant_user_on_fti(user).split(';') stmts += grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX) return stmts def sqlschema(schema, driver, text_index=True, user=None, set_owner=False, skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()): """Return the database SQL schema as a list of SQL statements, according to the given parameters. """ from cubicweb.server.schema2sql import schema2sql from cubicweb.server.sources import native if set_owner: assert user, 'user is argument required when set_owner is true' stmts = list(native.sql_schema(driver)) dbhelper = db.get_db_helper(driver) if text_index: stmts += dbhelper.sql_init_fti().split(';') # XXX stmts += schema2sql(dbhelper, schema, prefix=SQL_PREFIX, skip_entities=skip_entities, skip_relations=skip_relations) if dbhelper.users_support and user: stmts += sqlgrants(schema, driver, user, text_index, set_owner, skip_relations, skip_entities) return stmts _SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match def sql_drop_all_user_tables(driver_or_helper, sqlcursor): """Return ths sql to drop all tables found in the database system.""" if not getattr(driver_or_helper, 'list_tables', None): dbhelper = db.get_db_helper(driver_or_helper) else: dbhelper = driver_or_helper stmts = [dbhelper.sql_drop_sequence('entities_id_seq')] # for mssql, we need to drop views before tables if hasattr(dbhelper, 'list_views'): stmts += ['DROP VIEW %s;' % name for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))] stmts += ['DROP TABLE %s;' % name for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))] return stmts class ConnectionWrapper(object): """Wrap a connection to the system source's database, attempting to handle automatic reconnection. """ # since 3.19, we only have to manage the system source connection def __init__(self, system_source): # dictionary of (source, connection), indexed by sources'uri self._source = system_source self.cnx = system_source.get_connection() self.cu = self.cnx.cursor() def commit(self): """commit the current transaction for this user""" # let exception propagates self.cnx.commit() def rollback(self): """rollback the current transaction for this user""" # catch exceptions, rollback other sources anyway try: self.cnx.rollback() except Exception: self._source.critical('rollback error', exc_info=sys.exc_info()) # error on rollback, the connection is much probably in a really # bad state. Replace it by a new one. self.reconnect() def close(self, i_know_what_i_do=False): """close all connections in the set""" if i_know_what_i_do is not True: # unexpected closing safety belt raise RuntimeError('connections set shouldn\'t be closed') try: self.cu.close() self.cu = None except Exception: pass try: self.cnx.close() self.cnx = None except Exception: pass # internals ############################################################### def cnxset_freed(self): """connections set is being freed from a session""" pass # no nothing by default def reconnect(self): """reopen a connection for this source or all sources if none specified """ try: # properly close existing connection if any self.cnx.close() except Exception: pass self._source.info('trying to reconnect') self.cnx = self._source.get_connection() self.cu = self.cnx.cursor() @deprecated('[3.19] use .cu instead') def __getitem__(self, uri): assert uri == 'system' return self.cu @deprecated('[3.19] use repo.system_source instead') def source(self, uid): assert uid == 'system' return self._source @deprecated('[3.19] use .cnx instead') def connection(self, uid): assert uid == 'system' return self.cnx class SqliteConnectionWrapper(ConnectionWrapper): """Sqlite specific connection wrapper: close the connection each time it's freed (and reopen it later when needed) """ def __init__(self, system_source): # don't call parent's __init__, we don't want to initiate the connection self._source = system_source _cnx = None def cnxset_freed(self): self.cu.close() self.cnx.close() self.cnx = self.cu = None @property def cnx(self): if self._cnx is None: self._cnx = self._source.get_connection() self._cu = self._cnx.cursor() return self._cnx @cnx.setter def cnx(self, value): self._cnx = value @property def cu(self): if self._cnx is None: self._cnx = self._source.get_connection() self._cu = self._cnx.cursor() return self._cu @cu.setter def cu(self, value): self._cu = value class SQLAdapterMixIn(object): """Mixin for SQL data sources, getting a connection from a configuration dictionary and handling connection locking """ cnx_wrap = ConnectionWrapper def __init__(self, source_config, repairing=False): try: self.dbdriver = source_config['db-driver'].lower() dbname = source_config['db-name'] except KeyError: raise ConfigurationError('missing some expected entries in sources file') dbhost = source_config.get('db-host') port = source_config.get('db-port') dbport = port and int(port) or None dbuser = source_config.get('db-user') dbpassword = source_config.get('db-password') dbencoding = source_config.get('db-encoding', 'UTF-8') dbextraargs = source_config.get('db-extra-arguments') dbnamespace = source_config.get('db-namespace') self.dbhelper = db.get_db_helper(self.dbdriver) self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser, dbpassword, dbextraargs, dbencoding, dbnamespace) self.sqlgen = SQLGenerator() # copy back some commonly accessed attributes dbapi_module = self.dbhelper.dbapi_module self.OperationalError = dbapi_module.OperationalError self.InterfaceError = dbapi_module.InterfaceError self.DbapiError = dbapi_module.Error self._binary = self.dbhelper.binary_value self._process_value = dbapi_module.process_value self._dbencoding = dbencoding if self.dbdriver == 'sqlite': self.cnx_wrap = SqliteConnectionWrapper self.dbhelper.dbname = abspath(self.dbhelper.dbname) if not repairing: statement_timeout = int(source_config.get('db-statement-timeout', 0)) if statement_timeout > 0: def set_postgres_timeout(cnx): cnx.cursor().execute('SET statement_timeout to %d' % statement_timeout) cnx.commit() postgres_hooks = SQL_CONNECT_HOOKS['postgres'] postgres_hooks.append(set_postgres_timeout) def wrapped_connection(self): """open and return a connection to the database, wrapped into a class handling reconnection and all """ return self.cnx_wrap(self) def get_connection(self): """open and return a connection to the database""" return self.dbhelper.get_connection() def backup_to_file(self, backupfile, confirm): for cmd in self.dbhelper.backup_commands(backupfile, keepownership=False): if _run_command(cmd): if not confirm(' [Failed] Continue anyway?', default='n'): raise Exception('Failed command: %s' % cmd) def restore_from_file(self, backupfile, confirm, drop=True): for cmd in self.dbhelper.restore_commands(backupfile, keepownership=False, drop=drop): if _run_command(cmd): if not confirm(' [Failed] Continue anyway?', default='n'): raise Exception('Failed command: %s' % cmd) def merge_args(self, args, query_args): if args is not None: newargs = {} for key, val in args.items(): # convert cubicweb binary into db binary if isinstance(val, Binary): val = self._binary(val.getvalue()) # convert timestamp to utc. # expect SET TiME ZONE to UTC at connection opening time. # This shouldn't change anything for datetime without TZ. elif isinstance(val, datetime) and val.tzinfo is not None: val = utcdatetime(val) elif isinstance(val, time) and val.tzinfo is not None: val = utctime(val) newargs[key] = val # should not collide assert not (frozenset(newargs) & frozenset(query_args)), \ 'unexpected collision: %s' % (frozenset(newargs) & frozenset(query_args)) newargs.update(query_args) return newargs return query_args def process_result(self, cursor, cnx=None, column_callbacks=None): """return a list of CubicWeb compliant values from data in the given cursor """ return list(self.iter_process_result(cursor, cnx, column_callbacks)) def iter_process_result(self, cursor, cnx, column_callbacks=None): """return a iterator on tuples of CubicWeb compliant values from data in the given cursor """ # use two different implementations to avoid paying the price of # callback lookup for each *cell* in results when there is nothing to # lookup if not column_callbacks: return self.dbhelper.dbapi_module.process_cursor(cursor, self._dbencoding, Binary) assert cnx return self._cb_process_result(cursor, column_callbacks, cnx) def _cb_process_result(self, cursor, column_callbacks, cnx): # begin bind to locals for optimization descr = cursor.description encoding = self._dbencoding process_value = self._process_value binary = Binary # /end cursor.arraysize = 100 while True: results = cursor.fetchmany() if not results: break for line in results: result = [] for col, value in enumerate(line): if value is None: result.append(value) continue cbstack = column_callbacks.get(col, None) if cbstack is None: value = process_value(value, descr[col], encoding, binary) else: for cb in cbstack: value = cb(self, cnx, value) result.append(value) yield result def preprocess_entity(self, entity): """return a dictionary to use as extra argument to cursor.execute to insert/update an entity into a SQL database """ attrs = {} eschema = entity.e_schema converters = getattr(self.dbhelper, 'TYPE_CONVERTERS', {}) for attr, value in entity.cw_edited.items(): if value is not None and eschema.subjrels[attr].final: atype = str(entity.e_schema.destination(attr)) if atype in converters: # It is easier to modify preprocess_entity rather # than add_entity (native) as this behavior # may also be used for update. value = converters[atype](value) elif atype == 'Password': # XXX could be done using a TYPE_CONVERTERS callback # if value is a Binary instance, this mean we got it # from a query result and so it is already encrypted if isinstance(value, Binary): value = value.getvalue() else: value = crypt_password(value) value = self._binary(value) elif isinstance(value, Binary): value = self._binary(value.getvalue()) attrs[SQL_PREFIX + str(attr)] = value attrs[SQL_PREFIX + 'eid'] = entity.eid return attrs # these are overridden by set_log_methods below # only defining here to prevent pylint from complaining info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) # connection initialization functions ########################################## def _install_sqlite_querier_patch(): """This monkey-patch hotfixes a bug sqlite causing some dates to be returned as strings rather than date objects (http://www.sqlite.org/cvstrac/tktview?tn=1327,33) """ from cubicweb.server.querier import QuerierHelper if hasattr(QuerierHelper, '_sqlite_patched'): return # already monkey patched def wrap_execute(base_execute): def new_execute(*args, **kwargs): rset = base_execute(*args, **kwargs) if rset.description: found_date = False for row, rowdesc in zip(rset, rset.description): for cellindex, (value, vtype) in enumerate(zip(row, rowdesc)): if vtype in ('TZDatetime', 'Date', 'Datetime') \ and isinstance(value, text_type): found_date = True value = value.rsplit('.', 1)[0] try: row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') except Exception: row[cellindex] = strptime(value, '%Y-%m-%d') if vtype == 'TZDatetime': row[cellindex] = row[cellindex].replace(tzinfo=utc) if vtype == 'Time' and isinstance(value, text_type): found_date = True try: row[cellindex] = strptime(value, '%H:%M:%S') except Exception: # DateTime used as Time? row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') if vtype == 'Interval' and isinstance(value, int): found_date = True # XXX value is in number of seconds? row[cellindex] = timedelta(0, value, 0) if not found_date: break return rset return new_execute QuerierHelper.execute = wrap_execute(QuerierHelper.execute) QuerierHelper._sqlite_patched = True def _init_sqlite_connection(cnx): """Internal function that will be called to init a sqlite connection""" _install_sqlite_querier_patch() class group_concat(object): def __init__(self): self.values = set() def step(self, value): if value is not None: self.values.add(value) def finalize(self): return ', '.join(text_type(v) for v in self.values) cnx.create_aggregate("GROUP_CONCAT", 1, group_concat) def _limit_size(text, maxsize, format='text/plain'): if len(text) < maxsize: return text if format in ('text/html', 'text/xhtml', 'text/xml'): text = remove_html_tags(text) if len(text) > maxsize: text = text[:maxsize] + '...' return text def limit_size3(text, format, maxsize): return _limit_size(text, maxsize, format) cnx.create_function("LIMIT_SIZE", 3, limit_size3) def limit_size2(text, maxsize): return _limit_size(text, maxsize) cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) from logilab.common.date import strptime def weekday(ustr): try: dt = strptime(ustr, '%Y-%m-%d %H:%M:%S') except: dt = strptime(ustr, '%Y-%m-%d') # expect sunday to be 1, saturday 7 while weekday method return 0 for # monday return (dt.weekday() + 1) % 7 cnx.create_function("WEEKDAY", 1, weekday) cnx.cursor().execute("pragma foreign_keys = on") import yams.constraints yams.constraints.patch_sqlite_decimal() sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) sqlite_hooks.append(_init_sqlite_connection) def _init_postgres_connection(cnx): """Internal function that will be called to init a postgresql connection""" cnx.cursor().execute('SET TIME ZONE UTC') # commit is needed, else setting are lost if the connection is first # rolled back cnx.commit() postgres_hooks = SQL_CONNECT_HOOKS.setdefault('postgres', []) postgres_hooks.append(_init_postgres_connection)