--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/sqlutils.py Wed Nov 05 15:52:50 2008 +0100
@@ -0,0 +1,254 @@
+"""SQL utilities functions and classes.
+
+:organization: Logilab
+:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+from logilab.common.shellutils import ProgressBar
+from logilab.common.db import get_dbapi_compliant_module
+from logilab.common.adbh import get_adv_func_helper
+from logilab.common.sqlgen import SQLGenerator
+
+from indexer import get_indexer
+
+from cubicweb import Binary, ConfigurationError
+from cubicweb.common.uilib import remove_html_tags
+from cubicweb.server import SQL_CONNECT_HOOKS
+from cubicweb.server.utils import crypt_password, cartesian_product
+
+
+def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'):
+ """execute sql statements ignoring DROP/ CREATE GROUP or USER statements
+ error. If a cnx is given, commit at each statement
+ """
+ if hasattr(cursor_or_execute, 'execute'):
+ execute = cursor_or_execute.execute
+ else:
+ execute = cursor_or_execute
+ sqlstmts = sqlstmts.split(delimiter)
+ if withpb:
+ pb = ProgressBar(len(sqlstmts))
+ for sql in sqlstmts:
+ sql = sql.strip()
+ if withpb:
+ pb.update()
+ if not sql:
+ continue
+ # some dbapi modules doesn't accept unicode for sql string
+ execute(str(sql))
+ if withpb:
+ print
+
+
+def sqlgrants(schema, driver, user,
+ text_index=True, set_owner=True,
+ skip_relations=(), skip_entities=()):
+ """return sql to give all access privileges to the given user on the system
+ schema
+ """
+ from yams.schema2sql import grant_schema
+ from cubicweb.server.sources import native
+ output = []
+ w = output.append
+ w(native.grant_schema(user, set_owner))
+ w('')
+ if text_index:
+ indexer = get_indexer(driver)
+ w(indexer.sql_grant_user(user))
+ w('')
+ w(grant_schema(schema, user, set_owner, skip_entities=skip_entities))
+ return '\n'.join(output)
+
+
+def sqlschema(schema, driver, text_index=True,
+ user=None, set_owner=False,
+ skip_relations=('has_text', 'identity'), skip_entities=()):
+ """return the system sql schema, according to the given parameters"""
+ from yams.schema2sql import schema2sql
+ from cubicweb.server.sources import native
+ if set_owner:
+ assert user, 'user is argument required when set_owner is true'
+ output = []
+ w = output.append
+ w(native.sql_schema(driver))
+ w('')
+ if text_index:
+ indexer = get_indexer(driver)
+ w(indexer.sql_init_fti())
+ w('')
+ dbhelper = get_adv_func_helper(driver)
+ w(schema2sql(dbhelper, schema,
+ skip_entities=skip_entities, skip_relations=skip_relations))
+ if dbhelper.users_support and user:
+ w('')
+ w(sqlgrants(schema, driver, user, text_index, set_owner,
+ skip_relations, skip_entities))
+ return '\n'.join(output)
+
+
+def sqldropschema(schema, driver, text_index=True,
+ skip_relations=('has_text', 'identity'), skip_entities=()):
+ """return the sql to drop the schema, according to the given parameters"""
+ from yams.schema2sql import dropschema2sql
+ from cubicweb.server.sources import native
+ output = []
+ w = output.append
+ w(native.sql_drop_schema(driver))
+ w('')
+ if text_index:
+ indexer = get_indexer(driver)
+ w(indexer.sql_drop_fti())
+ w('')
+ w(dropschema2sql(schema,
+ skip_entities=skip_entities, skip_relations=skip_relations))
+ return '\n'.join(output)
+
+
+
+class SQLAdapterMixIn(object):
+ """Mixin for SQL data sources, getting a connection from a configuration
+ dictionary and handling connection locking
+ """
+
+ def __init__(self, source_config):
+ try:
+ self.dbdriver = source_config['db-driver'].lower()
+ self.dbname = source_config['db-name']
+ except KeyError:
+ raise ConfigurationError('missing some expected entries in sources file')
+ self.dbhost = source_config.get('db-host')
+ port = source_config.get('db-port')
+ self.dbport = port and int(port) or None
+ self.dbuser = source_config.get('db-user')
+ self.dbpasswd = source_config.get('db-password')
+ self.encoding = source_config.get('db-encoding', 'UTF-8')
+ self.dbapi_module = get_dbapi_compliant_module(self.dbdriver)
+ self.binary = self.dbapi_module.Binary
+ self.dbhelper = self.dbapi_module.adv_func_helper
+ self.sqlgen = SQLGenerator()
+
+ def get_connection(self, user=None, password=None):
+ """open and return a connection to the database"""
+ if user or self.dbuser:
+ self.info('connecting to %s@%s for user %s', self.dbname,
+ self.dbhost or 'localhost', user or self.dbuser)
+ else:
+ self.info('connecting to %s@%s', self.dbname,
+ self.dbhost or 'localhost')
+ cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
+ user or self.dbuser,
+ password or self.dbpasswd,
+ port=self.dbport)
+ init_cnx(self.dbdriver, cnx)
+ #self.dbapi_module.type_code_test(cnx.cursor())
+ return cnx
+
+ def merge_args(self, args, query_args):
+ if args is not None:
+ args = dict(args)
+ for key, val in args.items():
+ # convert cubicweb binary into db binary
+ if isinstance(val, Binary):
+ val = self.binary(val.getvalue())
+ args[key] = val
+ # should not collide
+ args.update(query_args)
+ return args
+ return query_args
+
+ def process_result(self, cursor):
+ """return a list of CubicWeb compliant values from data in the given cursor
+ """
+ descr = cursor.description
+ encoding = self.encoding
+ process_value = self.dbapi_module.process_value
+ binary = Binary
+ results = cursor.fetchall()
+ for i, line in enumerate(results):
+ result = []
+ for col, value in enumerate(line):
+ if value is None:
+ result.append(value)
+ continue
+ result.append(process_value(value, descr[col], encoding, binary))
+ results[i] = result
+ return results
+
+
+ def preprocess_entity(self, entity):
+ """return a dictionary to use as extra argument to cursor.execute
+ to insert/update an entity
+ """
+ attrs = {}
+ eschema = entity.e_schema
+ for attr, value in entity.items():
+ rschema = eschema.subject_relation(attr)
+ if rschema.is_final():
+ atype = str(entity.e_schema.destination(attr))
+ if atype == 'Boolean':
+ value = self.dbhelper.boolean_value(value)
+ elif atype == 'Password':
+ # if value is a Binary instance, this mean we got it
+ # from a query result and so it is already encrypted
+ if isinstance(value, Binary):
+ value = value.getvalue()
+ else:
+ value = crypt_password(value)
+ elif isinstance(value, Binary):
+ value = self.binary(value.getvalue())
+ attrs[str(attr)] = value
+ return attrs
+
+
+from logging import getLogger
+from cubicweb import set_log_methods
+set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
+
+def init_sqlite_connexion(cnx):
+ # XXX should not be publicly exposed
+ #def comma_join(strings):
+ # return ', '.join(strings)
+ #cnx.create_function("COMMA_JOIN", 1, comma_join)
+
+ class concat_strings(object):
+ def __init__(self):
+ self.values = []
+ def step(self, value):
+ if value is not None:
+ self.values.append(value)
+ def finalize(self):
+ return ', '.join(self.values)
+ # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
+ # some time
+ cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
+ cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
+
+ def _limit_size(text, maxsize, format='text/plain'):
+ if len(text) < maxsize:
+ return text
+ if format in ('text/html', 'text/xhtml', 'text/xml'):
+ text = remove_html_tags(text)
+ if len(text) > maxsize:
+ text = text[:maxsize] + '...'
+ return text
+
+ def limit_size3(text, format, maxsize):
+ return _limit_size(text, maxsize, format)
+ cnx.create_function("LIMIT_SIZE", 3, limit_size3)
+
+ def limit_size2(text, maxsize):
+ return _limit_size(text, maxsize)
+ cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
+ import yams.constraints
+ if hasattr(yams.constraints, 'patch_sqlite_decimal'):
+ yams.constraints.patch_sqlite_decimal()
+
+
+sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
+sqlite_hooks.append(init_sqlite_connexion)
+
+def init_cnx(driver, cnx):
+ for hook in SQL_CONNECT_HOOKS.get(driver, ()):
+ hook(cnx)