|
1 # copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
|
2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
3 # |
|
4 # This file is part of CubicWeb. |
|
5 # |
|
6 # CubicWeb is free software: you can redistribute it and/or modify it under the |
|
7 # terms of the GNU Lesser General Public License as published by the Free |
|
8 # Software Foundation, either version 2.1 of the License, or (at your option) |
|
9 # any later version. |
|
10 # |
|
11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT |
|
12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS |
|
13 # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more |
|
14 # details. |
|
15 # |
|
16 # You should have received a copy of the GNU Lesser General Public License along |
|
17 # with CubicWeb. If not, see <http://www.gnu.org/licenses/>. |
|
18 """SQL utilities functions and classes.""" |
|
19 from __future__ import print_function |
|
20 |
|
21 __docformat__ = "restructuredtext en" |
|
22 |
|
23 import sys |
|
24 import re |
|
25 import subprocess |
|
26 from os.path import abspath |
|
27 from logging import getLogger |
|
28 from datetime import time, datetime, timedelta |
|
29 |
|
30 from six import string_types, text_type |
|
31 from six.moves import filter |
|
32 |
|
33 from pytz import utc |
|
34 |
|
35 from logilab import database as db, common as lgc |
|
36 from logilab.common.shellutils import ProgressBar, DummyProgressBar |
|
37 from logilab.common.deprecation import deprecated |
|
38 from logilab.common.logging_ext import set_log_methods |
|
39 from logilab.common.date import utctime, utcdatetime, strptime |
|
40 from logilab.database.sqlgen import SQLGenerator |
|
41 |
|
42 from cubicweb import Binary, ConfigurationError |
|
43 from cubicweb.uilib import remove_html_tags |
|
44 from cubicweb.schema import PURE_VIRTUAL_RTYPES |
|
45 from cubicweb.server import SQL_CONNECT_HOOKS |
|
46 from cubicweb.server.utils import crypt_password |
|
47 |
|
48 lgc.USE_MX_DATETIME = False |
|
49 SQL_PREFIX = 'cw_' |
|
50 |
|
51 |
|
52 def _run_command(cmd): |
|
53 if isinstance(cmd, string_types): |
|
54 print(cmd) |
|
55 return subprocess.call(cmd, shell=True) |
|
56 else: |
|
57 print(' '.join(cmd)) |
|
58 return subprocess.call(cmd) |
|
59 |
|
60 |
|
61 def sqlexec(sqlstmts, cursor_or_execute, withpb=True, |
|
62 pbtitle='', delimiter=';', cnx=None): |
|
63 """execute sql statements ignoring DROP/ CREATE GROUP or USER statements |
|
64 error. |
|
65 |
|
66 :sqlstmts_as_string: a string or a list of sql statements. |
|
67 :cursor_or_execute: sql cursor or a callback used to execute statements |
|
68 :cnx: if given, commit/rollback at each statement. |
|
69 |
|
70 :withpb: if True, display a progresse bar |
|
71 :pbtitle: a string displayed as the progress bar title (if `withpb=True`) |
|
72 |
|
73 :delimiter: a string used to split sqlstmts (if it is a string) |
|
74 |
|
75 Return the failed statements (same type as sqlstmts) |
|
76 """ |
|
77 if hasattr(cursor_or_execute, 'execute'): |
|
78 execute = cursor_or_execute.execute |
|
79 else: |
|
80 execute = cursor_or_execute |
|
81 sqlstmts_as_string = False |
|
82 if isinstance(sqlstmts, string_types): |
|
83 sqlstmts_as_string = True |
|
84 sqlstmts = sqlstmts.split(delimiter) |
|
85 if withpb: |
|
86 if sys.stdout.isatty(): |
|
87 pb = ProgressBar(len(sqlstmts), title=pbtitle) |
|
88 else: |
|
89 pb = DummyProgressBar() |
|
90 failed = [] |
|
91 for sql in sqlstmts: |
|
92 sql = sql.strip() |
|
93 if withpb: |
|
94 pb.update() |
|
95 if not sql: |
|
96 continue |
|
97 try: |
|
98 # some dbapi modules doesn't accept unicode for sql string |
|
99 execute(str(sql)) |
|
100 except Exception: |
|
101 if cnx: |
|
102 cnx.rollback() |
|
103 failed.append(sql) |
|
104 else: |
|
105 if cnx: |
|
106 cnx.commit() |
|
107 if withpb: |
|
108 print() |
|
109 if sqlstmts_as_string: |
|
110 failed = delimiter.join(failed) |
|
111 return failed |
|
112 |
|
113 |
|
114 def sqlgrants(schema, driver, user, |
|
115 text_index=True, set_owner=True, |
|
116 skip_relations=(), skip_entities=()): |
|
117 """return sql to give all access privileges to the given user on the system |
|
118 schema |
|
119 """ |
|
120 from cubicweb.server.schema2sql import grant_schema |
|
121 from cubicweb.server.sources import native |
|
122 output = [] |
|
123 w = output.append |
|
124 w(native.grant_schema(user, set_owner)) |
|
125 w('') |
|
126 if text_index: |
|
127 dbhelper = db.get_db_helper(driver) |
|
128 w(dbhelper.sql_grant_user_on_fti(user)) |
|
129 w('') |
|
130 w(grant_schema(schema, user, set_owner, skip_entities=skip_entities, prefix=SQL_PREFIX)) |
|
131 return '\n'.join(output) |
|
132 |
|
133 |
|
134 def sqlschema(schema, driver, text_index=True, |
|
135 user=None, set_owner=False, |
|
136 skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()): |
|
137 """return the system sql schema, according to the given parameters""" |
|
138 from cubicweb.server.schema2sql import schema2sql |
|
139 from cubicweb.server.sources import native |
|
140 if set_owner: |
|
141 assert user, 'user is argument required when set_owner is true' |
|
142 output = [] |
|
143 w = output.append |
|
144 w(native.sql_schema(driver)) |
|
145 w('') |
|
146 dbhelper = db.get_db_helper(driver) |
|
147 if text_index: |
|
148 w(dbhelper.sql_init_fti().replace(';', ';;')) |
|
149 w('') |
|
150 w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX, |
|
151 skip_entities=skip_entities, |
|
152 skip_relations=skip_relations).replace(';', ';;')) |
|
153 if dbhelper.users_support and user: |
|
154 w('') |
|
155 w(sqlgrants(schema, driver, user, text_index, set_owner, |
|
156 skip_relations, skip_entities).replace(';', ';;')) |
|
157 return '\n'.join(output) |
|
158 |
|
159 |
|
160 def sqldropschema(schema, driver, text_index=True, |
|
161 skip_relations=PURE_VIRTUAL_RTYPES, skip_entities=()): |
|
162 """return the sql to drop the schema, according to the given parameters""" |
|
163 from cubicweb.server.schema2sql import dropschema2sql |
|
164 from cubicweb.server.sources import native |
|
165 output = [] |
|
166 w = output.append |
|
167 if text_index: |
|
168 dbhelper = db.get_db_helper(driver) |
|
169 w(dbhelper.sql_drop_fti()) |
|
170 w('') |
|
171 w(dropschema2sql(dbhelper, schema, prefix=SQL_PREFIX, |
|
172 skip_entities=skip_entities, |
|
173 skip_relations=skip_relations)) |
|
174 w('') |
|
175 w(native.sql_drop_schema(driver)) |
|
176 return '\n'.join(output) |
|
177 |
|
178 |
|
179 _SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match |
|
180 def sql_drop_all_user_tables(driver_or_helper, sqlcursor): |
|
181 """Return ths sql to drop all tables found in the database system.""" |
|
182 if not getattr(driver_or_helper, 'list_tables', None): |
|
183 dbhelper = db.get_db_helper(driver_or_helper) |
|
184 else: |
|
185 dbhelper = driver_or_helper |
|
186 |
|
187 cmds = [dbhelper.sql_drop_sequence('entities_id_seq')] |
|
188 # for mssql, we need to drop views before tables |
|
189 if hasattr(dbhelper, 'list_views'): |
|
190 cmds += ['DROP VIEW %s;' % name |
|
191 for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))] |
|
192 cmds += ['DROP TABLE %s;' % name |
|
193 for name in filter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))] |
|
194 return '\n'.join(cmds) |
|
195 |
|
196 |
|
197 class ConnectionWrapper(object): |
|
198 """handle connection to the system source, at some point associated to a |
|
199 :class:`Session` |
|
200 """ |
|
201 |
|
202 # since 3.19, we only have to manage the system source connection |
|
203 def __init__(self, system_source): |
|
204 # dictionary of (source, connection), indexed by sources'uri |
|
205 self._source = system_source |
|
206 self.cnx = system_source.get_connection() |
|
207 self.cu = self.cnx.cursor() |
|
208 |
|
209 def commit(self): |
|
210 """commit the current transaction for this user""" |
|
211 # let exception propagates |
|
212 self.cnx.commit() |
|
213 |
|
214 def rollback(self): |
|
215 """rollback the current transaction for this user""" |
|
216 # catch exceptions, rollback other sources anyway |
|
217 try: |
|
218 self.cnx.rollback() |
|
219 except Exception: |
|
220 self._source.critical('rollback error', exc_info=sys.exc_info()) |
|
221 # error on rollback, the connection is much probably in a really |
|
222 # bad state. Replace it by a new one. |
|
223 self.reconnect() |
|
224 |
|
225 def close(self, i_know_what_i_do=False): |
|
226 """close all connections in the set""" |
|
227 if i_know_what_i_do is not True: # unexpected closing safety belt |
|
228 raise RuntimeError('connections set shouldn\'t be closed') |
|
229 try: |
|
230 self.cu.close() |
|
231 self.cu = None |
|
232 except Exception: |
|
233 pass |
|
234 try: |
|
235 self.cnx.close() |
|
236 self.cnx = None |
|
237 except Exception: |
|
238 pass |
|
239 |
|
240 # internals ############################################################### |
|
241 |
|
242 def cnxset_freed(self): |
|
243 """connections set is being freed from a session""" |
|
244 pass # no nothing by default |
|
245 |
|
246 def reconnect(self): |
|
247 """reopen a connection for this source or all sources if none specified |
|
248 """ |
|
249 try: |
|
250 # properly close existing connection if any |
|
251 self.cnx.close() |
|
252 except Exception: |
|
253 pass |
|
254 self._source.info('trying to reconnect') |
|
255 self.cnx = self._source.get_connection() |
|
256 self.cu = self.cnx.cursor() |
|
257 |
|
258 @deprecated('[3.19] use .cu instead') |
|
259 def __getitem__(self, uri): |
|
260 assert uri == 'system' |
|
261 return self.cu |
|
262 |
|
263 @deprecated('[3.19] use repo.system_source instead') |
|
264 def source(self, uid): |
|
265 assert uid == 'system' |
|
266 return self._source |
|
267 |
|
268 @deprecated('[3.19] use .cnx instead') |
|
269 def connection(self, uid): |
|
270 assert uid == 'system' |
|
271 return self.cnx |
|
272 |
|
273 |
|
274 class SqliteConnectionWrapper(ConnectionWrapper): |
|
275 """Sqlite specific connection wrapper: close the connection each time it's |
|
276 freed (and reopen it later when needed) |
|
277 """ |
|
278 def __init__(self, system_source): |
|
279 # don't call parent's __init__, we don't want to initiate the connection |
|
280 self._source = system_source |
|
281 |
|
282 _cnx = None |
|
283 |
|
284 def cnxset_freed(self): |
|
285 self.cu.close() |
|
286 self.cnx.close() |
|
287 self.cnx = self.cu = None |
|
288 |
|
289 @property |
|
290 def cnx(self): |
|
291 if self._cnx is None: |
|
292 self._cnx = self._source.get_connection() |
|
293 self._cu = self._cnx.cursor() |
|
294 return self._cnx |
|
295 @cnx.setter |
|
296 def cnx(self, value): |
|
297 self._cnx = value |
|
298 |
|
299 @property |
|
300 def cu(self): |
|
301 if self._cnx is None: |
|
302 self._cnx = self._source.get_connection() |
|
303 self._cu = self._cnx.cursor() |
|
304 return self._cu |
|
305 @cu.setter |
|
306 def cu(self, value): |
|
307 self._cu = value |
|
308 |
|
309 |
|
310 class SQLAdapterMixIn(object): |
|
311 """Mixin for SQL data sources, getting a connection from a configuration |
|
312 dictionary and handling connection locking |
|
313 """ |
|
314 cnx_wrap = ConnectionWrapper |
|
315 |
|
316 def __init__(self, source_config, repairing=False): |
|
317 try: |
|
318 self.dbdriver = source_config['db-driver'].lower() |
|
319 dbname = source_config['db-name'] |
|
320 except KeyError: |
|
321 raise ConfigurationError('missing some expected entries in sources file') |
|
322 dbhost = source_config.get('db-host') |
|
323 port = source_config.get('db-port') |
|
324 dbport = port and int(port) or None |
|
325 dbuser = source_config.get('db-user') |
|
326 dbpassword = source_config.get('db-password') |
|
327 dbencoding = source_config.get('db-encoding', 'UTF-8') |
|
328 dbextraargs = source_config.get('db-extra-arguments') |
|
329 dbnamespace = source_config.get('db-namespace') |
|
330 self.dbhelper = db.get_db_helper(self.dbdriver) |
|
331 self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser, |
|
332 dbpassword, dbextraargs, |
|
333 dbencoding, dbnamespace) |
|
334 self.sqlgen = SQLGenerator() |
|
335 # copy back some commonly accessed attributes |
|
336 dbapi_module = self.dbhelper.dbapi_module |
|
337 self.OperationalError = dbapi_module.OperationalError |
|
338 self.InterfaceError = dbapi_module.InterfaceError |
|
339 self.DbapiError = dbapi_module.Error |
|
340 self._binary = self.dbhelper.binary_value |
|
341 self._process_value = dbapi_module.process_value |
|
342 self._dbencoding = dbencoding |
|
343 if self.dbdriver == 'sqlite': |
|
344 self.cnx_wrap = SqliteConnectionWrapper |
|
345 self.dbhelper.dbname = abspath(self.dbhelper.dbname) |
|
346 if not repairing: |
|
347 statement_timeout = int(source_config.get('db-statement-timeout', 0)) |
|
348 if statement_timeout > 0: |
|
349 def set_postgres_timeout(cnx): |
|
350 cnx.cursor().execute('SET statement_timeout to %d' % statement_timeout) |
|
351 cnx.commit() |
|
352 postgres_hooks = SQL_CONNECT_HOOKS['postgres'] |
|
353 postgres_hooks.append(set_postgres_timeout) |
|
354 |
|
355 def wrapped_connection(self): |
|
356 """open and return a connection to the database, wrapped into a class |
|
357 handling reconnection and all |
|
358 """ |
|
359 return self.cnx_wrap(self) |
|
360 |
|
361 def get_connection(self): |
|
362 """open and return a connection to the database""" |
|
363 return self.dbhelper.get_connection() |
|
364 |
|
365 def backup_to_file(self, backupfile, confirm): |
|
366 for cmd in self.dbhelper.backup_commands(backupfile, |
|
367 keepownership=False): |
|
368 if _run_command(cmd): |
|
369 if not confirm(' [Failed] Continue anyway?', default='n'): |
|
370 raise Exception('Failed command: %s' % cmd) |
|
371 |
|
372 def restore_from_file(self, backupfile, confirm, drop=True): |
|
373 for cmd in self.dbhelper.restore_commands(backupfile, |
|
374 keepownership=False, |
|
375 drop=drop): |
|
376 if _run_command(cmd): |
|
377 if not confirm(' [Failed] Continue anyway?', default='n'): |
|
378 raise Exception('Failed command: %s' % cmd) |
|
379 |
|
380 def merge_args(self, args, query_args): |
|
381 if args is not None: |
|
382 newargs = {} |
|
383 for key, val in args.items(): |
|
384 # convert cubicweb binary into db binary |
|
385 if isinstance(val, Binary): |
|
386 val = self._binary(val.getvalue()) |
|
387 # convert timestamp to utc. |
|
388 # expect SET TiME ZONE to UTC at connection opening time. |
|
389 # This shouldn't change anything for datetime without TZ. |
|
390 elif isinstance(val, datetime) and val.tzinfo is not None: |
|
391 val = utcdatetime(val) |
|
392 elif isinstance(val, time) and val.tzinfo is not None: |
|
393 val = utctime(val) |
|
394 newargs[key] = val |
|
395 # should not collide |
|
396 assert not (frozenset(newargs) & frozenset(query_args)), \ |
|
397 'unexpected collision: %s' % (frozenset(newargs) & frozenset(query_args)) |
|
398 newargs.update(query_args) |
|
399 return newargs |
|
400 return query_args |
|
401 |
|
402 def process_result(self, cursor, cnx=None, column_callbacks=None): |
|
403 """return a list of CubicWeb compliant values from data in the given cursor |
|
404 """ |
|
405 return list(self.iter_process_result(cursor, cnx, column_callbacks)) |
|
406 |
|
407 def iter_process_result(self, cursor, cnx, column_callbacks=None): |
|
408 """return a iterator on tuples of CubicWeb compliant values from data |
|
409 in the given cursor |
|
410 """ |
|
411 # use two different implementations to avoid paying the price of |
|
412 # callback lookup for each *cell* in results when there is nothing to |
|
413 # lookup |
|
414 if not column_callbacks: |
|
415 return self.dbhelper.dbapi_module.process_cursor(cursor, self._dbencoding, |
|
416 Binary) |
|
417 assert cnx |
|
418 return self._cb_process_result(cursor, column_callbacks, cnx) |
|
419 |
|
420 def _cb_process_result(self, cursor, column_callbacks, cnx): |
|
421 # begin bind to locals for optimization |
|
422 descr = cursor.description |
|
423 encoding = self._dbencoding |
|
424 process_value = self._process_value |
|
425 binary = Binary |
|
426 # /end |
|
427 cursor.arraysize = 100 |
|
428 while True: |
|
429 results = cursor.fetchmany() |
|
430 if not results: |
|
431 break |
|
432 for line in results: |
|
433 result = [] |
|
434 for col, value in enumerate(line): |
|
435 if value is None: |
|
436 result.append(value) |
|
437 continue |
|
438 cbstack = column_callbacks.get(col, None) |
|
439 if cbstack is None: |
|
440 value = process_value(value, descr[col], encoding, binary) |
|
441 else: |
|
442 for cb in cbstack: |
|
443 value = cb(self, cnx, value) |
|
444 result.append(value) |
|
445 yield result |
|
446 |
|
447 def preprocess_entity(self, entity): |
|
448 """return a dictionary to use as extra argument to cursor.execute |
|
449 to insert/update an entity into a SQL database |
|
450 """ |
|
451 attrs = {} |
|
452 eschema = entity.e_schema |
|
453 converters = getattr(self.dbhelper, 'TYPE_CONVERTERS', {}) |
|
454 for attr, value in entity.cw_edited.items(): |
|
455 if value is not None and eschema.subjrels[attr].final: |
|
456 atype = str(entity.e_schema.destination(attr)) |
|
457 if atype in converters: |
|
458 # It is easier to modify preprocess_entity rather |
|
459 # than add_entity (native) as this behavior |
|
460 # may also be used for update. |
|
461 value = converters[atype](value) |
|
462 elif atype == 'Password': # XXX could be done using a TYPE_CONVERTERS callback |
|
463 # if value is a Binary instance, this mean we got it |
|
464 # from a query result and so it is already encrypted |
|
465 if isinstance(value, Binary): |
|
466 value = value.getvalue() |
|
467 else: |
|
468 value = crypt_password(value) |
|
469 value = self._binary(value) |
|
470 elif isinstance(value, Binary): |
|
471 value = self._binary(value.getvalue()) |
|
472 attrs[SQL_PREFIX+str(attr)] = value |
|
473 attrs[SQL_PREFIX+'eid'] = entity.eid |
|
474 return attrs |
|
475 |
|
476 # these are overridden by set_log_methods below |
|
477 # only defining here to prevent pylint from complaining |
|
478 info = warning = error = critical = exception = debug = lambda msg,*a,**kw: None |
|
479 |
|
480 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) |
|
481 |
|
482 |
|
483 # connection initialization functions ########################################## |
|
484 |
|
485 def _install_sqlite_querier_patch(): |
|
486 """This monkey-patch hotfixes a bug sqlite causing some dates to be returned as strings rather than |
|
487 date objects (http://www.sqlite.org/cvstrac/tktview?tn=1327,33) |
|
488 """ |
|
489 from cubicweb.server.querier import QuerierHelper |
|
490 |
|
491 if hasattr(QuerierHelper, '_sqlite_patched'): |
|
492 return # already monkey patched |
|
493 |
|
494 def wrap_execute(base_execute): |
|
495 def new_execute(*args, **kwargs): |
|
496 rset = base_execute(*args, **kwargs) |
|
497 if rset.description: |
|
498 found_date = False |
|
499 for row, rowdesc in zip(rset, rset.description): |
|
500 for cellindex, (value, vtype) in enumerate(zip(row, rowdesc)): |
|
501 if vtype in ('TZDatetime', 'Date', 'Datetime') \ |
|
502 and isinstance(value, text_type): |
|
503 found_date = True |
|
504 value = value.rsplit('.', 1)[0] |
|
505 try: |
|
506 row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') |
|
507 except Exception: |
|
508 row[cellindex] = strptime(value, '%Y-%m-%d') |
|
509 if vtype == 'TZDatetime': |
|
510 row[cellindex] = row[cellindex].replace(tzinfo=utc) |
|
511 if vtype == 'Time' and isinstance(value, text_type): |
|
512 found_date = True |
|
513 try: |
|
514 row[cellindex] = strptime(value, '%H:%M:%S') |
|
515 except Exception: |
|
516 # DateTime used as Time? |
|
517 row[cellindex] = strptime(value, '%Y-%m-%d %H:%M:%S') |
|
518 if vtype == 'Interval' and isinstance(value, int): |
|
519 found_date = True |
|
520 # XXX value is in number of seconds? |
|
521 row[cellindex] = timedelta(0, value, 0) |
|
522 if not found_date: |
|
523 break |
|
524 return rset |
|
525 return new_execute |
|
526 |
|
527 QuerierHelper.execute = wrap_execute(QuerierHelper.execute) |
|
528 QuerierHelper._sqlite_patched = True |
|
529 |
|
530 |
|
531 def _init_sqlite_connection(cnx): |
|
532 """Internal function that will be called to init a sqlite connection""" |
|
533 _install_sqlite_querier_patch() |
|
534 |
|
535 class group_concat(object): |
|
536 def __init__(self): |
|
537 self.values = set() |
|
538 def step(self, value): |
|
539 if value is not None: |
|
540 self.values.add(value) |
|
541 def finalize(self): |
|
542 return ', '.join(text_type(v) for v in self.values) |
|
543 |
|
544 cnx.create_aggregate("GROUP_CONCAT", 1, group_concat) |
|
545 |
|
546 def _limit_size(text, maxsize, format='text/plain'): |
|
547 if len(text) < maxsize: |
|
548 return text |
|
549 if format in ('text/html', 'text/xhtml', 'text/xml'): |
|
550 text = remove_html_tags(text) |
|
551 if len(text) > maxsize: |
|
552 text = text[:maxsize] + '...' |
|
553 return text |
|
554 |
|
555 def limit_size3(text, format, maxsize): |
|
556 return _limit_size(text, maxsize, format) |
|
557 cnx.create_function("LIMIT_SIZE", 3, limit_size3) |
|
558 |
|
559 def limit_size2(text, maxsize): |
|
560 return _limit_size(text, maxsize) |
|
561 cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) |
|
562 |
|
563 from logilab.common.date import strptime |
|
564 def weekday(ustr): |
|
565 try: |
|
566 dt = strptime(ustr, '%Y-%m-%d %H:%M:%S') |
|
567 except: |
|
568 dt = strptime(ustr, '%Y-%m-%d') |
|
569 # expect sunday to be 1, saturday 7 while weekday method return 0 for |
|
570 # monday |
|
571 return (dt.weekday() + 1) % 7 |
|
572 cnx.create_function("WEEKDAY", 1, weekday) |
|
573 |
|
574 cnx.cursor().execute("pragma foreign_keys = on") |
|
575 |
|
576 import yams.constraints |
|
577 yams.constraints.patch_sqlite_decimal() |
|
578 |
|
579 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) |
|
580 sqlite_hooks.append(_init_sqlite_connection) |
|
581 |
|
582 |
|
583 def _init_postgres_connection(cnx): |
|
584 """Internal function that will be called to init a postgresql connection""" |
|
585 cnx.cursor().execute('SET TIME ZONE UTC') |
|
586 # commit is needed, else setting are lost if the connection is first |
|
587 # rolled back |
|
588 cnx.commit() |
|
589 |
|
590 postgres_hooks = SQL_CONNECT_HOOKS.setdefault('postgres', []) |
|
591 postgres_hooks.append(_init_postgres_connection) |