diff -r 000000000000 -r b97547f5f1fa server/sqlutils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/sqlutils.py Wed Nov 05 15:52:50 2008 +0100 @@ -0,0 +1,254 @@ +"""SQL utilities functions and classes. + +:organization: Logilab +:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + +from logilab.common.shellutils import ProgressBar +from logilab.common.db import get_dbapi_compliant_module +from logilab.common.adbh import get_adv_func_helper +from logilab.common.sqlgen import SQLGenerator + +from indexer import get_indexer + +from cubicweb import Binary, ConfigurationError +from cubicweb.common.uilib import remove_html_tags +from cubicweb.server import SQL_CONNECT_HOOKS +from cubicweb.server.utils import crypt_password, cartesian_product + + +def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'): + """execute sql statements ignoring DROP/ CREATE GROUP or USER statements + error. If a cnx is given, commit at each statement + """ + if hasattr(cursor_or_execute, 'execute'): + execute = cursor_or_execute.execute + else: + execute = cursor_or_execute + sqlstmts = sqlstmts.split(delimiter) + if withpb: + pb = ProgressBar(len(sqlstmts)) + for sql in sqlstmts: + sql = sql.strip() + if withpb: + pb.update() + if not sql: + continue + # some dbapi modules doesn't accept unicode for sql string + execute(str(sql)) + if withpb: + print + + +def sqlgrants(schema, driver, user, + text_index=True, set_owner=True, + skip_relations=(), skip_entities=()): + """return sql to give all access privileges to the given user on the system + schema + """ + from yams.schema2sql import grant_schema + from cubicweb.server.sources import native + output = [] + w = output.append + w(native.grant_schema(user, set_owner)) + w('') + if text_index: + indexer = get_indexer(driver) + w(indexer.sql_grant_user(user)) + w('') + w(grant_schema(schema, user, set_owner, skip_entities=skip_entities)) + return '\n'.join(output) + + +def sqlschema(schema, driver, text_index=True, + user=None, set_owner=False, + skip_relations=('has_text', 'identity'), skip_entities=()): + """return the system sql schema, according to the given parameters""" + from yams.schema2sql import schema2sql + from cubicweb.server.sources import native + if set_owner: + assert user, 'user is argument required when set_owner is true' + output = [] + w = output.append + w(native.sql_schema(driver)) + w('') + if text_index: + indexer = get_indexer(driver) + w(indexer.sql_init_fti()) + w('') + dbhelper = get_adv_func_helper(driver) + w(schema2sql(dbhelper, schema, + skip_entities=skip_entities, skip_relations=skip_relations)) + if dbhelper.users_support and user: + w('') + w(sqlgrants(schema, driver, user, text_index, set_owner, + skip_relations, skip_entities)) + return '\n'.join(output) + + +def sqldropschema(schema, driver, text_index=True, + skip_relations=('has_text', 'identity'), skip_entities=()): + """return the sql to drop the schema, according to the given parameters""" + from yams.schema2sql import dropschema2sql + from cubicweb.server.sources import native + output = [] + w = output.append + w(native.sql_drop_schema(driver)) + w('') + if text_index: + indexer = get_indexer(driver) + w(indexer.sql_drop_fti()) + w('') + w(dropschema2sql(schema, + skip_entities=skip_entities, skip_relations=skip_relations)) + return '\n'.join(output) + + + +class SQLAdapterMixIn(object): + """Mixin for SQL data sources, getting a connection from a configuration + dictionary and handling connection locking + """ + + def __init__(self, source_config): + try: + self.dbdriver = source_config['db-driver'].lower() + self.dbname = source_config['db-name'] + except KeyError: + raise ConfigurationError('missing some expected entries in sources file') + self.dbhost = source_config.get('db-host') + port = source_config.get('db-port') + self.dbport = port and int(port) or None + self.dbuser = source_config.get('db-user') + self.dbpasswd = source_config.get('db-password') + self.encoding = source_config.get('db-encoding', 'UTF-8') + self.dbapi_module = get_dbapi_compliant_module(self.dbdriver) + self.binary = self.dbapi_module.Binary + self.dbhelper = self.dbapi_module.adv_func_helper + self.sqlgen = SQLGenerator() + + def get_connection(self, user=None, password=None): + """open and return a connection to the database""" + if user or self.dbuser: + self.info('connecting to %s@%s for user %s', self.dbname, + self.dbhost or 'localhost', user or self.dbuser) + else: + self.info('connecting to %s@%s', self.dbname, + self.dbhost or 'localhost') + cnx = self.dbapi_module.connect(self.dbhost, self.dbname, + user or self.dbuser, + password or self.dbpasswd, + port=self.dbport) + init_cnx(self.dbdriver, cnx) + #self.dbapi_module.type_code_test(cnx.cursor()) + return cnx + + def merge_args(self, args, query_args): + if args is not None: + args = dict(args) + for key, val in args.items(): + # convert cubicweb binary into db binary + if isinstance(val, Binary): + val = self.binary(val.getvalue()) + args[key] = val + # should not collide + args.update(query_args) + return args + return query_args + + def process_result(self, cursor): + """return a list of CubicWeb compliant values from data in the given cursor + """ + descr = cursor.description + encoding = self.encoding + process_value = self.dbapi_module.process_value + binary = Binary + results = cursor.fetchall() + for i, line in enumerate(results): + result = [] + for col, value in enumerate(line): + if value is None: + result.append(value) + continue + result.append(process_value(value, descr[col], encoding, binary)) + results[i] = result + return results + + + def preprocess_entity(self, entity): + """return a dictionary to use as extra argument to cursor.execute + to insert/update an entity + """ + attrs = {} + eschema = entity.e_schema + for attr, value in entity.items(): + rschema = eschema.subject_relation(attr) + if rschema.is_final(): + atype = str(entity.e_schema.destination(attr)) + if atype == 'Boolean': + value = self.dbhelper.boolean_value(value) + elif atype == 'Password': + # if value is a Binary instance, this mean we got it + # from a query result and so it is already encrypted + if isinstance(value, Binary): + value = value.getvalue() + else: + value = crypt_password(value) + elif isinstance(value, Binary): + value = self.binary(value.getvalue()) + attrs[str(attr)] = value + return attrs + + +from logging import getLogger +from cubicweb import set_log_methods +set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) + +def init_sqlite_connexion(cnx): + # XXX should not be publicly exposed + #def comma_join(strings): + # return ', '.join(strings) + #cnx.create_function("COMMA_JOIN", 1, comma_join) + + class concat_strings(object): + def __init__(self): + self.values = [] + def step(self, value): + if value is not None: + self.values.append(value) + def finalize(self): + return ', '.join(self.values) + # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for + # some time + cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings) + cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings) + + def _limit_size(text, maxsize, format='text/plain'): + if len(text) < maxsize: + return text + if format in ('text/html', 'text/xhtml', 'text/xml'): + text = remove_html_tags(text) + if len(text) > maxsize: + text = text[:maxsize] + '...' + return text + + def limit_size3(text, format, maxsize): + return _limit_size(text, maxsize, format) + cnx.create_function("LIMIT_SIZE", 3, limit_size3) + + def limit_size2(text, maxsize): + return _limit_size(text, maxsize) + cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) + import yams.constraints + if hasattr(yams.constraints, 'patch_sqlite_decimal'): + yams.constraints.patch_sqlite_decimal() + + +sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) +sqlite_hooks.append(init_sqlite_connexion) + +def init_cnx(driver, cnx): + for hook in SQL_CONNECT_HOOKS.get(driver, ()): + hook(cnx)