server/sqlutils.py
changeset 0 b97547f5f1fa
child 1016 26387b836099
child 1251 af40e615dc89
equal deleted inserted replaced
-1:000000000000 0:b97547f5f1fa
       
     1 """SQL utilities functions and classes.
       
     2 
       
     3 :organization: Logilab
       
     4 :copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     6 """
       
     7 __docformat__ = "restructuredtext en"
       
     8 
       
     9 from logilab.common.shellutils import ProgressBar
       
    10 from logilab.common.db import get_dbapi_compliant_module
       
    11 from logilab.common.adbh import get_adv_func_helper
       
    12 from logilab.common.sqlgen import SQLGenerator
       
    13 
       
    14 from indexer import get_indexer
       
    15 
       
    16 from cubicweb import Binary, ConfigurationError
       
    17 from cubicweb.common.uilib import remove_html_tags
       
    18 from cubicweb.server import SQL_CONNECT_HOOKS
       
    19 from cubicweb.server.utils import crypt_password, cartesian_product
       
    20 
       
    21 
       
    22 def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'):
       
    23     """execute sql statements ignoring DROP/ CREATE GROUP or USER statements
       
    24     error. If a cnx is given, commit at each statement
       
    25     """
       
    26     if hasattr(cursor_or_execute, 'execute'):
       
    27         execute = cursor_or_execute.execute
       
    28     else:
       
    29         execute = cursor_or_execute
       
    30     sqlstmts = sqlstmts.split(delimiter)
       
    31     if withpb:
       
    32         pb = ProgressBar(len(sqlstmts))
       
    33     for sql in sqlstmts:
       
    34         sql = sql.strip()
       
    35         if withpb:
       
    36             pb.update()
       
    37         if not sql:
       
    38             continue
       
    39         # some dbapi modules doesn't accept unicode for sql string
       
    40         execute(str(sql))
       
    41     if withpb:
       
    42         print
       
    43 
       
    44 
       
    45 def sqlgrants(schema, driver, user,
       
    46               text_index=True, set_owner=True,
       
    47               skip_relations=(), skip_entities=()):
       
    48     """return sql to give all access privileges to the given user on the system
       
    49     schema
       
    50     """
       
    51     from yams.schema2sql import grant_schema
       
    52     from cubicweb.server.sources import native
       
    53     output = []
       
    54     w = output.append
       
    55     w(native.grant_schema(user, set_owner))
       
    56     w('')
       
    57     if text_index:
       
    58         indexer = get_indexer(driver)
       
    59         w(indexer.sql_grant_user(user))
       
    60         w('')
       
    61     w(grant_schema(schema, user, set_owner, skip_entities=skip_entities))
       
    62     return '\n'.join(output)
       
    63 
       
    64                   
       
    65 def sqlschema(schema, driver, text_index=True, 
       
    66               user=None, set_owner=False,
       
    67               skip_relations=('has_text', 'identity'), skip_entities=()):
       
    68     """return the system sql schema, according to the given parameters"""
       
    69     from yams.schema2sql import schema2sql
       
    70     from cubicweb.server.sources import native
       
    71     if set_owner:
       
    72         assert user, 'user is argument required when set_owner is true'
       
    73     output = []
       
    74     w = output.append
       
    75     w(native.sql_schema(driver))
       
    76     w('')
       
    77     if text_index:
       
    78         indexer = get_indexer(driver)
       
    79         w(indexer.sql_init_fti())
       
    80         w('')
       
    81     dbhelper = get_adv_func_helper(driver)
       
    82     w(schema2sql(dbhelper, schema, 
       
    83                  skip_entities=skip_entities, skip_relations=skip_relations))
       
    84     if dbhelper.users_support and user:
       
    85         w('')
       
    86         w(sqlgrants(schema, driver, user, text_index, set_owner,
       
    87                     skip_relations, skip_entities))
       
    88     return '\n'.join(output)
       
    89 
       
    90                   
       
    91 def sqldropschema(schema, driver, text_index=True, 
       
    92                   skip_relations=('has_text', 'identity'), skip_entities=()):
       
    93     """return the sql to drop the schema, according to the given parameters"""
       
    94     from yams.schema2sql import dropschema2sql
       
    95     from cubicweb.server.sources import native
       
    96     output = []
       
    97     w = output.append
       
    98     w(native.sql_drop_schema(driver))
       
    99     w('')
       
   100     if text_index:
       
   101         indexer = get_indexer(driver)
       
   102         w(indexer.sql_drop_fti())
       
   103         w('')
       
   104     w(dropschema2sql(schema,
       
   105                      skip_entities=skip_entities, skip_relations=skip_relations))
       
   106     return '\n'.join(output)
       
   107 
       
   108 
       
   109 
       
   110 class SQLAdapterMixIn(object):
       
   111     """Mixin for SQL data sources, getting a connection from a configuration
       
   112     dictionary and handling connection locking
       
   113     """
       
   114     
       
   115     def __init__(self, source_config):
       
   116         try:
       
   117             self.dbdriver = source_config['db-driver'].lower()
       
   118             self.dbname = source_config['db-name']
       
   119         except KeyError:
       
   120             raise ConfigurationError('missing some expected entries in sources file')
       
   121         self.dbhost = source_config.get('db-host')
       
   122         port = source_config.get('db-port')
       
   123         self.dbport = port and int(port) or None
       
   124         self.dbuser = source_config.get('db-user')
       
   125         self.dbpasswd = source_config.get('db-password')
       
   126         self.encoding = source_config.get('db-encoding', 'UTF-8')
       
   127         self.dbapi_module = get_dbapi_compliant_module(self.dbdriver)
       
   128         self.binary = self.dbapi_module.Binary
       
   129         self.dbhelper = self.dbapi_module.adv_func_helper
       
   130         self.sqlgen = SQLGenerator()
       
   131         
       
   132     def get_connection(self, user=None, password=None):
       
   133         """open and return a connection to the database"""
       
   134         if user or self.dbuser:
       
   135             self.info('connecting to %s@%s for user %s', self.dbname,
       
   136                       self.dbhost or 'localhost', user or self.dbuser)
       
   137         else:
       
   138             self.info('connecting to %s@%s', self.dbname,
       
   139                       self.dbhost or 'localhost')
       
   140         cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
       
   141                                         user or self.dbuser,
       
   142                                         password or self.dbpasswd,
       
   143                                         port=self.dbport)
       
   144         init_cnx(self.dbdriver, cnx)
       
   145         #self.dbapi_module.type_code_test(cnx.cursor())
       
   146         return cnx
       
   147 
       
   148     def merge_args(self, args, query_args):
       
   149         if args is not None:
       
   150             args = dict(args)
       
   151             for key, val in args.items():
       
   152                 # convert cubicweb binary into db binary
       
   153                 if isinstance(val, Binary):
       
   154                     val = self.binary(val.getvalue())
       
   155                 args[key] = val
       
   156             # should not collide
       
   157             args.update(query_args)
       
   158             return args
       
   159         return query_args
       
   160 
       
   161     def process_result(self, cursor):
       
   162         """return a list of CubicWeb compliant values from data in the given cursor
       
   163         """
       
   164         descr = cursor.description
       
   165         encoding = self.encoding
       
   166         process_value = self.dbapi_module.process_value
       
   167         binary = Binary
       
   168         results = cursor.fetchall()
       
   169         for i, line in enumerate(results):
       
   170             result = []
       
   171             for col, value in enumerate(line):
       
   172                 if value is None:
       
   173                     result.append(value)
       
   174                     continue
       
   175                 result.append(process_value(value, descr[col], encoding, binary))
       
   176             results[i] = result
       
   177         return results
       
   178 
       
   179 
       
   180     def preprocess_entity(self, entity):
       
   181         """return a dictionary to use as extra argument to cursor.execute
       
   182         to insert/update an entity
       
   183         """
       
   184         attrs = {}
       
   185         eschema = entity.e_schema
       
   186         for attr, value in entity.items():
       
   187             rschema = eschema.subject_relation(attr)
       
   188             if rschema.is_final():
       
   189                 atype = str(entity.e_schema.destination(attr))
       
   190                 if atype == 'Boolean':
       
   191                     value = self.dbhelper.boolean_value(value)
       
   192                 elif atype == 'Password':
       
   193                     # if value is a Binary instance, this mean we got it
       
   194                     # from a query result and so it is already encrypted
       
   195                     if isinstance(value, Binary):
       
   196                         value = value.getvalue()
       
   197                     else:
       
   198                         value = crypt_password(value)
       
   199                 elif isinstance(value, Binary):
       
   200                     value = self.binary(value.getvalue())
       
   201             attrs[str(attr)] = value
       
   202         return attrs
       
   203 
       
   204 
       
   205 from logging import getLogger
       
   206 from cubicweb import set_log_methods
       
   207 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
       
   208 
       
   209 def init_sqlite_connexion(cnx):
       
   210     # XXX should not be publicly exposed
       
   211     #def comma_join(strings):
       
   212     #    return ', '.join(strings)
       
   213     #cnx.create_function("COMMA_JOIN", 1, comma_join)
       
   214 
       
   215     class concat_strings(object):
       
   216         def __init__(self):
       
   217             self.values = []
       
   218         def step(self, value):
       
   219             if value is not None:
       
   220                 self.values.append(value)
       
   221         def finalize(self):
       
   222             return ', '.join(self.values)
       
   223     # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
       
   224     # some time
       
   225     cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
       
   226     cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
       
   227     
       
   228     def _limit_size(text, maxsize, format='text/plain'):
       
   229         if len(text) < maxsize:
       
   230             return text
       
   231         if format in ('text/html', 'text/xhtml', 'text/xml'):
       
   232             text = remove_html_tags(text)
       
   233         if len(text) > maxsize:
       
   234             text = text[:maxsize] + '...'
       
   235         return text
       
   236 
       
   237     def limit_size3(text, format, maxsize):
       
   238         return _limit_size(text, maxsize, format)
       
   239     cnx.create_function("LIMIT_SIZE", 3, limit_size3)
       
   240 
       
   241     def limit_size2(text, maxsize):
       
   242         return _limit_size(text, maxsize)
       
   243     cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
       
   244     import yams.constraints
       
   245     if hasattr(yams.constraints, 'patch_sqlite_decimal'):
       
   246         yams.constraints.patch_sqlite_decimal()
       
   247 
       
   248 
       
   249 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
       
   250 sqlite_hooks.append(init_sqlite_connexion)
       
   251 
       
   252 def init_cnx(driver, cnx):
       
   253     for hook in SQL_CONNECT_HOOKS.get(driver, ()):
       
   254         hook(cnx)