server/sqlutils.py
changeset 4831 c5aec27c1bf7
parent 4719 aaed3f813ef8
child 4848 41f84eea63c9
--- a/server/sqlutils.py	Mon Mar 08 09:51:29 2010 +0100
+++ b/server/sqlutils.py	Mon Mar 08 17:57:29 2010 +0100
@@ -11,21 +11,17 @@
 import subprocess
 from datetime import datetime, date
 
-import logilab.common as lgc
-from logilab.common import db
+from logilab import db, common as lgc
 from logilab.common.shellutils import ProgressBar
-from logilab.common.adbh import get_adv_func_helper
-from logilab.common.sqlgen import SQLGenerator
 from logilab.common.date import todate, todatetime
-
-from indexer import get_indexer
+from logilab.db.sqlgen import SQLGenerator
 
 from cubicweb import Binary, ConfigurationError
 from cubicweb.uilib import remove_html_tags
 from cubicweb.schema import PURE_VIRTUAL_RTYPES
 from cubicweb.server import SQL_CONNECT_HOOKS
 from cubicweb.server.utils import crypt_password
-
+from rql.utils import RQL_FUNCTIONS_REGISTRY
 
 lgc.USE_MX_DATETIME = False
 SQL_PREFIX = 'cw_'
@@ -77,8 +73,8 @@
     w(native.grant_schema(user, set_owner))
     w('')
     if text_index:
-        indexer = get_indexer(driver)
-        w(indexer.sql_grant_user(user))
+        dbhelper = db.get_db_helper(driver)
+        w(dbhelper.sql_grant_user_on_fti(user))
         w('')
     w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
     return '\n'.join(output)
@@ -96,11 +92,10 @@
     w = output.append
     w(native.sql_schema(driver))
     w('')
+    dbhelper = db.get_db_helper(driver)
     if text_index:
-        indexer = get_indexer(driver)
-        w(indexer.sql_init_fti())
+        w(dbhelper.sql_init_fti())
         w('')
-    dbhelper = get_adv_func_helper(driver)
     w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
                  skip_entities=skip_entities, skip_relations=skip_relations))
     if dbhelper.users_support and user:
@@ -120,8 +115,8 @@
     w(native.sql_drop_schema(driver))
     w('')
     if text_index:
-        indexer = get_indexer(driver)
-        w(indexer.sql_drop_fti())
+        dbhelper = db.get_db_helper(driver)
+        w(dbhelper.sql_drop_fti())
         w('')
     w(dropschema2sql(schema, prefix=SQL_PREFIX,
                      skip_entities=skip_entities,
@@ -137,55 +132,42 @@
     def __init__(self, source_config):
         try:
             self.dbdriver = source_config['db-driver'].lower()
-            self.dbname = source_config['db-name']
+            dbname = source_config['db-name']
         except KeyError:
             raise ConfigurationError('missing some expected entries in sources file')
-        self.dbhost = source_config.get('db-host')
+        dbhost = source_config.get('db-host')
         port = source_config.get('db-port')
-        self.dbport = port and int(port) or None
-        self.dbuser = source_config.get('db-user')
-        self.dbpasswd = source_config.get('db-password')
-        self.encoding = source_config.get('db-encoding', 'UTF-8')
-        self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver)
-        self.dbdriver_extra_args = source_config.get('db-extra-arguments')
-        self.binary = self.dbapi_module.Binary
-        self.dbhelper = self.dbapi_module.adv_func_helper
+        dbport = port and int(port) or None
+        dbuser = source_config.get('db-user')
+        dbpassword = source_config.get('db-password')
+        dbencoding = source_config.get('db-encoding', 'UTF-8')
+        dbextraargs = source_config.get('db-extra-arguments')
+        self.dbhelper = db.get_db_helper(self.dbdriver)
+        self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser,
+                                             dbpassword, dbextraargs,
+                                             dbencoding)
         self.sqlgen = SQLGenerator()
+        # copy back some commonly accessed attributes
+        dbapi_module = self.dbhelper.dbapi_module
+        self.OperationalError = dbapi_module.OperationalError
+        self.InterfaceError = dbapi_module.InterfaceError
+        self._binary = dbapi_module.Binary
+        self._process_value = dbapi_module.process_value
+        self._dbencoding = dbencoding
 
-    def get_connection(self, user=None, password=None):
+    def get_connection(self):
         """open and return a connection to the database"""
