"""SQL utilities functions and classes.
:organization: Logilab
:copyright: 2001-2009 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""
__docformat__ = "restructuredtext en"
from warnings import warn
from datetime import datetime, date, timedelta
import logilab.common as lgc
from logilab.common.shellutils import ProgressBar
from logilab.common import db
from logilab.common.adbh import get_adv_func_helper
from logilab.common.sqlgen import SQLGenerator
from indexer import get_indexer
from cubicweb import Binary, ConfigurationError
from cubicweb.utils import todate, todatetime
from cubicweb.common.uilib import remove_html_tags
from cubicweb.server import SQL_CONNECT_HOOKS
from cubicweb.server.utils import crypt_password
lgc.USE_MX_DATETIME = False
SQL_PREFIX = 'cw_'
def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'):
"""execute sql statements ignoring DROP/ CREATE GROUP or USER statements
error. If a cnx is given, commit at each statement
"""
if hasattr(cursor_or_execute, 'execute'):
execute = cursor_or_execute.execute
else:
execute = cursor_or_execute
sqlstmts = sqlstmts.split(delimiter)
if withpb:
pb = ProgressBar(len(sqlstmts))
for sql in sqlstmts:
sql = sql.strip()
if withpb:
pb.update()
if not sql:
continue
# some dbapi modules doesn't accept unicode for sql string
execute(str(sql))
if withpb:
print
def sqlgrants(schema, driver, user,
text_index=True, set_owner=True,
skip_relations=(), skip_entities=()):
"""return sql to give all access privileges to the given user on the system
schema
"""
from yams.schema2sql import grant_schema
from cubicweb.server.sources import native
output = []
w = output.append
w(native.grant_schema(user, set_owner))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_grant_user(user))
w('')
w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
return '\n'.join(output)
def sqlschema(schema, driver, text_index=True,
user=None, set_owner=False,
skip_relations=('has_text', 'identity'), skip_entities=()):
"""return the system sql schema, according to the given parameters"""
from yams.schema2sql import schema2sql
from cubicweb.server.sources import native
if set_owner:
assert user, 'user is argument required when set_owner is true'
output = []
w = output.append
w(native.sql_schema(driver))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_init_fti())
w('')
dbhelper = get_adv_func_helper(driver)
w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
skip_entities=skip_entities, skip_relations=skip_relations))
if dbhelper.users_support and user:
w('')
w(sqlgrants(schema, driver, user, text_index, set_owner,
skip_relations, skip_entities))
return '\n'.join(output)
def sqldropschema(schema, driver, text_index=True,
skip_relations=('has_text', 'identity'), skip_entities=()):
"""return the sql to drop the schema, according to the given parameters"""
from yams.schema2sql import dropschema2sql
from cubicweb.server.sources import native
output = []
w = output.append
w(native.sql_drop_schema(driver))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_drop_fti())
w('')
w(dropschema2sql(schema, prefix=SQL_PREFIX,
skip_entities=skip_entities, skip_relations=skip_relations))
return '\n'.join(output)
try:
from mx.DateTime import DateTimeType, DateTimeDeltaType
except ImportError:
DateTimeType, DateTimeDeltaType = None
class SQLAdapterMixIn(object):
"""Mixin for SQL data sources, getting a connection from a configuration
dictionary and handling connection locking
"""
def __init__(self, source_config):
try:
self.dbdriver = source_config['db-driver'].lower()
self.dbname = source_config['db-name']
except KeyError:
raise ConfigurationError('missing some expected entries in sources file')
self.dbhost = source_config.get('db-host')
port = source_config.get('db-port')
self.dbport = port and int(port) or None
self.dbuser = source_config.get('db-user')
self.dbpasswd = source_config.get('db-password')
self.encoding = source_config.get('db-encoding', 'UTF-8')
self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver)
self.binary = self.dbapi_module.Binary
self.dbhelper = self.dbapi_module.adv_func_helper
self.sqlgen = SQLGenerator()
def get_connection(self, user=None, password=None):
"""open and return a connection to the database"""
if user or self.dbuser:
self.info('connecting to %s@%s for user %s', self.dbname,
self.dbhost or 'localhost', user or self.dbuser)
else:
self.info('connecting to %s@%s', self.dbname,
self.dbhost or 'localhost')
cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
user or self.dbuser,
password or self.dbpasswd,
port=self.dbport)
init_cnx(self.dbdriver, cnx)
#self.dbapi_module.type_code_test(cnx.cursor())
return cnx
def merge_args(self, args, query_args):
if args is not None:
args = dict(args)
for key, val in args.items():
# convert cubicweb binary into db binary
if isinstance(val, Binary):
val = self.binary(val.getvalue())
# XXX <3.2 bw compat
elif type(val) is DateTimeType:
warn('found mx date time instance, please update to use datetime',
DeprecationWarning)
val = datetime(val.year, val.month, val.day,
val.hour, val.minute, int(val.second))
elif type(val) is DateTimeDeltaType:
warn('found mx date time instance, please update to use datetime',
DeprecationWarning)
val = timedelta(0, int(val.seconds), 0)
args[key] = val
# should not collide
args.update(query_args)
return args
return query_args
def process_result(self, cursor):
"""return a list of CubicWeb compliant values from data in the given cursor
"""
descr = cursor.description
encoding = self.encoding
process_value = self.dbapi_module.process_value
binary = Binary
results = cursor.fetchall()
for i, line in enumerate(results):
result = []
for col, value in enumerate(line):
if value is None:
result.append(value)
continue
result.append(process_value(value, descr[col], encoding, binary))
results[i] = result
return results
def preprocess_entity(self, entity):
"""return a dictionary to use as extra argument to cursor.execute
to insert/update an entity into a SQL database
"""
attrs = {}
eschema = entity.e_schema
for attr, value in entity.items():
rschema = eschema.subject_relation(attr)
if rschema.is_final():
atype = str(entity.e_schema.destination(attr))
if atype == 'Boolean':
value = self.dbhelper.boolean_value(value)
elif atype == 'Password':
# if value is a Binary instance, this mean we got it
# from a query result and so it is already encrypted
if isinstance(value, Binary):
value = value.getvalue()
else:
value = crypt_password(value)
# XXX needed for sqlite but I don't think it is for other backends
elif atype == 'Datetime' and isinstance(value, date):
value = todatetime(value)
elif atype == 'Date' and isinstance(value, datetime):
value = todate(value)
elif isinstance(value, Binary):
value = self.binary(value.getvalue())
# XXX <3.2 bw compat
elif type(value) is DateTimeType:
warn('found mx date time instance, please update to use datetime',
DeprecationWarning)
value = datetime(value.year, value.month, value.day,
value.hour, value.minute, int(value.second))
elif type(value) is DateTimeDeltaType:
warn('found mx date time instance, please update to use datetime',
DeprecationWarning)
value = timedelta(0, int(value.seconds), 0)
attrs[SQL_PREFIX+str(attr)] = value
return attrs
from logging import getLogger
from cubicweb import set_log_methods
set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
def init_sqlite_connexion(cnx):
# XXX should not be publicly exposed
#def comma_join(strings):
# return ', '.join(strings)
#cnx.create_function("COMMA_JOIN", 1, comma_join)
class concat_strings(object):
def __init__(self):
self.values = []
def step(self, value):
if value is not None:
self.values.append(value)
def finalize(self):
return ', '.join(self.values)
# renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
# some time
cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
def _limit_size(text, maxsize, format='text/plain'):
if len(text) < maxsize:
return text
if format in ('text/html', 'text/xhtml', 'text/xml'):
text = remove_html_tags(text)
if len(text) > maxsize:
text = text[:maxsize] + '...'
return text
def limit_size3(text, format, maxsize):
return _limit_size(text, maxsize, format)
cnx.create_function("LIMIT_SIZE", 3, limit_size3)
def limit_size2(text, maxsize):
return _limit_size(text, maxsize)
cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
import yams.constraints
if hasattr(yams.constraints, 'patch_sqlite_decimal'):
yams.constraints.patch_sqlite_decimal()
sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
sqlite_hooks.append(init_sqlite_connexion)
def init_cnx(driver, cnx):
for hook in SQL_CONNECT_HOOKS.get(driver, ()):
hook(cnx)