server/sqlutils.py
changeset 0 b97547f5f1fa
child 1016 26387b836099
child 1251 af40e615dc89
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/server/sqlutils.py	Wed Nov 05 15:52:50 2008 +0100
@@ -0,0 +1,254 @@
+"""SQL utilities functions and classes.
+
+:organization: Logilab
+:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+from logilab.common.shellutils import ProgressBar
+from logilab.common.db import get_dbapi_compliant_module
+from logilab.common.adbh import get_adv_func_helper
+from logilab.common.sqlgen import SQLGenerator
+
+from indexer import get_indexer
+
+from cubicweb import Binary, ConfigurationError
+from cubicweb.common.uilib import remove_html_tags
+from cubicweb.server import SQL_CONNECT_HOOKS
+from cubicweb.server.utils import crypt_password, cartesian_product
+
+
+def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'):
+    """execute sql statements ignoring DROP/ CREATE GROUP or USER statements
+    error. If a cnx is given, commit at each statement
+    """
+    if hasattr(cursor_or_execute, 'execute'):
+        execute = cursor_or_execute.execute
+    else:
+        execute = cursor_or_execute
+    sqlstmts = sqlstmts.split(delimiter)
+    if withpb:
+        pb = ProgressBar(len(sqlstmts))
+    for sql in sqlstmts:
+        sql = sql.strip()
+        if withpb:
+            pb.update()
+        if not sql:
+            continue
+        # some dbapi modules doesn't accept unicode for sql string
+        execute(str(sql))
+    if withpb:
+        print
+
+
+def sqlgrants(schema, driver, user,
+              text_index=True, set_owner=True,
+              skip_relations=(), skip_entities=()):
+    """return sql to give all access privileges to the given user on the system
+    schema
+    """
+    from yams.schema2sql import grant_schema
+    from cubicweb.server.sources import native
+    output = []
+    w = output.append
+    w(native.grant_schema(user, set_owner))
+    w('')
+    if text_index:
+        indexer = get_indexer(driver)
+        w(indexer.sql_grant_user(user))
+        w('')
+    w(grant_schema(schema, user, set_owner, skip_entities=skip_entities))
+    return '\n'.join(output)
+
+                  
+def sqlschema(schema, driver, text_index=True, 
+              user=None, set_owner=False,
+              skip_relations=('has_text', 'identity'), skip_entities=()):
+    """return the system sql schema, according to the given parameters"""
+    from yams.schema2sql import schema2sql
+    from cubicweb.server.sources import native
+    if set_owner:
+        assert user, 'user is argument required when set_owner is true'
+    output = []
+    w = output.append
+    w(native.sql_schema(driver))
+    w('')
+    if text_index:
+        indexer = get_indexer(driver)
+        w(indexer.sql_init_fti())
+        w('')
+    dbhelper = get_adv_func_helper(driver)
+    w(schema2sql(dbhelper, schema, 
+                 skip_entities=skip_entities, skip_relations=skip_relations))
+    if dbhelper.users_support and user:
+        w('')
+        w(sqlgrants(schema, driver, user, text_index, set_owner,
+                    skip_relations, skip_entities))
+    return '\n'.join(output)
+
+                  
+def sqldropschema(schema, driver, text_index=True, 
+                  skip_relations=('has_text', 'identity'), skip_entities=()):
+    """return the sql to drop the schema, according to the given parameters"""
+    from yams.schema2sql import dropschema2sql
+    from cubicweb.server.sources import native
+    output = []
+    w = output.append
+    w(native.sql_drop_schema(driver))
+    w('')
+    if text_index:
+        indexer = get_indexer(driver)
+        w(indexer.sql_drop_fti())
+        w('')
+    w(dropschema2sql(schema,
+                     skip_entities=skip_entities, skip_relations=skip_relations))
+    return '\n'.join(output)
+
+
+
+class SQLAdapterMixIn(object):
+    """Mixin for SQL data sources, getting a connection from a configuration
+    dictionary and handling connection locking
+    """
+    
+    def __init__(self, source_config):
+        try:
+            self.dbdriver = source_config['db-driver'].lower()
+            self.dbname = source_config['db-name']
+        except KeyError:
+            raise ConfigurationError('missing some expected entries in sources file')
+        self.dbhost = source_config.get('db-host')
+        port = source_config.get('db-port')
+        self.dbport = port and int(port) or None
+        self.dbuser = source_config.get('db-user')
+        self.dbpasswd = source_config.get('db-password')
+        self.encoding = source_config.get('db-encoding', 'UTF-8')
+        self.dbapi_module = get_dbapi_compliant_module(self.dbdriver)
+        self.binary = self.dbapi_module.Binary
+        self.dbhelper = self.dbapi_module.adv_func_helper
+        self.sqlgen = SQLGenerator()
+        
+    def get_connection(self, user=None, password=None):
+        """open and return a connection to the database"""
+        if user or self.dbuser:
+            self.info('connecting to %s@%s for user %s', self.dbname,
+                      self.dbhost or 'localhost', user or self.dbuser)
+        else:
+            self.info('connecting to %s@%s', self.dbname,
+                      self.dbhost or 'localhost')
+        cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
+                                        user or self.dbuser,
+                                        password or self.dbpasswd,
+                                        port=self.dbport)
+        init_cnx(self.dbdriver, cnx)
+        #self.dbapi_module.type_code_test(cnx.cursor())
+        return cnx
+
+    def merge_args(self, args, query_args):
+        if args is not None:
+            args = dict(args)
+            for key, val in args.items():
+                # convert cubicweb binary into db binary
+                if isinstance(val, Binary):
+                    val = self.binary(val.getvalue())
+                args[key] = val
+            # should not collide
+            args.update(query_args)
+            return args
+        return query_args
+
+    def process_result(self, cursor):
+        """return a list of CubicWeb compliant values from data in the given cursor
+        """
+        descr = cursor.description
+        encoding = self.encoding
+        process_value = self.dbapi_module.process_value
+        binary = Binary
+        results = cursor.fetchall()
+        for i, line in enumerate(results):
+            result = []
+            for col, value in enumerate(line):
+                if value is None:
+                    result.append(value)
+                    continue
+                result.append(process_value(value, descr[col], encoding, binary))
+            results[i] = result
+        return results
+
+
+    def preprocess_entity(self, entity):
+        """return a dictionary to use as extra argument to cursor.execute
+        to insert/update an entity
+        """
+        attrs = {}
+        eschema = entity.e_schema
+        for attr, value in entity.items():
+            rschema = eschema.subject_relation(attr)
+            if rschema.is_final():
+                atype = str(entity.e_schema.destination(attr))
+                if atype == 'Boolean':
+                    value = self.dbhelper.boolean_value(value)
+                elif atype == 'Password':
+                    # if value is a Binary instance, this mean we got it
+                    # from a query result and so it is already encrypted
+                    if isinstance(value, Binary):
+                        value = value.getvalue()
+                    else:
+                        value = crypt_password(value)
+                elif isinstance(value, Binary):
+                    value = self.binary(value.getvalue())
+            attrs[str(attr)] = value
+        return attrs
+
+
+from logging import getLogger
+from cubicweb import set_log_methods
+set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
+
+def init_sqlite_connexion(cnx):
+    # XXX should not be publicly exposed
+    #def comma_join(strings):
+    #    return ', '.join(strings)
+    #cnx.create_function("COMMA_JOIN", 1, comma_join)
+
+    class concat_strings(object):
+        def __init__(self):
+            self.values = []
+        def step(self, value):
+            if value is not None:
+                self.values.append(value)
+        def finalize(self):
+            return ', '.join(self.values)
+    # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
+    # some time
+    cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
+    cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
+    
+    def _limit_size(text, maxsize, format='text/plain'):
+        if len(text) < maxsize:
+            return text
+        if format in ('text/html', 'text/xhtml', 'text/xml'):
+            text = remove_html_tags(text)
+        if len(text) > maxsize:
+            text = text[:maxsize] + '...'
+        return text
+
+    def limit_size3(text, format, maxsize):
+        return _limit_size(text, maxsize, format)
+    cnx.create_function("LIMIT_SIZE", 3, limit_size3)
+
+    def limit_size2(text, maxsize):
+        return _limit_size(text, maxsize)
+    cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
+    import yams.constraints
+    if hasattr(yams.constraints, 'patch_sqlite_decimal'):
+        yams.constraints.patch_sqlite_decimal()
+
+
+sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
+sqlite_hooks.append(init_sqlite_connexion)
+
+def init_cnx(driver, cnx):
+    for hook in SQL_CONNECT_HOOKS.get(driver, ()):
+        hook(cnx)