-        if user or self.dbuser:
-            self.info('connecting to %s@%s for user %s', self.dbname,
-                      self.dbhost or 'localhost', user or self.dbuser)
-        else:
-            self.info('connecting to %s@%s', self.dbname,
-                      self.dbhost or 'localhost')
-        extra = {}
-        if self.dbdriver_extra_args:
-            extra = {'extra_args': self.dbdriver_extra_args}
-        cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
-                                        user or self.dbuser,
-                                        password or self.dbpasswd,
-                                        port=self.dbport,
-                                        **extra)
-        init_cnx(self.dbdriver, cnx)
-        #self.dbapi_module.type_code_test(cnx.cursor())
-        return cnx
+        return self.dbhelper.get_connection()
 
     def backup_to_file(self, backupfile):
-        for cmd in self.dbhelper.backup_commands(self.dbname, self.dbhost,
-                                                 self.dbuser, backupfile,
-                                                 dbport=self.dbport,
+        for cmd in self.dbhelper.backup_commands(backupfile,
                                                  keepownership=False):
             if _run_command(cmd):
                 if not confirm('   [Failed] Continue anyway?', default='n'):
                     raise Exception('Failed command: %s' % cmd)
 
     def restore_from_file(self, backupfile, confirm, drop=True):
-        for cmd in self.dbhelper.restore_commands(self.dbname, self.dbhost,
-                                                  self.dbuser, backupfile,
-                                                  self.encoding,
-                                                  dbport=self.dbport,
+        for cmd in self.dbhelper.restore_commands(backupfile,
                                                   keepownership=False,
                                                   drop=drop):
             if _run_command(cmd):
@@ -198,7 +180,7 @@
             for key, val in args.iteritems():
                 # convert cubicweb binary into db binary
                 if isinstance(val, Binary):
-                    val = self.binary(val.getvalue())
+                    val = self._binary(val.getvalue())
                 newargs[key] = val
             # should not collide
             newargs.update(query_args)
@@ -208,10 +190,12 @@
     def process_result(self, cursor):
         """return a list of CubicWeb compliant values from data in the given cursor
         """
+        # begin bind to locals for optimization
         descr = cursor.description
-        encoding = self.encoding
-        process_value = self.dbapi_module.process_value
+        encoding = self._dbencoding
+        process_value = self._process_value
         binary = Binary
+        # /end
         results = cursor.fetchall()
         for i, line in enumerate(results):
             result = []
@@ -242,14 +226,14 @@
                         value = value.getvalue()
                     else:
                         value = crypt_password(value)
-                    value = self.binary(value)
+                    value = self._binary(value)
                 # XXX needed for sqlite but I don't think it is for other backends
                 elif atype == 'Datetime' and isinstance(value, date):
                     value = todatetime(value)
                 elif atype == 'Date' and isinstance(value, datetime):
                     value = todate(value)
                 elif isinstance(value, Binary):
-                    value = self.binary(value.getvalue())
+                    value = self._binary(value.getvalue())
             attrs[SQL_PREFIX+str(attr)] = value
         return attrs
 
@@ -259,12 +243,8 @@
 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
 
 def init_sqlite_connexion(cnx):
-    # XXX should not be publicly exposed
-    #def comma_join(strings):
-    #    return ', '.join(strings)
-    #cnx.create_function("COMMA_JOIN", 1, comma_join)
 
-    class concat_strings(object):
+    class group_concat(object):
         def __init__(self):
             self.values = []
         def step(self, value):
@@ -272,10 +252,7 @@
                 self.values.append(value)
         def finalize(self):
             return ', '.join(self.values)
-    # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
-    # some time
-    cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
-    cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
+    cnx.create_aggregate("GROUP_CONCAT", 1, group_concat)
 
     def _limit_size(text, maxsize, format='text/plain'):
         if len(text) < maxsize:
@@ -293,9 +270,9 @@
     def limit_size2(text, maxsize):
         return _limit_size(text, maxsize)
     cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
+
     import yams.constraints
-    if hasattr(yams.constraints, 'patch_sqlite_decimal'):
-        yams.constraints.patch_sqlite_decimal()
+    yams.constraints.patch_sqlite_decimal()
 
     def fspath(eid, etype, attr):
         try:
@@ -320,10 +297,5 @@
                 raise
     cnx.create_function('_fsopen', 1, _fsopen)
 
-
 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
 sqlite_hooks.append(init_sqlite_connexion)
-
-def init_cnx(driver, cnx):
-    for hook in SQL_CONNECT_HOOKS.get(driver, ()):
-        hook(cnx)