--- a/server/sqlutils.py Mon Jan 04 18:40:30 2016 +0100
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,591 +0,0 @@
-# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
-# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
-#
-# This file is part of CubicWeb.
-#
-# CubicWeb is free software: you can redistribute it and/or modify it under the
-# terms of the GNU Lesser General Public License as published by the Free
-# Software Foundation, either version 2.1 of the License, or (at your option)
-# any later version.
-#
-# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
-# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
-# details.
-#
-# You should have received a copy of the GNU Lesser General Public License along
-# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
-"""SQL utilities functions and classes."""
-from __future__ import print_function
-
-__docformat__ = "restructuredtext en"
-
-import sys
-import re
-import subprocess
-from os.path import abspath
-from logging import getLogger
-from datetime import time, datetime, timedelta
-
-from six import string_types, text_type
-from six.moves import filter
-
-from pytz import utc
-
-from logilab import database as db, common as lgc
-from logilab.common.shellutils import ProgressBar, DummyProgressBar
-from logilab.common.deprecation import deprecated
-from logilab.common.logging_ext import set_log_methods
-from logilab.common.date import utctime, utcdatetime, strptime
-from logilab.database.sqlgen import SQLGenerator
-
-from cubicweb import Binary, ConfigurationError
-from cubicweb.uilib import remove_html_tags
-from cubicweb.schema import PURE_VIRTUAL_RTYPES
-from cubicweb.server import SQL_CONNECT_HOOKS
-from cubicweb.server.utils import crypt_password
-
-lgc.USE_MX_DATETIME = False
-SQL_PREFIX = 'cw_'
-
-
-def _run_command(cmd):
- if isinstance(cmd, string_types):
- print(cmd)
- return subprocess.call(cmd, shell=True)
- else:
- print(' '.join(cmd))
- return subprocess.call(cmd)
-
-
-def sqlexec(sqlstmts, cursor_or_execute, withpb=True,
- pbtitle='', delimiter=';', cnx=None):
- """execute sql statements ignoring DROP/ CREATE GROUP or USER statements
- error.
-
- :sqlstmts_as_string: a string or a list of sql statements.
- :cursor_or_execute: sql cursor or a callback used to execute statements
- :cnx: if given, commit/rollback at each statement.
-
- :withpb: if True, display a progresse bar
- :pbtitle: a string displayed as the progress bar title (if `withpb=True`)
-
- :delimiter: a string used to split sqlstmts (if it is a string)
-
- Return the failed statements (same type as sqlstmts)
- """
- if hasattr(cursor_or_execute, 'execute'):
- execute = cursor_or_execute.execute
- else:
- execute = cursor_or_execute
- sqlstmts_as_string = False
- if isinstance(sqlstmts, string_types):
- sqlstmts_as_string = True
- sqlstmts = sqlstmts.split(delimiter)
- if withpb:
- if sys.stdout.isatty():
- pb = ProgressBar(len(sqlstmts), title=pbtitle)
- else:
- pb = DummyProgressBar()
- failed = []
- for sql in sqlstmts:
- sql = sql.strip()
- if withpb:
- pb.update()
- if not sql:
- continue
- try:
- # some dbapi modules doesn't accept unicode for sql string
- execute(str(sql))
- except Exception:
- if cnx:
- cnx.rollback()
- failed.append(sql)
- else:
- if cnx:
- cnx.commit()
- if withpb:
- print()
- if sqlstmts_as_string:
- failed = delimiter.join(failed)
- return failed
-
-
-def sqlgrants(schema, driver, user,
- text_index=True, set_owner=True,
- skip_relations=(), skip_entities=()):
- """return sql to give all access privileges to the given user on the system
- schema
- """
- from cubicweb.server.schema2sql import grant_schema
- from cubicweb.server.sources import native
- output = []
- w = output.append
- w(native.grant_schema(user, set_owner))
- w('')
- if text_index:
- dbhelper = db.get_db_helper(driver)
- w(dbhelper.sql_grant_user_on_fti(user))
- w('')
- w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
- return '\n'.join(output)
-
-
-def sqlschema(schema, driver, text_index=True,
- user=None, set_owner=False,
- skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()):
- """return the system sql schema, according to the given parameters"""
- from cubicweb.server.schema2sql import schema2sql
- from cubicweb.server.sources import native
- if set_owner:
- assert user, 'user is argument required when set_owner is true'
- output = []
- w = output.append
- w(native.sql_schema(driver))
- w('')
- dbhelper = db.get_db_helper(driver)
- if text_index:
- w(dbhelper.sql_init_fti().replace(';', ';;'))
- w('')
- w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
- skip_entities=skip_entities,
- skip_relations=skip_relations).replace(';', ';;'))
- if dbhelper.users_support and user:
- w('')
- w(sqlgrants(schema, driver, user, text_index, set_owner,
- skip_relations, skip_entities).replace(';', ';;'))
- return '\n'.join(output)
-
-
-def sqldropschema(schema, driver, text_index=True,
- skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()):
- """return the sql to drop the schema, according to the given parameters"""
- from cubicweb.server.schema2sql import dropschema2sql
- from cubicweb.server.sources import native
- output = []
- w = output.append
- if text_index:
- dbhelper = db.get_db_helper(driver)
- w(dbhelper.sql_drop_fti())
- w('')
- w(dropschema2sql(dbhelper, schema, prefix=SQL_PREFIX,
- skip_entities=skip_entities,
- skip_relations=skip_relations))
- w('')
- w(native.sql_drop_schema(driver))
- return '\n'.join(output)
-
-
-_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match
-def sql_drop_all_user_tables(driver_or_helper, sqlcursor):
- """Return ths sql to drop all tables found in the database system."""
- if not getattr(driver_or_helper, 'list_tables', None):
- dbhelper = db.get_db_helper(driver_or_helper)
- else:
- dbhelper = driver_or_helper
-
- cmds = [dbhelper.sql_drop_sequence('entities_id_seq')]
- # for mssql, we need to drop views before tables
- if hasattr(dbhelper, 'list_views'):
- cmds += ['DROP VIEW %s;' % name
- for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))]
- cmds += ['DROP TABLE %s;' % name
- for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))]
- return '\n'.join(cmds)
-
-
-class ConnectionWrapper(object):
- """handle connection to the system source, at some point associated to a
- :class:`Session`
- """
-
- # since 3.19, we only have to manage the system source connection
- def __init__(self, system_source):
- # dictionary of (source, connection), indexed by sources'uri
- self._source = system_source
- self.cnx = system_source.get_connection()
- self.cu = self.cnx.cursor()
-
- def commit(self):
- """commit the current transaction for this user"""
- # let exception propagates
- self.cnx.commit()
-
- def rollback(self):
- """rollback the current transaction for this user"""
- # catch exceptions, rollback other sources anyway
- try:
- self.cnx.rollback()
- except Exception:
- self._source.critical('rollback error', exc_info=sys.exc_info())
- # error on rollback, the connection is much probably in a really
- # bad state. Replace it by a new one.
- self.reconnect()
-
- def close(self, i_know_what_i_do=False):
- """close all connections in the set"""
- if i_know_what_i_do is not True: # unexpected closing safety belt
- raise RuntimeError('connections set shouldn\'t be closed')
- try:
- self.cu.close()
- self.cu = None
- except Exception:
- pass
- try:
- self.cnx.close()
- self.cnx = None
- except Exception:
- pass
-
- # internals ###############################################################
-
- def cnxset_freed(self):
- """connections set is being freed from a session"""
- pass # no nothing by default
-
- def reconnect(self):
- """reopen a connection for this source or all sources if none specified
- """
- try:
- # properly close existing connection if any
- self.cnx.close()
- except Exception:
- pass
- self._source.info('trying to reconnect')
- self.cnx = self._source.get_connection()
- self.cu = self.cnx.cursor()
-
- @deprecated('[3.19] use .cu instead')
- def __getitem__(self, uri):
- assert uri == 'system'
- return self.cu
-
- @deprecated('[3.19] use repo.system_source instead')
- def source(self, uid):
- assert uid == 'system'
- return self._source
-
- @deprecated('[3.19] use .cnx instead')
- def connection(self, uid):
- assert uid == 'system'
- return self.cnx
-
-
-class SqliteConnectionWrapper(ConnectionWrapper):
- """Sqlite specific connection wrapper: close the connection each time it's
- freed (and reopen it later when needed)
- """
- def __init__(self, system_source):
- # don't call parent's __init__, we don't want to initiate the connection
- self._source = system_source
-
- _cnx = None
-
- def cnxset_freed(self):
- self.cu.close()
- self.cnx.close()
- self.cnx = self.cu = None
-
- @property
- def cnx(self):
- if self._cnx is None:
- self._cnx = self._source.get_connection()
- self._cu = self._cnx.cursor()
- return self._cnx
- @cnx.setter
- def cnx(self, value):
- self._cnx = value
-
- @property
- def cu(self):
- if self._cnx is None:
- self._cnx = self._source.get_connection()
- self._cu = self._cnx.cursor()
- return self._cu
- @cu.setter
- def cu(self, value):
- self._cu = value
-
-
-class SQLAdapterMixIn(object):
- """Mixin for SQL data sources, getting a connection from a configuration
- dictionary and handling connection locking
- """
- cnx_wrap = ConnectionWrapper
-
- def __init__(self, source_config, repairing=False):
- try:
- self.dbdriver = source_config['db-driver'].lower()
- dbname = source_config['db-name']
- except KeyError:
- raise ConfigurationError('missing some expected entries in sources file')
- dbhost = source_config.get('db-host')
- port = source_config.get('db-port')
- dbport = port and int(port) or None
- dbuser = source_config.get('db-user')
- dbpassword = source_config.get('db-password')
- dbencoding = source_config.get('db-encoding', 'UTF-8')
- dbextraargs = source_config.get('db-extra-arguments')
- dbnamespace = source_config.get('db-namespace')
- self.dbhelper = db.get_db_helper(self.dbdriver)
- self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser,
- dbpassword, dbextraargs,
- dbencoding, dbnamespace)
- self.sqlgen = SQLGenerator()
- # copy back some commonly accessed attributes
- dbapi_module = self.dbhelper.dbapi_module
- self.OperationalError = dbapi_module.OperationalError
- self.InterfaceError = dbapi_module.InterfaceError
- self.DbapiError = dbapi_module.Error
- self._binary = self.dbhelper.binary_value
- self._process_value = dbapi_module.process_value
- self._dbencoding = dbencoding
- if self.dbdriver == 'sqlite':
- self.cnx_wrap = SqliteConnectionWrapper
- self.dbhelper.dbname = abspath(self.dbhelper.dbname)
- if not repairing:
- statement_timeout = int(source_config.get('db-statement-timeout', 0))
- if statement_timeout > 0:
- def set_postgres_timeout(cnx):
- cnx.cursor().execute('SET statement_timeout to %d' % statement_timeout)
- cnx.commit()
- postgres_hooks = SQL_CONNECT_HOOKS['postgres']
- postgres_hooks.append(set_postgres_timeout)
-
- def wrapped_connection(self):
- """open and return a connection to the database, wrapped into a class
- handling reconnection and all
- """
- return self.cnx_wrap(self)
-
- def get_connection(self):
- """open and return a connection to the database"""
- return self.dbhelper.get_connection()
-
- def backup_to_file(self, backupfile, confirm):
- for cmd in self.dbhelper.backup_commands(backupfile,
- keepownership=False):
- if _run_command(cmd):
- if not confirm(' [Failed] Continue anyway?', default='n'):
- raise Exception('Failed command: %s' % cmd)
-
- def restore_from_file(self, backupfile, confirm, drop=True):
- for cmd in self.dbhelper.restore_commands(backupfile,
- keepownership=False,
- drop=drop):
- if _run_command(cmd):
- if not confirm(' [Failed] Continue anyway?', default='n'):
- raise Exception('Failed command: %s' % cmd)
-
- def merge_args(self, args, query_args):
- if args is not None:
- newargs = {}
- for key, val in args.items():
- # convert cubicweb binary into db binary
- if isinstance(val, Binary):
- val = self._binary(val.getvalue())
- # convert timestamp to utc.
- # expect SET TiME ZONE to UTC at connection opening time.
- # This shouldn't change anything for datetime without TZ.
- elif isinstance(val, datetime) and val.tzinfo is not None:
- val = utcdatetime(val)
- elif isinstance(val, time) and val.tzinfo is not None:
- val = utctime(val)
- newargs[key] = val
- # should not collide
- assert not (frozenset(newargs) & frozenset(query_args)), \
- 'unexpected collision: %s' % (frozenset(newargs) & frozenset(query_args))
- newargs.update(query_args)
- return newargs
- return query_args
-
- def process_result(self, cursor, cnx=None, column_callbacks=None):
- """return a list of CubicWeb compliant values from data in the given cursor
- """
- return list(self.iter_process_result(cursor, cnx, column_callbacks))
-
- def iter_process_result(self, cursor, cnx, column_callbacks=None):
- """return a iterator on tuples of CubicWeb compliant values from data
- in the given cursor
- """
- # use two different implementations to avoid paying the price of
- # callback lookup for each *cell* in results when there is nothing to
- # lookup
- if not column_callbacks:
- return self.dbhelper.dbapi_module.process_cursor(cursor, self._dbencoding,
- Binary)
- assert cnx
- return self._cb_process_result(cursor, column_callbacks, cnx)
-
- def _cb_process_result(self, cursor, column_callbacks, cnx):
- # begin bind to locals for optimization
- descr = cursor.description
- encoding = self._dbencoding
- process_value = self._process_value
- binary = Binary
- # /end
- cursor.arraysize = 100
- while True:
- results = cursor.fetchmany()
- if not results:
- break
- for line in results:
- result = []
- for col, value in enumerate(line):
- if value is None:
- result.append(value)
- continue
- cbstack = column_callbacks.get(col, None)
- if cbstack is None:
- value = process_value(value, descr[col], encoding, binary)
- else:
- for cb in cbstack:
- value = cb(self, cnx, value)
- result.append(value)
- yield result
-
- def preprocess_entity(self, entity):
- """return a dictionary to use as extra argument to cursor.execute
- to insert/update an entity into a SQL database
- """
- attrs = {}
- eschema = entity.e_schema
- converters = getattr(self.dbhelper, 'TYPE_CONVERTERS', {})
- for attr, value in entity.cw_edited.items():
- if value is not None and eschema.subjrels[attr].final:
- atype = str(entity.e_schema.destination(attr))
- if atype in converters:
- # It is easier to modify preprocess_entity rather
- # than add_entity (native) as this behavior
- # may also be used for update.
- value = converters[atype](value)
- elif atype == 'Password': # XXX could be done using a TYPE_CONVERTERS callback
- # if value is a Binary instance, this mean we got it
- # from a query result and so it is already encrypted
- if isinstance(value, Binary):
- value = value.getvalue()
- else:
- value = crypt_password(value)
- value = self._binary(value)
- elif isinstance(value, Binary):
- value = self._binary(value.getvalue())
- attrs[SQL_PREFIX+str(attr)] = value
- attrs[SQL_PREFIX+'eid'] = entity.eid
- return attrs
-
- # these are overridden by set_log_methods below
- # only defining here to prevent pylint from complaining
- info = warning = error = critical = exception = debug = lambda msg,*a,**kw: None
-
-set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
-
-
-# connection initialization functions ##########################################
-
-def _install_sqlite_querier_patch():
- """This monkey-patch hotfixes a bug sqlite causing some dates to be returned as strings rather than
- date objects (http://www.sqlite.org/cvstrac/tktview?tn=1327,33)
- """
- from cubicweb.server.querier import QuerierHelper
-
- if hasattr(QuerierHelper, '_sqlite_patched'):
- return # already monkey patched
-
- def wrap_execute(base_execute):
- def new_execute(*args, **kwargs):
- rset = base_execute(*args, **kwargs)
- if rset.description:
- found_date = False
- for row, rowdesc in zip(rset, rset.description):
- for cellindex, (value, vtype) in enumerate(zip(row, rowdesc)):
- if vtype in ('TZDatetime', 'Date', 'Datetime') \
- and isinstance(value, text_type):
- found_date = True
- value = value.rsplit('.', 1)[0]
- try:
- row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S')
- except Exception:
- row[cellindex] = strptime(value, '%Y-%m-%d')
- if vtype == 'TZDatetime':
- row[cellindex] = row[cellindex].replace(tzinfo=utc)
- if vtype == 'Time' and isinstance(value, text_type):
- found_date = True
- try:
- row[cellindex] = strptime(value, '%H:%M:%S')
- except Exception:
- # DateTime used as Time?
- row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S')
- if vtype == 'Interval' and isinstance(value, int):
- found_date = True
- # XXX value is in number of seconds?
- row[cellindex] = timedelta(0, value, 0)
- if not found_date:
- break
- return rset
- return new_execute
-
- QuerierHelper.execute = wrap_execute(QuerierHelper.execute)
- QuerierHelper._sqlite_patched = True
-
-
-def _init_sqlite_connection(cnx):
- """Internal function that will be called to init a sqlite connection"""
- _install_sqlite_querier_patch()
-
- class group_concat(object):
- def __init__(self):
- self.values = set()
- def step(self, value):
- if value is not None:
- self.values.add(value)
- def finalize(self):
- return ', '.join(text_type(v) for v in self.values)
-
- cnx.create_aggregate("GROUP_CONCAT", 1, group_concat)
-
- def _limit_size(text, maxsize, format='text/plain'):
- if len(text) < maxsize:
- return text
- if format in ('text/html', 'text/xhtml', 'text/xml'):
- text = remove_html_tags(text)
- if len(text) > maxsize:
- text = text[:maxsize] + '...'
- return text
-
- def limit_size3(text, format, maxsize):
- return _limit_size(text, maxsize, format)
- cnx.create_function("LIMIT_SIZE", 3, limit_size3)
-
- def limit_size2(text, maxsize):
- return _limit_size(text, maxsize)
- cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
-
- from logilab.common.date import strptime
- def weekday(ustr):
- try:
- dt = strptime(ustr, '%Y-%m-%d %H:%M:%S')
- except:
- dt = strptime(ustr, '%Y-%m-%d')
- # expect sunday to be 1, saturday 7 while weekday method return 0 for
- # monday
- return (dt.weekday() + 1) % 7
- cnx.create_function("WEEKDAY", 1, weekday)
-
- cnx.cursor().execute("pragma foreign_keys = on")
-
- import yams.constraints
- yams.constraints.patch_sqlite_decimal()
-
-sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
-sqlite_hooks.append(_init_sqlite_connection)
-
-
-def _init_postgres_connection(cnx):
- """Internal function that will be called to init a postgresql connection"""
- cnx.cursor().execute('SET TIME ZONE UTC')
- # commit is needed, else setting are lost if the connection is first
- # rolled back
- cnx.commit()
-
-postgres_hooks = SQL_CONNECT_HOOKS.setdefault('postgres', [])
-postgres_hooks.append(_init_postgres_connection)