server/sqlutils.py
changeset 11057 0b59724cb3f2
parent 11052 058bb3dc685f
child 11058 23eb30449fe5
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
     1 # copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """SQL utilities functions and classes."""
       
    19 from __future__ import print_function
       
    20 
       
    21 __docformat__ = "restructuredtext en"
       
    22 
       
    23 import sys
       
    24 import re
       
    25 import subprocess
       
    26 from os.path import abspath
       
    27 from logging import getLogger
       
    28 from datetime import time, datetime, timedelta
       
    29 
       
    30 from six import string_types, text_type
       
    31 from six.moves import filter
       
    32 
       
    33 from pytz import utc
       
    34 
       
    35 from logilab import database as db, common as lgc
       
    36 from logilab.common.shellutils import ProgressBar, DummyProgressBar
       
    37 from logilab.common.deprecation import deprecated
       
    38 from logilab.common.logging_ext import set_log_methods
       
    39 from logilab.common.date import utctime, utcdatetime, strptime
       
    40 from logilab.database.sqlgen import SQLGenerator
       
    41 
       
    42 from cubicweb import Binary, ConfigurationError
       
    43 from cubicweb.uilib import remove_html_tags
       
    44 from cubicweb.schema import PURE_VIRTUAL_RTYPES
       
    45 from cubicweb.server import SQL_CONNECT_HOOKS
       
    46 from cubicweb.server.utils import crypt_password
       
    47 
       
    48 lgc.USE_MX_DATETIME = False
       
    49 SQL_PREFIX = 'cw_'
       
    50 
       
    51 
       
    52 def _run_command(cmd):
       
    53     if isinstance(cmd, string_types):
       
    54         print(cmd)
       
    55         return subprocess.call(cmd, shell=True)
       
    56     else:
       
    57         print(' '.join(cmd))
       
    58         return subprocess.call(cmd)
       
    59 
       
    60 
       
    61 def sqlexec(sqlstmts, cursor_or_execute, withpb=True,
       
    62             pbtitle='', delimiter=';', cnx=None):
       
    63     """execute sql statements ignoring DROP/ CREATE GROUP or USER statements
       
    64     error.
       
    65 
       
    66     :sqlstmts_as_string: a string or a list of sql statements.
       
    67     :cursor_or_execute: sql cursor or a callback used to execute statements
       
    68     :cnx: if given, commit/rollback at each statement.
       
    69 
       
    70     :withpb: if True, display a progresse bar
       
    71     :pbtitle: a string displayed as the progress bar title (if `withpb=True`)
       
    72 
       
    73     :delimiter: a string used to split sqlstmts (if it is a string)
       
    74 
       
    75     Return the failed statements (same type as sqlstmts)
       
    76     """
       
    77     if hasattr(cursor_or_execute, 'execute'):
       
    78         execute = cursor_or_execute.execute
       
    79     else:
       
    80         execute = cursor_or_execute
       
    81     sqlstmts_as_string = False
       
    82     if isinstance(sqlstmts, string_types):
       
    83         sqlstmts_as_string = True
       
    84         sqlstmts = sqlstmts.split(delimiter)
       
    85     if withpb:
       
    86         if sys.stdout.isatty():
       
    87             pb = ProgressBar(len(sqlstmts), title=pbtitle)
       
    88         else:
       
    89             pb = DummyProgressBar()
       
    90     failed = []
       
    91     for sql in sqlstmts:
       
    92         sql = sql.strip()
       
    93         if withpb:
       
    94             pb.update()
       
    95         if not sql:
       
    96             continue
       
    97         try:
       
    98             # some dbapi modules doesn't accept unicode for sql string
       
    99             execute(str(sql))
       
   100         except Exception:
       
   101             if cnx:
       
   102                 cnx.rollback()
       
   103             failed.append(sql)
       
   104         else:
       
   105             if cnx:
       
   106                 cnx.commit()
       
   107     if withpb:
       
   108         print()
       
   109     if sqlstmts_as_string:
       
   110         failed = delimiter.join(failed)
       
   111     return failed
       
   112 
       
   113 
       
   114 def sqlgrants(schema, driver, user,
       
   115               text_index=True, set_owner=True,
       
   116               skip_relations=(), skip_entities=()):
       
   117     """return sql to give all access privileges to the given user on the system
       
   118     schema
       
   119     """
       
   120     from cubicweb.server.schema2sql import grant_schema
       
   121     from cubicweb.server.sources import native
       
   122     output = []
       
   123     w = output.append
       
   124     w(native.grant_schema(user, set_owner))
       
   125     w('')
       
   126     if text_index:
       
   127         dbhelper = db.get_db_helper(driver)
       
   128         w(dbhelper.sql_grant_user_on_fti(user))
       
   129         w('')
       
   130     w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX))
       
   131     return '\n'.join(output)
       
   132 
       
   133 
       
   134 def sqlschema(schema, driver, text_index=True,
       
   135               user=None, set_owner=False,
       
   136               skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()):
       
   137     """return the system sql schema, according to the given parameters"""
       
   138     from cubicweb.server.schema2sql import schema2sql
       
   139     from cubicweb.server.sources import native
       
   140     if set_owner:
       
   141         assert user, 'user is argument required when set_owner is true'
       
   142     output = []
       
   143     w = output.append
       
   144     w(native.sql_schema(driver))
       
   145     w('')
       
   146     dbhelper = db.get_db_helper(driver)
       
   147     if text_index:
       
   148         w(dbhelper.sql_init_fti().replace(';', ';;'))
       
   149         w('')
       
   150     w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX,
       
   151                  skip_entities=skip_entities,
       
   152                  skip_relations=skip_relations).replace(';', ';;'))
       
   153     if dbhelper.users_support and user:
       
   154         w('')
       
   155         w(sqlgrants(schema, driver, user, text_index, set_owner,
       
   156                     skip_relations, skip_entities).replace(';', ';;'))
       
   157     return '\n'.join(output)
       
   158 
       
   159 
       
   160 def sqldropschema(schema, driver, text_index=True,
       
   161                   skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()):
       
   162     """return the sql to drop the schema, according to the given parameters"""
       
   163     from cubicweb.server.schema2sql import dropschema2sql
       
   164     from cubicweb.server.sources import native
       
   165     output = []
       
   166     w = output.append
       
   167     if text_index:
       
   168         dbhelper = db.get_db_helper(driver)
       
   169         w(dbhelper.sql_drop_fti())
       
   170         w('')
       
   171     w(dropschema2sql(dbhelper, schema, prefix=SQL_PREFIX,
       
   172                      skip_entities=skip_entities,
       
   173                      skip_relations=skip_relations))
       
   174     w('')
       
   175     w(native.sql_drop_schema(driver))
       
   176     return '\n'.join(output)
       
   177 
       
   178 
       
   179 _SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match
       
   180 def sql_drop_all_user_tables(driver_or_helper, sqlcursor):
       
   181     """Return ths sql to drop all tables found in the database system."""
       
   182     if not getattr(driver_or_helper, 'list_tables', None):
       
   183         dbhelper = db.get_db_helper(driver_or_helper)
       
   184     else:
       
   185         dbhelper = driver_or_helper
       
   186 
       
   187     cmds = [dbhelper.sql_drop_sequence('entities_id_seq')]
       
   188     # for mssql, we need to drop views before tables
       
   189     if hasattr(dbhelper, 'list_views'):
       
   190         cmds += ['DROP VIEW %s;' % name
       
   191                  for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))]
       
   192     cmds += ['DROP TABLE %s;' % name
       
   193              for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))]
       
   194     return '\n'.join(cmds)
       
   195 
       
   196 
       
   197 class ConnectionWrapper(object):
       
   198     """handle connection to the system source, at some point associated to a
       
   199     :class:`Session`
       
   200     """
       
   201 
       
   202     # since 3.19, we only have to manage the system source connection
       
   203     def __init__(self, system_source):
       
   204         # dictionary of (source, connection), indexed by sources'uri
       
   205         self._source = system_source
       
   206         self.cnx = system_source.get_connection()
       
   207         self.cu = self.cnx.cursor()
       
   208 
       
   209     def commit(self):
       
   210         """commit the current transaction for this user"""
       
   211         # let exception propagates
       
   212         self.cnx.commit()
       
   213 
       
   214     def rollback(self):
       
   215         """rollback the current transaction for this user"""
       
   216         # catch exceptions, rollback other sources anyway
       
   217         try:
       
   218             self.cnx.rollback()
       
   219         except Exception:
       
   220             self._source.critical('rollback error', exc_info=sys.exc_info())
       
   221             # error on rollback, the connection is much probably in a really
       
   222             # bad state. Replace it by a new one.
       
   223             self.reconnect()
       
   224 
       
   225     def close(self, i_know_what_i_do=False):
       
   226         """close all connections in the set"""
       
   227         if i_know_what_i_do is not True: # unexpected closing safety belt
       
   228             raise RuntimeError('connections set shouldn\'t be closed')
       
   229         try:
       
   230             self.cu.close()
       
   231             self.cu = None
       
   232         except Exception:
       
   233             pass
       
   234         try:
       
   235             self.cnx.close()
       
   236             self.cnx = None
       
   237         except Exception:
       
   238             pass
       
   239 
       
   240     # internals ###############################################################
       
   241 
       
   242     def cnxset_freed(self):
       
   243         """connections set is being freed from a session"""
       
   244         pass # no nothing by default
       
   245 
       
   246     def reconnect(self):
       
   247         """reopen a connection for this source or all sources if none specified
       
   248         """
       
   249         try:
       
   250             # properly close existing connection if any
       
   251             self.cnx.close()
       
   252         except Exception:
       
   253             pass
       
   254         self._source.info('trying to reconnect')
       
   255         self.cnx = self._source.get_connection()
       
   256         self.cu = self.cnx.cursor()
       
   257 
       
   258     @deprecated('[3.19] use .cu instead')
       
   259     def __getitem__(self, uri):
       
   260         assert uri == 'system'
       
   261         return self.cu
       
   262 
       
   263     @deprecated('[3.19] use repo.system_source instead')
       
   264     def source(self, uid):
       
   265         assert uid == 'system'
       
   266         return self._source
       
   267 
       
   268     @deprecated('[3.19] use .cnx instead')
       
   269     def connection(self, uid):
       
   270         assert uid == 'system'
       
   271         return self.cnx
       
   272 
       
   273 
       
   274 class SqliteConnectionWrapper(ConnectionWrapper):
       
   275     """Sqlite specific connection wrapper: close the connection each time it's
       
   276     freed (and reopen it later when needed)
       
   277     """
       
   278     def __init__(self, system_source):
       
   279         # don't call parent's __init__, we don't want to initiate the connection
       
   280         self._source = system_source
       
   281 
       
   282     _cnx = None
       
   283 
       
   284     def cnxset_freed(self):
       
   285         self.cu.close()
       
   286         self.cnx.close()
       
   287         self.cnx = self.cu = None
       
   288 
       
   289     @property
       
   290     def cnx(self):
       
   291         if self._cnx is None:
       
   292             self._cnx = self._source.get_connection()
       
   293             self._cu = self._cnx.cursor()
       
   294         return self._cnx
       
   295     @cnx.setter
       
   296     def cnx(self, value):
       
   297         self._cnx = value
       
   298 
       
   299     @property
       
   300     def cu(self):
       
   301         if self._cnx is None:
       
   302             self._cnx = self._source.get_connection()
       
   303             self._cu = self._cnx.cursor()
       
   304         return self._cu
       
   305     @cu.setter
       
   306     def cu(self, value):
       
   307         self._cu = value
       
   308 
       
   309 
       
   310 class SQLAdapterMixIn(object):
       
   311     """Mixin for SQL data sources, getting a connection from a configuration
       
   312     dictionary and handling connection locking
       
   313     """
       
   314     cnx_wrap = ConnectionWrapper
       
   315 
       
   316     def __init__(self, source_config, repairing=False):
       
   317         try:
       
   318             self.dbdriver = source_config['db-driver'].lower()
       
   319             dbname = source_config['db-name']
       
   320         except KeyError:
       
   321             raise ConfigurationError('missing some expected entries in sources file')
       
   322         dbhost = source_config.get('db-host')
       
   323         port = source_config.get('db-port')
       
   324         dbport = port and int(port) or None
       
   325         dbuser = source_config.get('db-user')
       
   326         dbpassword = source_config.get('db-password')
       
   327         dbencoding = source_config.get('db-encoding', 'UTF-8')
       
   328         dbextraargs = source_config.get('db-extra-arguments')
       
   329         dbnamespace = source_config.get('db-namespace')
       
   330         self.dbhelper = db.get_db_helper(self.dbdriver)
       
   331         self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser,
       
   332                                              dbpassword, dbextraargs,
       
   333                                              dbencoding, dbnamespace)
       
   334         self.sqlgen = SQLGenerator()
       
   335         # copy back some commonly accessed attributes
       
   336         dbapi_module = self.dbhelper.dbapi_module
       
   337         self.OperationalError = dbapi_module.OperationalError
       
   338         self.InterfaceError = dbapi_module.InterfaceError
       
   339         self.DbapiError = dbapi_module.Error
       
   340         self._binary = self.dbhelper.binary_value
       
   341         self._process_value = dbapi_module.process_value
       
   342         self._dbencoding = dbencoding
       
   343         if self.dbdriver == 'sqlite':
       
   344             self.cnx_wrap = SqliteConnectionWrapper
       
   345             self.dbhelper.dbname = abspath(self.dbhelper.dbname)
       
   346         if not repairing:
       
   347             statement_timeout = int(source_config.get('db-statement-timeout', 0))
       
   348             if statement_timeout > 0:
       
   349                 def set_postgres_timeout(cnx):
       
   350                     cnx.cursor().execute('SET statement_timeout to %d' % statement_timeout)
       
   351                     cnx.commit()
       
   352                 postgres_hooks = SQL_CONNECT_HOOKS['postgres']
       
   353                 postgres_hooks.append(set_postgres_timeout)
       
   354 
       
   355     def wrapped_connection(self):
       
   356         """open and return a connection to the database, wrapped into a class
       
   357         handling reconnection and all
       
   358         """
       
   359         return self.cnx_wrap(self)
       
   360 
       
   361     def get_connection(self):
       
   362         """open and return a connection to the database"""
       
   363         return self.dbhelper.get_connection()
       
   364 
       
   365     def backup_to_file(self, backupfile, confirm):
       
   366         for cmd in self.dbhelper.backup_commands(backupfile,
       
   367                                                  keepownership=False):
       
   368             if _run_command(cmd):
       
   369                 if not confirm('   [Failed] Continue anyway?', default='n'):
       
   370                     raise Exception('Failed command: %s' % cmd)
       
   371 
       
   372     def restore_from_file(self, backupfile, confirm, drop=True):
       
   373         for cmd in self.dbhelper.restore_commands(backupfile,
       
   374                                                   keepownership=False,
       
   375                                                   drop=drop):
       
   376             if _run_command(cmd):
       
   377                 if not confirm('   [Failed] Continue anyway?', default='n'):
       
   378                     raise Exception('Failed command: %s' % cmd)
       
   379 
       
   380     def merge_args(self, args, query_args):
       
   381         if args is not None:
       
   382             newargs = {}
       
   383             for key, val in args.items():
       
   384                 # convert cubicweb binary into db binary
       
   385                 if isinstance(val, Binary):
       
   386                     val = self._binary(val.getvalue())
       
   387                 # convert timestamp to utc.
       
   388                 # expect SET TiME ZONE to UTC at connection opening time.
       
   389                 # This shouldn't change anything for datetime without TZ.
       
   390                 elif isinstance(val, datetime) and val.tzinfo is not None:
       
   391                     val = utcdatetime(val)
       
   392                 elif isinstance(val, time) and val.tzinfo is not None:
       
   393                     val = utctime(val)
       
   394                 newargs[key] = val
       
   395             # should not collide
       
   396             assert not (frozenset(newargs) & frozenset(query_args)), \
       
   397                 'unexpected collision: %s' % (frozenset(newargs) & frozenset(query_args))
       
   398             newargs.update(query_args)
       
   399             return newargs
       
   400         return query_args
       
   401 
       
   402     def process_result(self, cursor, cnx=None, column_callbacks=None):
       
   403         """return a list of CubicWeb compliant values from data in the given cursor
       
   404         """
       
   405         return list(self.iter_process_result(cursor, cnx, column_callbacks))
       
   406 
       
   407     def iter_process_result(self, cursor, cnx, column_callbacks=None):
       
   408         """return a iterator on tuples of CubicWeb compliant values from data
       
   409         in the given cursor
       
   410         """
       
   411         # use two different implementations to avoid paying the price of
       
   412         # callback lookup for each *cell* in results when there is nothing to
       
   413         # lookup
       
   414         if not column_callbacks:
       
   415             return self.dbhelper.dbapi_module.process_cursor(cursor, self._dbencoding,
       
   416                                                              Binary)
       
   417         assert cnx
       
   418         return self._cb_process_result(cursor, column_callbacks, cnx)
       
   419 
       
   420     def _cb_process_result(self, cursor, column_callbacks, cnx):
       
   421         # begin bind to locals for optimization
       
   422         descr = cursor.description
       
   423         encoding = self._dbencoding
       
   424         process_value = self._process_value
       
   425         binary = Binary
       
   426         # /end
       
   427         cursor.arraysize = 100
       
   428         while True:
       
   429             results = cursor.fetchmany()
       
   430             if not results:
       
   431                 break
       
   432             for line in results:
       
   433                 result = []
       
   434                 for col, value in enumerate(line):
       
   435                     if value is None:
       
   436                         result.append(value)
       
   437                         continue
       
   438                     cbstack = column_callbacks.get(col, None)
       
   439                     if cbstack is None:
       
   440                         value = process_value(value, descr[col], encoding, binary)
       
   441                     else:
       
   442                         for cb in cbstack:
       
   443                             value = cb(self, cnx, value)
       
   444                     result.append(value)
       
   445                 yield result
       
   446 
       
   447     def preprocess_entity(self, entity):
       
   448         """return a dictionary to use as extra argument to cursor.execute
       
   449         to insert/update an entity into a SQL database
       
   450         """
       
   451         attrs = {}
       
   452         eschema = entity.e_schema
       
   453         converters = getattr(self.dbhelper, 'TYPE_CONVERTERS', {})
       
   454         for attr, value in entity.cw_edited.items():
       
   455             if value is not None and eschema.subjrels[attr].final:
       
   456                 atype = str(entity.e_schema.destination(attr))
       
   457                 if atype in converters:
       
   458                     # It is easier to modify preprocess_entity rather
       
   459                     # than add_entity (native) as this behavior
       
   460                     # may also be used for update.
       
   461                     value = converters[atype](value)
       
   462                 elif atype == 'Password': # XXX could be done using a TYPE_CONVERTERS callback
       
   463                     # if value is a Binary instance, this mean we got it
       
   464                     # from a query result and so it is already encrypted
       
   465                     if isinstance(value, Binary):
       
   466                         value = value.getvalue()
       
   467                     else:
       
   468                         value = crypt_password(value)
       
   469                     value = self._binary(value)
       
   470                 elif isinstance(value, Binary):
       
   471                     value = self._binary(value.getvalue())
       
   472             attrs[SQL_PREFIX+str(attr)] = value
       
   473         attrs[SQL_PREFIX+'eid'] = entity.eid
       
   474         return attrs
       
   475 
       
   476     # these are overridden by set_log_methods below
       
   477     # only defining here to prevent pylint from complaining
       
   478     info = warning = error = critical = exception = debug = lambda msg,*a,**kw: None
       
   479 
       
   480 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter'))
       
   481 
       
   482 
       
   483 # connection initialization functions ##########################################
       
   484 
       
   485 def _install_sqlite_querier_patch():
       
   486     """This monkey-patch hotfixes a bug sqlite causing some dates to be returned as strings rather than
       
   487     date objects (http://www.sqlite.org/cvstrac/tktview?tn=1327,33)
       
   488     """
       
   489     from cubicweb.server.querier import QuerierHelper
       
   490 
       
   491     if hasattr(QuerierHelper, '_sqlite_patched'):
       
   492         return  # already monkey patched
       
   493 
       
   494     def wrap_execute(base_execute):
       
   495         def new_execute(*args, **kwargs):
       
   496             rset = base_execute(*args, **kwargs)
       
   497             if rset.description:
       
   498                 found_date = False
       
   499                 for row, rowdesc in zip(rset, rset.description):
       
   500                     for cellindex, (value, vtype) in enumerate(zip(row, rowdesc)):
       
   501                         if vtype in ('TZDatetime', 'Date', 'Datetime') \
       
   502                            and isinstance(value, text_type):
       
   503                             found_date = True
       
   504                             value = value.rsplit('.', 1)[0]
       
   505                             try:
       
   506                                 row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S')
       
   507                             except Exception:
       
   508                                 row[cellindex] = strptime(value, '%Y-%m-%d')
       
   509                             if vtype == 'TZDatetime':
       
   510                                 row[cellindex] = row[cellindex].replace(tzinfo=utc)
       
   511                         if vtype == 'Time' and isinstance(value, text_type):
       
   512                             found_date = True
       
   513                             try:
       
   514                                 row[cellindex] = strptime(value, '%H:%M:%S')
       
   515                             except Exception:
       
   516                                 # DateTime used as Time?
       
   517                                 row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S')
       
   518                         if vtype == 'Interval' and isinstance(value, int):
       
   519                             found_date = True
       
   520                             # XXX value is in number of seconds?
       
   521                             row[cellindex] = timedelta(0, value, 0)
       
   522                     if not found_date:
       
   523                         break
       
   524             return rset
       
   525         return new_execute
       
   526 
       
   527     QuerierHelper.execute = wrap_execute(QuerierHelper.execute)
       
   528     QuerierHelper._sqlite_patched = True
       
   529 
       
   530 
       
   531 def _init_sqlite_connection(cnx):
       
   532     """Internal function that will be called to init a sqlite connection"""
       
   533     _install_sqlite_querier_patch()
       
   534 
       
   535     class group_concat(object):
       
   536         def __init__(self):
       
   537             self.values = set()
       
   538         def step(self, value):
       
   539             if value is not None:
       
   540                 self.values.add(value)
       
   541         def finalize(self):
       
   542             return ', '.join(text_type(v) for v in self.values)
       
   543 
       
   544     cnx.create_aggregate("GROUP_CONCAT", 1, group_concat)
       
   545 
       
   546     def _limit_size(text, maxsize, format='text/plain'):
       
   547         if len(text) < maxsize:
       
   548             return text
       
   549         if format in ('text/html', 'text/xhtml', 'text/xml'):
       
   550             text = remove_html_tags(text)
       
   551         if len(text) > maxsize:
       
   552             text = text[:maxsize] + '...'
       
   553         return text
       
   554 
       
   555     def limit_size3(text, format, maxsize):
       
   556         return _limit_size(text, maxsize, format)
       
   557     cnx.create_function("LIMIT_SIZE", 3, limit_size3)
       
   558 
       
   559     def limit_size2(text, maxsize):
       
   560         return _limit_size(text, maxsize)
       
   561     cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2)
       
   562 
       
   563     from logilab.common.date import strptime
       
   564     def weekday(ustr):
       
   565         try:
       
   566             dt = strptime(ustr, '%Y-%m-%d %H:%M:%S')
       
   567         except:
       
   568             dt =  strptime(ustr, '%Y-%m-%d')
       
   569         # expect sunday to be 1, saturday 7 while weekday method return 0 for
       
   570         # monday
       
   571         return (dt.weekday() + 1) % 7
       
   572     cnx.create_function("WEEKDAY", 1, weekday)
       
   573 
       
   574     cnx.cursor().execute("pragma foreign_keys = on")
       
   575 
       
   576     import yams.constraints
       
   577     yams.constraints.patch_sqlite_decimal()
       
   578 
       
   579 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', [])
       
   580 sqlite_hooks.append(_init_sqlite_connection)
       
   581 
       
   582 
       
   583 def _init_postgres_connection(cnx):
       
   584     """Internal function that will be called to init a postgresql connection"""
       
   585     cnx.cursor().execute('SET TIME ZONE UTC')
       
   586     # commit is needed, else setting are lost if the connection is first
       
   587     # rolled back
       
   588     cnx.commit()
       
   589 
       
   590 postgres_hooks = SQL_CONNECT_HOOKS.setdefault('postgres', [])
       
   591 postgres_hooks.append(_init_postgres_connection)