diff -r 3b79a0fc91db -r c5aec27c1bf7 server/sqlutils.py --- a/server/sqlutils.py Mon Mar 08 09:51:29 2010 +0100 +++ b/server/sqlutils.py Mon Mar 08 17:57:29 2010 +0100 @@ -11,21 +11,17 @@ import subprocess from datetime import datetime, date -import logilab.common as lgc -from logilab.common import db +from logilab import db, common as lgc from logilab.common.shellutils import ProgressBar -from logilab.common.adbh import get_adv_func_helper -from logilab.common.sqlgen import SQLGenerator from logilab.common.date import todate, todatetime - -from indexer import get_indexer +from logilab.db.sqlgen import SQLGenerator from cubicweb import Binary, ConfigurationError from cubicweb.uilib import remove_html_tags from cubicweb.schema import PURE_VIRTUAL_RTYPES from cubicweb.server import SQL_CONNECT_HOOKS from cubicweb.server.utils import crypt_password - +from rql.utils import RQL_FUNCTIONS_REGISTRY lgc.USE_MX_DATETIME = False SQL_PREFIX = 'cw_' @@ -77,8 +73,8 @@ w(native.grant_schema(user, set_owner)) w('') if text_index: - indexer = get_indexer(driver) - w(indexer.sql_grant_user(user)) + dbhelper = db.get_db_helper(driver) + w(dbhelper.sql_grant_user_on_fti(user)) w('') w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX)) return '\n'.join(output) @@ -96,11 +92,10 @@ w = output.append w(native.sql_schema(driver)) w('') + dbhelper = db.get_db_helper(driver) if text_index: - indexer = get_indexer(driver) - w(indexer.sql_init_fti()) + w(dbhelper.sql_init_fti()) w('') - dbhelper = get_adv_func_helper(driver) w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX, skip_entities=skip_entities, skip_relations=skip_relations)) if dbhelper.users_support and user: @@ -120,8 +115,8 @@ w(native.sql_drop_schema(driver)) w('') if text_index: - indexer = get_indexer(driver) - w(indexer.sql_drop_fti()) + dbhelper = db.get_db_helper(driver) + w(dbhelper.sql_drop_fti()) w('') w(dropschema2sql(schema, prefix=SQL_PREFIX, skip_entities=skip_entities, @@ -137,55 +132,42 @@ def __init__(self, source_config): try: self.dbdriver = source_config['db-driver'].lower() - self.dbname = source_config['db-name'] + dbname = source_config['db-name'] except KeyError: raise ConfigurationError('missing some expected entries in sources file') - self.dbhost = source_config.get('db-host') + dbhost = source_config.get('db-host') port = source_config.get('db-port') - self.dbport = port and int(port) or None - self.dbuser = source_config.get('db-user') - self.dbpasswd = source_config.get('db-password') - self.encoding = source_config.get('db-encoding', 'UTF-8') - self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver) - self.dbdriver_extra_args = source_config.get('db-extra-arguments') - self.binary = self.dbapi_module.Binary - self.dbhelper = self.dbapi_module.adv_func_helper + dbport = port and int(port) or None + dbuser = source_config.get('db-user') + dbpassword = source_config.get('db-password') + dbencoding = source_config.get('db-encoding', 'UTF-8') + dbextraargs = source_config.get('db-extra-arguments') + self.dbhelper = db.get_db_helper(self.dbdriver) + self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser, + dbpassword, dbextraargs, + dbencoding) self.sqlgen = SQLGenerator() + # copy back some commonly accessed attributes + dbapi_module = self.dbhelper.dbapi_module + self.OperationalError = dbapi_module.OperationalError + self.InterfaceError = dbapi_module.InterfaceError + self._binary = dbapi_module.Binary + self._process_value = dbapi_module.process_value + self._dbencoding = dbencoding - def get_connection(self, user=None, password=None): + def get_connection(self): """open and return a connection to the database""" - if user or self.dbuser: - self.info('connecting to %s@%s for user %s', self.dbname, - self.dbhost or 'localhost', user or self.dbuser) - else: - self.info('connecting to %s@%s', self.dbname, - self.dbhost or 'localhost') - extra = {} - if self.dbdriver_extra_args: - extra = {'extra_args': self.dbdriver_extra_args} - cnx = self.dbapi_module.connect(self.dbhost, self.dbname, - user or self.dbuser, - password or self.dbpasswd, - port=self.dbport, - **extra) - init_cnx(self.dbdriver, cnx) - #self.dbapi_module.type_code_test(cnx.cursor()) - return cnx + return self.dbhelper.get_connection() def backup_to_file(self, backupfile): - for cmd in self.dbhelper.backup_commands(self.dbname, self.dbhost, - self.dbuser, backupfile, - dbport=self.dbport, + for cmd in self.dbhelper.backup_commands(backupfile, keepownership=False): if _run_command(cmd): if not confirm(' [Failed] Continue anyway?', default='n'): raise Exception('Failed command: %s' % cmd) def restore_from_file(self, backupfile, confirm, drop=True): - for cmd in self.dbhelper.restore_commands(self.dbname, self.dbhost, - self.dbuser, backupfile, - self.encoding, - dbport=self.dbport, + for cmd in self.dbhelper.restore_commands(backupfile, keepownership=False, drop=drop): if _run_command(cmd): @@ -198,7 +180,7 @@ for key, val in args.iteritems(): # convert cubicweb binary into db binary if isinstance(val, Binary): - val = self.binary(val.getvalue()) + val = self._binary(val.getvalue()) newargs[key] = val # should not collide newargs.update(query_args) @@ -208,10 +190,12 @@ def process_result(self, cursor): """return a list of CubicWeb compliant values from data in the given cursor """ + # begin bind to locals for optimization descr = cursor.description - encoding = self.encoding - process_value = self.dbapi_module.process_value + encoding = self._dbencoding + process_value = self._process_value binary = Binary + # /end results = cursor.fetchall() for i, line in enumerate(results): result = [] @@ -242,14 +226,14 @@ value = value.getvalue() else: value = crypt_password(value) - value = self.binary(value) + value = self._binary(value) # XXX needed for sqlite but I don't think it is for other backends elif atype == 'Datetime' and isinstance(value, date): value = todatetime(value) elif atype == 'Date' and isinstance(value, datetime): value = todate(value) elif isinstance(value, Binary): - value = self.binary(value.getvalue()) + value = self._binary(value.getvalue()) attrs[SQL_PREFIX+str(attr)] = value return attrs @@ -259,12 +243,8 @@ set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) def init_sqlite_connexion(cnx): - # XXX should not be publicly exposed - #def comma_join(strings): - # return ', '.join(strings) - #cnx.create_function("COMMA_JOIN", 1, comma_join) - class concat_strings(object): + class group_concat(object): def __init__(self): self.values = [] def step(self, value): @@ -272,10 +252,7 @@ self.values.append(value) def finalize(self): return ', '.join(self.values) - # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for - # some time - cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings) - cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings) + cnx.create_aggregate("GROUP_CONCAT", 1, group_concat) def _limit_size(text, maxsize, format='text/plain'): if len(text) < maxsize: @@ -293,9 +270,9 @@ def limit_size2(text, maxsize): return _limit_size(text, maxsize) cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) + import yams.constraints - if hasattr(yams.constraints, 'patch_sqlite_decimal'): - yams.constraints.patch_sqlite_decimal() + yams.constraints.patch_sqlite_decimal() def fspath(eid, etype, attr): try: @@ -320,10 +297,5 @@ raise cnx.create_function('_fsopen', 1, _fsopen) - sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) sqlite_hooks.append(init_sqlite_connexion) - -def init_cnx(driver, cnx): - for hook in SQL_CONNECT_HOOKS.get(driver, ()): - hook(cnx)