provide a new add_cubes() migration function for cases where the new cubes are linked together by new relations
In this case, we need to add all new cubes at once.
"""SQL utilities functions and classes.
:organization: Logilab
:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""
__docformat__ = "restructuredtext en"
from logilab.common.shellutils import ProgressBar
from logilab.common.db import get_dbapi_compliant_module
from logilab.common.adbh import get_adv_func_helper
from logilab.common.sqlgen import SQLGenerator
from indexer import get_indexer
from cubicweb import Binary, ConfigurationError
from cubicweb.common.uilib import remove_html_tags
from cubicweb.server import SQL_CONNECT_HOOKS
from cubicweb.server.utils import crypt_password, cartesian_product
def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'):
"""execute sql statements ignoring DROP/ CREATE GROUP or USER statements
error. If a cnx is given, commit at each statement
"""
if hasattr(cursor_or_execute, 'execute'):
execute = cursor_or_execute.execute
else:
execute = cursor_or_execute
sqlstmts = sqlstmts.split(delimiter)
if withpb:
pb = ProgressBar(len(sqlstmts))
for sql in sqlstmts:
sql = sql.strip()
if withpb:
pb.update()
if not sql:
continue
# some dbapi modules doesn't accept unicode for sql string
execute(str(sql))
if withpb:
print
def sqlgrants(schema, driver, user,
text_index=True, set_owner=True,
skip_relations=(), skip_entities=()):
"""return sql to give all access privileges to the given user on the system
schema
"""
from yams.schema2sql import grant_schema
from cubicweb.server.sources import native
output = []
w = output.append
w(native.grant_schema(user, set_owner))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_grant_user(user))
w('')
w(grant_schema(schema, user, set_owner, skip_entities=skip_entities))
return '\n'.join(output)
def sqlschema(schema, driver, text_index=True,
user=None, set_owner=False,
skip_relations=('has_text', 'identity'), skip_entities=()):
"""return the system sql schema, according to the given parameters"""
from yams.schema2sql import schema2sql
from cubicweb.server.sources import native
if set_owner:
assert user, 'user is argument required when set_owner is true'
output = []
w = output.append
w(native.sql_schema(driver))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_init_fti())
w('')
dbhelper = get_adv_func_helper(driver)
w(schema2sql(dbhelper, schema,
skip_entities=skip_entities, skip_relations=skip_relations))
if dbhelper.users_support and user:
w('')
w(sqlgrants(schema, driver, user, text_index, set_owner,
skip_relations, skip_entities))
return '\n'.join(output)
def sqldropschema(schema, driver, text_index=True,
skip_relations=('has_text', 'identity'), skip_entities=()):
"""return the sql to drop the schema, according to the given parameters"""
from yams.schema2sql import dropschema2sql
from cubicweb.server.sources import native
output = []
w = output.append
w(native.sql_drop_schema(driver))
w('')
if text_index:
indexer = get_indexer(driver)
w(indexer.sql_drop_fti())
w('')
w(dropschema2sql(schema,
skip_entities=skip_entities, skip_relations=skip_relations))
return '\n'.join(output)
class SQLAdapterMixIn(object):
"""Mixin for SQL data sources, getting a connection from a configuration
dictionary and handling connection locking
"""
def __init__(self, source_config):
try:
self.dbdriver = source_config['db-driver'].lower()
self.dbname = source_config['db-name']
except KeyError:
raise ConfigurationError('missing some expected entries in sources file')
self.dbhost = source_config.get('db-host')
port = source_config.get('db-port')
self.dbport = port and int(port) or None
self.dbuser = source_config.get('db-user')
self.dbpasswd = source_config.get('db-password')
self.encoding = source_config.get('db-encoding', 'UTF-8')
self.dbapi_module = get_dbapi_compliant_module(self.dbdriver)
self.binary = self.dbapi_module.Binary
self.dbhelper = self.dbapi_module.adv_func_helper
self.sqlgen = SQLGenerator()
def get_connection(self, user=None, password=None):
"""open and return a connection to the database"""
if user or self.dbuser:
self.info('connecting to %s@%s for user %s', self.dbname,
self.dbhost or 'localhost', user or self.dbuser)
else:
self.info('connecting to %s@%s', self.dbname,
self.dbhost or 'localhost')
cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
user or self.dbuser,
password or self.dbpasswd,
port=self.dbport)
init_cnx(self.dbdriver, cnx)
#self.dbapi_module.type_code_test(cnx.cursor())
return cnx
def merge_args(self, args, query_args):
if args is not None:
args = dict(args)
for key, val in args.items():
# convert cubicweb binary into db binary
if isinstance(val, Binary):
val = self.binary(val.getvalue())
args[key] = val
# should not collide
args.update(query_args)
return args
return query_args
def process_result(self, cursor):
"""return a list of CubicWeb compliant values from data in the given cursor
"""
descr = cursor.description
encoding = self.encoding
process_value = self.dbapi_module.process_value
binary = Binary
results = cursor.fetchall()
for i, line in enumerate(results):
result = []
for col, value in enumerate(line):
if value is None:
result.append(value)
continue
result.append(process_value(value, descr[col], encoding, binary))
results[i] = result
return results
def preprocess_entity(self, entity):
"""return a dictionary to use as extra argument to cursor.execute
to insert/update an entity
"""
attrs = {}
eschema = entity.e_schema
for attr, value in entity.items():
rschema = eschema.subject_relation(attr)
if rschema.is_final():
atype = str(entity.e_schema.destination(attr))
if atype == 'Boolean':
value = self.dbhelper.boolean_value(value)
elif atype == 'Password':
# if value is a Binary instance, this mean we got it
# from a query result and so it is already encrypted
if isinstance(value, Binary):
value = value.getvalue()
else:
value = crypt_password(value)
elif isinstance(value, Binary):
value = self.binary(value.getvalue())
attrs[str(attr)] = value
return attrs
from logging import getLogger
from cubicweb import set_log_methods
set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
def init_sqlite_connexion(cnx):
# XXX should not be publicly exposed
#def comma_join(strings):
# return ', '.join(strings)
#cnx.create_function("COMMA_JOIN", 1, comma_join)
class concat_strings(object):
def __init__(self):
self.values = []
def step(self, value):
if value is not None:
self.values.append(value)
def finalize(self):
return ', '.join(self.values)
# renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
# some time
cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
def _limit_size(text, maxsize, format='text/plain'):
if len(text) < maxsize:
return text
if format in ('text/html', 'text/xhtml', 'text/xml'):
text = remove_html_tags(text)
if len(text) > maxsize:
text = text[:maxsize] + '...'
return text
def limit_size3(text, format, maxsize):
return _limit_size(text, maxsize, format)
cnx.create_function("LIMIT_SIZE", 3, limit_size3)
def limit_size2(text, maxsize):
return _limit_size(text, maxsize)
cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
import yams.constraints
if hasattr(yams.constraints, 'patch_sqlite_decimal'):
yams.constraints.patch_sqlite_decimal()
sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
sqlite_hooks.append(init_sqlite_connexion)
def init_cnx(driver, cnx):
for hook in SQL_CONNECT_HOOKS.get(driver, ()):
hook(cnx)