--- a/server/sqlutils.py Mon Mar 08 09:51:29 2010 +0100
+++ b/server/sqlutils.py Mon Mar 08 17:57:29 2010 +0100
@@ -11,21 +11,17 @@
import subprocess
from datetime import datetime, date
-import logilab.common as lgc
-from logilab.common import db
+from logilab import db, common as lgc
from logilab.common.shellutils import ProgressBar
-from logilab.common.adbh import get_adv_func_helper
-from logilab.common.sqlgen import SQLGenerator
from logilab.common.date import todate, todatetime
-
-from indexer import get_indexer
+from logilab.db.sqlgen import SQLGenerator
from cubicweb import Binary, ConfigurationError
from cubicweb.uilib import remove_html_tags
from cubicweb.schema import PURE_VIRTUAL_RTYPES
from cubicweb.server import SQL_CONNECT_HOOKS
from cubicweb.server.utils import crypt_password
-
+from rql.utils import RQL_FUNCTIONS_REGISTRY
lgc.USE_MX_DATETIME = False
SQL_PREFIX = 'cw_'
@@ -77,8 +73,8 @@
w(native.grant_schema(user, set_owner))
w('')
if text_index:
- indexer = get_indexer(driver)
- w(indexer.sql_grant_user(user))
+ dbhelper = db.get_db_helper(driver)
+ w(dbhelper.sql_grant_user_on_fti(user))
w('')
w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
return '\n'.join(output)
@@ -96,11 +92,10 @@
w = output.append
w(native.sql_schema(driver))
w('')
+ dbhelper = db.get_db_helper(driver)
if text_index:
- indexer = get_indexer(driver)
- w(indexer.sql_init_fti())
+ w(dbhelper.sql_init_fti())
w('')
- dbhelper = get_adv_func_helper(driver)
w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
skip_entities=skip_entities, skip_relations=skip_relations))
if dbhelper.users_support and user:
@@ -120,8 +115,8 @@
w(native.sql_drop_schema(driver))
w('')
if text_index:
- indexer = get_indexer(driver)
- w(indexer.sql_drop_fti())
+ dbhelper = db.get_db_helper(driver)
+ w(dbhelper.sql_drop_fti())
w('')
w(dropschema2sql(schema, prefix=SQL_PREFIX,
skip_entities=skip_entities,
@@ -137,55 +132,42 @@
def __init__(self, source_config):
try:
self.dbdriver = source_config['db-driver'].lower()
- self.dbname = source_config['db-name']
+ dbname = source_config['db-name']
except KeyError:
raise ConfigurationError('missing some expected entries in sources file')
- self.dbhost = source_config.get('db-host')
+ dbhost = source_config.get('db-host')
port = source_config.get('db-port')
- self.dbport = port and int(port) or None
- self.dbuser = source_config.get('db-user')
- self.dbpasswd = source_config.get('db-password')
- self.encoding = source_config.get('db-encoding', 'UTF-8')
- self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver)
- self.dbdriver_extra_args = source_config.get('db-extra-arguments')
- self.binary = self.dbapi_module.Binary
- self.dbhelper = self.dbapi_module.adv_func_helper
+ dbport = port and int(port) or None
+ dbuser = source_config.get('db-user')
+ dbpassword = source_config.get('db-password')
+ dbencoding = source_config.get('db-encoding', 'UTF-8')
+ dbextraargs = source_config.get('db-extra-arguments')
+ self.dbhelper = db.get_db_helper(self.dbdriver)
+ self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser,
+ dbpassword, dbextraargs,
+ dbencoding)
self.sqlgen = SQLGenerator()
+ # copy back some commonly accessed attributes
+ dbapi_module = self.dbhelper.dbapi_module
+ self.OperationalError = dbapi_module.OperationalError
+ self.InterfaceError = dbapi_module.InterfaceError
+ self._binary = dbapi_module.Binary
+ self._process_value = dbapi_module.process_value
+ self._dbencoding = dbencoding
- def get_connection(self, user=None, password=None):
+ def get_connection(self):
"""open and return a connection to the database"""
- if user or self.dbuser:
- self.info('connecting to %s@%s for user %s', self.dbname,
- self.dbhost or 'localhost', user or self.dbuser)
- else:
- self.info('connecting to %s@%s', self.dbname,
- self.dbhost or 'localhost')
- extra = {}
- if self.dbdriver_extra_args:
- extra = {'extra_args': self.dbdriver_extra_args}
- cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
- user or self.dbuser,
- password or self.dbpasswd,
- port=self.dbport,
- **extra)
- init_cnx(self.dbdriver, cnx)
- #self.dbapi_module.type_code_test(cnx.cursor())
- return cnx
+ return self.dbhelper.get_connection()
def backup_to_file(self, backupfile):
- for cmd in self.dbhelper.backup_commands(self.dbname, self.dbhost,
- self.dbuser, backupfile,
- dbport=self.dbport,
+ for cmd in self.dbhelper.backup_commands(backupfile,
keepownership=False):
if _run_command(cmd):
if not confirm(' [Failed] Continue anyway?', default='n'):
raise Exception('Failed command: %s' % cmd)
def restore_from_file(self, backupfile, confirm, drop=True):
- for cmd in self.dbhelper.restore_commands(self.dbname, self.dbhost,
- self.dbuser, backupfile,
- self.encoding,
- dbport=self.dbport,
+ for cmd in self.dbhelper.restore_commands(backupfile,
keepownership=False,
drop=drop):
if _run_command(cmd):
@@ -198,7 +180,7 @@
for key, val in args.iteritems():
# convert cubicweb binary into db binary
if isinstance(val, Binary):
- val = self.binary(val.getvalue())
+ val = self._binary(val.getvalue())
newargs[key] = val
# should not collide
newargs.update(query_args)
@@ -208,10 +190,12 @@
def process_result(self, cursor):
"""return a list of CubicWeb compliant values from data in the given cursor
"""
+ # begin bind to locals for optimization
descr = cursor.description
- encoding = self.encoding
- process_value = self.dbapi_module.process_value
+ encoding = self._dbencoding
+ process_value = self._process_value
binary = Binary
+ # /end
results = cursor.fetchall()
for i, line in enumerate(results):
result = []
@@ -242,14 +226,14 @@
value = value.getvalue()
else:
value = crypt_password(value)
- value = self.binary(value)
+ value = self._binary(value)
# XXX needed for sqlite but I don't think it is for other backends
elif atype == 'Datetime' and isinstance(value, date):
value = todatetime(value)
elif atype == 'Date' and isinstance(value, datetime):
value = todate(value)
elif isinstance(value, Binary):
- value = self.binary(value.getvalue())
+ value = self._binary(value.getvalue())
attrs[SQL_PREFIX+str(attr)] = value
return attrs
@@ -259,12 +243,8 @@
set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
def init_sqlite_connexion(cnx):
- # XXX should not be publicly exposed
- #def comma_join(strings):
- # return ', '.join(strings)
- #cnx.create_function("COMMA_JOIN", 1, comma_join)
- class concat_strings(object):
+ class group_concat(object):
def __init__(self):
self.values = []
def step(self, value):
@@ -272,10 +252,7 @@
self.values.append(value)
def finalize(self):
return ', '.join(self.values)
- # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
- # some time
- cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
- cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
+ cnx.create_aggregate("GROUP_CONCAT", 1, group_concat)
def _limit_size(text, maxsize, format='text/plain'):
if len(text) < maxsize:
@@ -293,9 +270,9 @@
def limit_size2(text, maxsize):
return _limit_size(text, maxsize)
cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
+
import yams.constraints
- if hasattr(yams.constraints, 'patch_sqlite_decimal'):
- yams.constraints.patch_sqlite_decimal()
+ yams.constraints.patch_sqlite_decimal()
def fspath(eid, etype, attr):
try:
@@ -320,10 +297,5 @@
raise
cnx.create_function('_fsopen', 1, _fsopen)
-
sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
sqlite_hooks.append(init_sqlite_connexion)
-
-def init_cnx(driver, cnx):
- for hook in SQL_CONNECT_HOOKS.get(driver, ()):
- hook(cnx)