server/sqlutils.py
changeset 4831 c5aec27c1bf7
parent 4719 aaed3f813ef8
child 4848 41f84eea63c9
equal deleted inserted replaced
4829:3b79a0fc91db 4831:c5aec27c1bf7
     9 
     9 
    10 import os
    10 import os
    11 import subprocess
    11 import subprocess
    12 from datetime import datetime, date
    12 from datetime import datetime, date
    13 
    13 
    14 import logilab.common as lgc
    14 from logilab import db, common as lgc
    15 from logilab.common import db
       
    16 from logilab.common.shellutils import ProgressBar
    15 from logilab.common.shellutils import ProgressBar
    17 from logilab.common.adbh import get_adv_func_helper
       
    18 from logilab.common.sqlgen import SQLGenerator
       
    19 from logilab.common.date import todate, todatetime
    16 from logilab.common.date import todate, todatetime
    20 
    17 from logilab.db.sqlgen import SQLGenerator
    21 from indexer import get_indexer
       
    22 
    18 
    23 from cubicweb import Binary, ConfigurationError
    19 from cubicweb import Binary, ConfigurationError
    24 from cubicweb.uilib import remove_html_tags
    20 from cubicweb.uilib import remove_html_tags
    25 from cubicweb.schema import PURE_VIRTUAL_RTYPES
    21 from cubicweb.schema import PURE_VIRTUAL_RTYPES
    26 from cubicweb.server import SQL_CONNECT_HOOKS
    22 from cubicweb.server import SQL_CONNECT_HOOKS
    27 from cubicweb.server.utils import crypt_password
    23 from cubicweb.server.utils import crypt_password
    28 
    24 from rql.utils import RQL_FUNCTIONS_REGISTRY
    29 
    25 
    30 lgc.USE_MX_DATETIME = False
    26 lgc.USE_MX_DATETIME = False
    31 SQL_PREFIX = 'cw_'
    27 SQL_PREFIX = 'cw_'
    32 
    28 
    33 def _run_command(cmd):
    29 def _run_command(cmd):
    75     output = []
    71     output = []
    76     w = output.append
    72     w = output.append
    77     w(native.grant_schema(user, set_owner))
    73     w(native.grant_schema(user, set_owner))
    78     w('')
    74     w('')
    79     if text_index:
    75     if text_index:
    80         indexer = get_indexer(driver)
    76         dbhelper = db.get_db_helper(driver)
    81         w(indexer.sql_grant_user(user))
    77         w(dbhelper.sql_grant_user_on_fti(user))
    82         w('')
    78         w('')
    83     w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
    79     w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
    84     return '\n'.join(output)
    80     return '\n'.join(output)
    85 
    81 
    86 
    82 
    94         assert user, 'user is argument required when set_owner is true'
    90         assert user, 'user is argument required when set_owner is true'
    95     output = []
    91     output = []
    96     w = output.append
    92     w = output.append
    97     w(native.sql_schema(driver))
    93     w(native.sql_schema(driver))
    98     w('')
    94     w('')
       
    95     dbhelper = db.get_db_helper(driver)
    99     if text_index:
    96     if text_index:
   100         indexer = get_indexer(driver)
    97         w(dbhelper.sql_init_fti())
   101         w(indexer.sql_init_fti())
       
   102         w('')
    98         w('')
   103     dbhelper = get_adv_func_helper(driver)
       
   104     w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
    99     w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
   105                  skip_entities=skip_entities, skip_relations=skip_relations))
   100                  skip_entities=skip_entities, skip_relations=skip_relations))
   106     if dbhelper.users_support and user:
   101     if dbhelper.users_support and user:
   107         w('')
   102         w('')
   108         w(sqlgrants(schema, driver, user, text_index, set_owner,
   103         w(sqlgrants(schema, driver, user, text_index, set_owner,
   118     output = []
   113     output = []
   119     w = output.append
   114     w = output.append
   120     w(native.sql_drop_schema(driver))
   115     w(native.sql_drop_schema(driver))
   121     w('')
   116     w('')
   122     if text_index:
   117     if text_index:
   123         indexer = get_indexer(driver)
   118         dbhelper = db.get_db_helper(driver)
   124         w(indexer.sql_drop_fti())
   119         w(dbhelper.sql_drop_fti())
   125         w('')
   120         w('')
   126     w(dropschema2sql(schema, prefix=SQL_PREFIX,
   121     w(dropschema2sql(schema, prefix=SQL_PREFIX,
   127                      skip_entities=skip_entities,
   122                      skip_entities=skip_entities,
   128                      skip_relations=skip_relations))
   123                      skip_relations=skip_relations))
   129     return '\n'.join(output)
   124     return '\n'.join(output)
   135     """
   130     """
   136 
   131 
   137     def __init__(self, source_config):
   132     def __init__(self, source_config):
   138         try:
   133         try:
   139             self.dbdriver = source_config['db-driver'].lower()
   134             self.dbdriver = source_config['db-driver'].lower()
   140             self.dbname = source_config['db-name']
   135             dbname = source_config['db-name']
   141         except KeyError:
   136         except KeyError:
   142             raise ConfigurationError('missing some expected entries in sources file')
   137             raise ConfigurationError('missing some expected entries in sources file')
   143         self.dbhost = source_config.get('db-host')
   138         dbhost = source_config.get('db-host')
   144         port = source_config.get('db-port')
   139         port = source_config.get('db-port')
   145         self.dbport = port and int(port) or None
   140         dbport = port and int(port) or None
   146         self.dbuser = source_config.get('db-user')
   141         dbuser = source_config.get('db-user')
   147         self.dbpasswd = source_config.get('db-password')
   142         dbpassword = source_config.get('db-password')
   148         self.encoding = source_config.get('db-encoding', 'UTF-8')
   143         dbencoding = source_config.get('db-encoding', 'UTF-8')
   149         self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver)
   144         dbextraargs = source_config.get('db-extra-arguments')
   150         self.dbdriver_extra_args = source_config.get('db-extra-arguments')
   145         self.dbhelper = db.get_db_helper(self.dbdriver)
   151         self.binary = self.dbapi_module.Binary
   146         self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser,
   152         self.dbhelper = self.dbapi_module.adv_func_helper
   147                                              dbpassword, dbextraargs,
       
   148                                              dbencoding)
   153         self.sqlgen = SQLGenerator()
   149         self.sqlgen = SQLGenerator()
   154 
   150         # copy back some commonly accessed attributes
   155     def get_connection(self, user=None, password=None):
   151         dbapi_module = self.dbhelper.dbapi_module
       
   152         self.OperationalError = dbapi_module.OperationalError
       
   153         self.InterfaceError = dbapi_module.InterfaceError
       
   154         self._binary = dbapi_module.Binary
       
   155         self._process_value = dbapi_module.process_value
       
   156         self._dbencoding = dbencoding
       
   157 
       
   158     def get_connection(self):
   156         """open and return a connection to the database"""
   159         """open and return a connection to the database"""
   157         if user or self.dbuser:
   160         return self.dbhelper.get_connection()
   158             self.info('connecting to %s@%s for user %s', self.dbname,
       
   159                       self.dbhost or 'localhost', user or self.dbuser)
       
   160         else:
       
   161             self.info('connecting to %s@%s', self.dbname,
       
   162                       self.dbhost or 'localhost')
       
   163         extra = {}
       
   164         if self.dbdriver_extra_args:
       
   165             extra = {'extra_args': self.dbdriver_extra_args}
       
   166         cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
       
   167                                         user or self.dbuser,
       
   168                                         password or self.dbpasswd,
       
   169                                         port=self.dbport,
       
   170                                         **extra)
       
   171         init_cnx(self.dbdriver, cnx)
       
   172         #self.dbapi_module.type_code_test(cnx.cursor())
       
   173         return cnx
       
   174 
   161 
   175     def backup_to_file(self, backupfile):
   162     def backup_to_file(self, backupfile):
   176         for cmd in self.dbhelper.backup_commands(self.dbname, self.dbhost,
   163         for cmd in self.dbhelper.backup_commands(backupfile,
   177                                                  self.dbuser, backupfile,
       
   178                                                  dbport=self.dbport,
       
   179                                                  keepownership=False):
   164                                                  keepownership=False):
   180             if _run_command(cmd):
   165             if _run_command(cmd):
   181                 if not confirm('   [Failed] Continue anyway?', default='n'):
   166                 if not confirm('   [Failed] Continue anyway?', default='n'):
   182                     raise Exception('Failed command: %s' % cmd)
   167                     raise Exception('Failed command: %s' % cmd)
   183 
   168 
   184     def restore_from_file(self, backupfile, confirm, drop=True):
   169     def restore_from_file(self, backupfile, confirm, drop=True):
   185         for cmd in self.dbhelper.restore_commands(self.dbname, self.dbhost,
   170         for cmd in self.dbhelper.restore_commands(backupfile,
   186                                                   self.dbuser, backupfile,
       
   187                                                   self.encoding,
       
   188                                                   dbport=self.dbport,
       
   189                                                   keepownership=False,
   171                                                   keepownership=False,
   190                                                   drop=drop):
   172                                                   drop=drop):
   191             if _run_command(cmd):
   173             if _run_command(cmd):
   192                 if not confirm('   [Failed] Continue anyway?', default='n'):
   174                 if not confirm('   [Failed] Continue anyway?', default='n'):
   193                     raise Exception('Failed command: %s' % cmd)
   175                     raise Exception('Failed command: %s' % cmd)
   196         if args is not None:
   178         if args is not None:
   197             newargs = {}
   179             newargs = {}
   198             for key, val in args.iteritems():
   180             for key, val in args.iteritems():
   199                 # convert cubicweb binary into db binary
   181                 # convert cubicweb binary into db binary
   200                 if isinstance(val, Binary):
   182                 if isinstance(val, Binary):
   201                     val = self.binary(val.getvalue())
   183                     val = self._binary(val.getvalue())
   202                 newargs[key] = val
   184                 newargs[key] = val
   203             # should not collide
   185             # should not collide
   204             newargs.update(query_args)
   186             newargs.update(query_args)
   205             return newargs
   187             return newargs
   206         return query_args
   188         return query_args
   207 
   189 
   208     def process_result(self, cursor):
   190     def process_result(self, cursor):
   209         """return a list of CubicWeb compliant values from data in the given cursor
   191         """return a list of CubicWeb compliant values from data in the given cursor
   210         """
   192         """
       
   193         # begin bind to locals for optimization
   211         descr = cursor.description
   194         descr = cursor.description
   212         encoding = self.encoding
   195         encoding = self._dbencoding
   213         process_value = self.dbapi_module.process_value
   196         process_value = self._process_value
   214         binary = Binary
   197         binary = Binary
       
   198         # /end
   215         results = cursor.fetchall()
   199         results = cursor.fetchall()
   216         for i, line in enumerate(results):
   200         for i, line in enumerate(results):
   217             result = []
   201             result = []
   218             for col, value in enumerate(line):
   202             for col, value in enumerate(line):
   219                 if value is None:
   203                 if value is None:
   240                     # from a query result and so it is already encrypted
   224                     # from a query result and so it is already encrypted
   241                     if isinstance(value, Binary):
   225                     if isinstance(value, Binary):
   242                         value = value.getvalue()
   226                         value = value.getvalue()
   243                     else:
   227                     else:
   244                         value = crypt_password(value)
   228                         value = crypt_password(value)
   245                     value = self.binary(value)
   229                     value = self._binary(value)
   246                 # XXX needed for sqlite but I don't think it is for other backends
   230                 # XXX needed for sqlite but I don't think it is for other backends
   247                 elif atype == 'Datetime' and isinstance(value, date):
   231                 elif atype == 'Datetime' and isinstance(value, date):
   248                     value = todatetime(value)
   232                     value = todatetime(value)
   249                 elif atype == 'Date' and isinstance(value, datetime):
   233                 elif atype == 'Date' and isinstance(value, datetime):
   250                     value = todate(value)
   234                     value = todate(value)
   251                 elif isinstance(value, Binary):
   235                 elif isinstance(value, Binary):
   252                     value = self.binary(value.getvalue())
   236                     value = self._binary(value.getvalue())
   253             attrs[SQL_PREFIX+str(attr)] = value
   237             attrs[SQL_PREFIX+str(attr)] = value
   254         return attrs
   238         return attrs
   255 
   239 
   256 
   240 
   257 from logging import getLogger
   241 from logging import getLogger
   258 from cubicweb import set_log_methods
   242 from cubicweb import set_log_methods
   259 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
   243 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
   260 
   244 
   261 def init_sqlite_connexion(cnx):
   245 def init_sqlite_connexion(cnx):
   262     # XXX should not be publicly exposed
   246 
   263     #def comma_join(strings):
   247     class group_concat(object):
   264     #    return ', '.join(strings)
       
   265     #cnx.create_function("COMMA_JOIN", 1, comma_join)
       
   266 
       
   267     class concat_strings(object):
       
   268         def __init__(self):
   248         def __init__(self):
   269             self.values = []
   249             self.values = []
   270         def step(self, value):
   250         def step(self, value):
   271             if value is not None:
   251             if value is not None:
   272                 self.values.append(value)
   252                 self.values.append(value)
   273         def finalize(self):
   253         def finalize(self):
   274             return ', '.join(self.values)
   254             return ', '.join(self.values)
   275     # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for
   255     cnx.create_aggregate("GROUP_CONCAT", 1, group_concat)
   276     # some time
       
   277     cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings)
       
   278     cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings)
       
   279 
   256 
   280     def _limit_size(text, maxsize, format='text/plain'):
   257     def _limit_size(text, maxsize, format='text/plain'):
   281         if len(text) < maxsize:
   258         if len(text) < maxsize:
   282             return text
   259             return text
   283         if format in ('text/html', 'text/xhtml', 'text/xml'):
   260         if format in ('text/html', 'text/xhtml', 'text/xml'):
   291     cnx.create_function("LIMIT_SIZE", 3, limit_size3)
   268     cnx.create_function("LIMIT_SIZE", 3, limit_size3)
   292 
   269 
   293     def limit_size2(text, maxsize):
   270     def limit_size2(text, maxsize):
   294         return _limit_size(text, maxsize)
   271         return _limit_size(text, maxsize)
   295     cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
   272     cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
       
   273 
   296     import yams.constraints
   274     import yams.constraints
   297     if hasattr(yams.constraints, 'patch_sqlite_decimal'):
   275     yams.constraints.patch_sqlite_decimal()
   298         yams.constraints.patch_sqlite_decimal()
       
   299 
   276 
   300     def fspath(eid, etype, attr):
   277     def fspath(eid, etype, attr):
   301         try:
   278         try:
   302             cu = cnx.cursor()
   279             cu = cnx.cursor()
   303             cu.execute('SELECT X.cw_%s FROM cw_%s as X '
   280             cu.execute('SELECT X.cw_%s FROM cw_%s as X '
   318                 import traceback
   295                 import traceback
   319                 traceback.print_exc()
   296                 traceback.print_exc()
   320                 raise
   297                 raise
   321     cnx.create_function('_fsopen', 1, _fsopen)
   298     cnx.create_function('_fsopen', 1, _fsopen)
   322 
   299 
   323 
       
   324 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
   300 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
   325 sqlite_hooks.append(init_sqlite_connexion)
   301 sqlite_hooks.append(init_sqlite_connexion)
   326 
       
   327 def init_cnx(driver, cnx):
       
   328     for hook in SQL_CONNECT_HOOKS.get(driver, ()):
       
   329         hook(cnx)