31 """ |
31 """ |
32 __docformat__ = "restructuredtext en" |
32 __docformat__ = "restructuredtext en" |
33 |
33 |
34 import threading |
34 import threading |
35 |
35 |
|
36 from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY |
|
37 |
36 from rql import BadRQLQuery, CoercionError |
38 from rql import BadRQLQuery, CoercionError |
37 from rql.stmts import Union, Select |
39 from rql.stmts import Union, Select |
38 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not, |
40 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not, |
39 Variable, ColumnAlias, Relation, SubQuery, Exists) |
41 Variable, ColumnAlias, Relation, SubQuery, Exists) |
40 |
42 |
|
43 from cubicweb import QueryError |
41 from cubicweb.server.sqlutils import SQL_PREFIX |
44 from cubicweb.server.sqlutils import SQL_PREFIX |
42 from cubicweb.server.utils import cleanup_solutions |
45 from cubicweb.server.utils import cleanup_solutions |
43 |
46 |
44 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable |
47 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable |
|
48 |
|
49 FunctionDescr.source_execute = None |
|
50 |
|
51 def default_update_cb_stack(self, stack): |
|
52 stack.append(self.source_execute) |
|
53 FunctionDescr.update_cb_stack = default_update_cb_stack |
|
54 |
|
55 LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH') |
|
56 def length_source_execute(source, value): |
|
57 return len(value.getvalue()) |
|
58 LENGTH.source_execute = length_source_execute |
45 |
59 |
46 def _new_var(select, varname): |
60 def _new_var(select, varname): |
47 newvar = select.get_variable(varname) |
61 newvar = select.get_variable(varname) |
48 if not 'relations' in newvar.stinfo: |
62 if not 'relations' in newvar.stinfo: |
49 # not yet initialized |
63 # not yet initialized |
250 for vref in term.iget_nodes(VariableRef): |
264 for vref in term.iget_nodes(VariableRef): |
251 if not vref.name in selectedidx: |
265 if not vref.name in selectedidx: |
252 selectedidx.append(vref.name) |
266 selectedidx.append(vref.name) |
253 rqlst.selection.append(vref) |
267 rqlst.selection.append(vref) |
254 |
268 |
255 # IGenerator implementation for RQL->SQL ###################################### |
269 def iter_mapped_var_sels(stmt, variable): |
256 |
270 # variable is a Variable or ColumnAlias node mapped to a source side |
|
271 # callback |
|
272 if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias |
|
273 variable.stinfo['selected']): |
|
274 raise QueryError("can't use %s as a restriction variable" |
|
275 % variable.name) |
|
276 for selectidx in variable.stinfo['selected']: |
|
277 vrefs = stmt.selection[selectidx].get_nodes(VariableRef) |
|
278 if len(vrefs) != 1: |
|
279 raise QueryError() |
|
280 yield selectidx, vrefs[0] |
|
281 |
|
282 def update_source_cb_stack(state, stmt, node, stack): |
|
283 while True: |
|
284 node = node.parent |
|
285 if node is stmt: |
|
286 break |
|
287 if not isinstance(node, Function): |
|
288 raise QueryError() |
|
289 func = SQL_FUNCTIONS_REGISTRY.get_function(node.name) |
|
290 if func.source_execute is None: |
|
291 raise QueryError('%s can not be called on mapped attribute' |
|
292 % node.name) |
|
293 state.source_cb_funcs.add(node) |
|
294 func.update_cb_stack(stack) |
|
295 |
|
296 |
|
297 # IGenerator implementation for RQL->SQL ####################################### |
257 |
298 |
258 class StateInfo(object): |
299 class StateInfo(object): |
259 def __init__(self, existssols, unstablevars): |
300 def __init__(self, existssols, unstablevars): |
260 self.existssols = existssols |
301 self.existssols = existssols |
261 self.unstablevars = unstablevars |
302 self.unstablevars = unstablevars |
262 self.subtables = {} |
303 self.subtables = {} |
|
304 self.needs_source_cb = None |
|
305 self.subquery_source_cb = None |
|
306 self.source_cb_funcs = set() |
263 |
307 |
264 def reset(self, solution): |
308 def reset(self, solution): |
265 """reset some visit variables""" |
309 """reset some visit variables""" |
266 self.solution = solution |
310 self.solution = solution |
267 self.count = 0 |
311 self.count = 0 |
274 self.duplicate_switches = [] |
318 self.duplicate_switches = [] |
275 self.aliases = {} |
319 self.aliases = {} |
276 self.restrictions = [] |
320 self.restrictions = [] |
277 self._restr_stack = [] |
321 self._restr_stack = [] |
278 self.ignore_varmap = False |
322 self.ignore_varmap = False |
|
323 self._needs_source_cb = {} |
|
324 |
|
325 def merge_source_cbs(self, needs_source_cb): |
|
326 if self.needs_source_cb is None: |
|
327 self.needs_source_cb = needs_source_cb |
|
328 elif needs_source_cb != self.needs_source_cb: |
|
329 raise QueryError('query fetch some source mapped attribute, some not') |
|
330 |
|
331 def finalize_source_cbs(self): |
|
332 if self.subquery_source_cb is not None: |
|
333 self.needs_source_cb.update(self.subquery_source_cb) |
279 |
334 |
280 def add_restriction(self, restr): |
335 def add_restriction(self, restr): |
281 if restr: |
336 if restr: |
282 self.restrictions.append(restr) |
337 self.restrictions.append(restr) |
283 |
338 |
330 |
385 |
331 WARNING: a CubicWebSQLGenerator instance is not thread safe, but generate is |
386 WARNING: a CubicWebSQLGenerator instance is not thread safe, but generate is |
332 protected by a lock |
387 protected by a lock |
333 """ |
388 """ |
334 |
389 |
335 def __init__(self, schema, dbms_helper, attrmap=None): |
390 def __init__(self, schema, dbhelper, attrmap=None): |
336 self.schema = schema |
391 self.schema = schema |
337 self.dbms_helper = dbms_helper |
392 self.dbhelper = dbhelper |
338 self.dbencoding = dbms_helper.dbencoding |
393 self.dbencoding = dbhelper.dbencoding |
339 self.keyword_map = {'NOW' : self.dbms_helper.sql_current_timestamp, |
394 self.keyword_map = {'NOW' : self.dbhelper.sql_current_timestamp, |
340 'TODAY': self.dbms_helper.sql_current_date, |
395 'TODAY': self.dbhelper.sql_current_date, |
341 } |
396 } |
342 if not self.dbms_helper.union_parentheses_support: |
397 if not self.dbhelper.union_parentheses_support: |
343 self.union_sql = self.noparen_union_sql |
398 self.union_sql = self.noparen_union_sql |
344 if self.dbms_helper.fti_need_distinct: |
399 if self.dbhelper.fti_need_distinct: |
345 self.__union_sql = self.union_sql |
400 self.__union_sql = self.union_sql |
346 self.union_sql = self.has_text_need_distinct_union_sql |
401 self.union_sql = self.has_text_need_distinct_union_sql |
347 self._lock = threading.Lock() |
402 self._lock = threading.Lock() |
348 if attrmap is None: |
403 if attrmap is None: |
349 attrmap = {} |
404 attrmap = {} |
371 self._not_scope_offset = 0 |
426 self._not_scope_offset = 0 |
372 try: |
427 try: |
373 # union query for each rqlst / solution |
428 # union query for each rqlst / solution |
374 sql = self.union_sql(union) |
429 sql = self.union_sql(union) |
375 # we are done |
430 # we are done |
376 return sql, self._query_attrs |
431 return sql, self._query_attrs, self._state.needs_source_cb |
377 finally: |
432 finally: |
378 self._lock.release() |
433 self._lock.release() |
379 |
434 |
380 def has_text_need_distinct_union_sql(self, union, needalias=False): |
435 def has_text_need_distinct_union_sql(self, union, needalias=False): |
381 if getattr(union, 'has_text_query', False): |
436 if getattr(union, 'has_text_query', False): |
389 sqls = ('(%s)' % self.select_sql(select, needalias) |
444 sqls = ('(%s)' % self.select_sql(select, needalias) |
390 for select in union.children) |
445 for select in union.children) |
391 return '\nUNION ALL\n'.join(sqls) |
446 return '\nUNION ALL\n'.join(sqls) |
392 |
447 |
393 def noparen_union_sql(self, union, needalias=False): |
448 def noparen_union_sql(self, union, needalias=False): |
394 # needed for sqlite backend which doesn't like parentheses around |
449 # needed for sqlite backend which doesn't like parentheses around union |
395 # union query. This may cause bug in some condition (sort in one of |
450 # query. This may cause bug in some condition (sort in one of the |
396 # the subquery) but will work in most case |
451 # subquery) but will work in most case |
|
452 # |
397 # see http://www.sqlite.org/cvstrac/tktview?tn=3074 |
453 # see http://www.sqlite.org/cvstrac/tktview?tn=3074 |
398 sqls = (self.select_sql(select, needalias) |
454 sqls = (self.select_sql(select, needalias) |
399 for i, select in enumerate(union.children)) |
455 for i, select in enumerate(union.children)) |
400 return '\nUNION ALL\n'.join(sqls) |
456 return '\nUNION ALL\n'.join(sqls) |
401 |
457 |
433 select.select_only_variables() |
489 select.select_only_variables() |
434 needwrap = True |
490 needwrap = True |
435 else: |
491 else: |
436 existssols, unstable = {}, () |
492 existssols, unstable = {}, () |
437 state = StateInfo(existssols, unstable) |
493 state = StateInfo(existssols, unstable) |
|
494 if self._state is not None: |
|
495 # state from a previous unioned select |
|
496 state.merge_source_cbs(self._state.needs_source_cb) |
438 # treat subqueries |
497 # treat subqueries |
439 self._subqueries_sql(select, state) |
498 self._subqueries_sql(select, state) |
440 # generate sql for this select node |
499 # generate sql for this select node |
441 selectidx = [str(term) for term in select.selection] |
500 selectidx = [str(term) for term in select.selection] |
442 if needwrap: |
501 if needwrap: |
488 fselectidx) |
547 fselectidx) |
489 for sortterm in sorts) |
548 for sortterm in sorts) |
490 if fneedwrap: |
549 if fneedwrap: |
491 selection = ['T1.C%s' % i for i in xrange(len(origselection))] |
550 selection = ['T1.C%s' % i for i in xrange(len(origselection))] |
492 sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql) |
551 sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql) |
|
552 state.finalize_source_cbs() |
493 finally: |
553 finally: |
494 select.selection = origselection |
554 select.selection = origselection |
495 # limit / offset |
555 # limit / offset |
496 limit = select.limit |
556 limit = select.limit |
497 if limit: |
557 if limit: |
502 return sql |
562 return sql |
503 |
563 |
504 def _subqueries_sql(self, select, state): |
564 def _subqueries_sql(self, select, state): |
505 for i, subquery in enumerate(select.with_): |
565 for i, subquery in enumerate(select.with_): |
506 sql = self.union_sql(subquery.query, needalias=True) |
566 sql = self.union_sql(subquery.query, needalias=True) |
507 tablealias = '_T%s' % i |
567 tablealias = '_T%s' % i # XXX nested subqueries |
508 sql = '(%s) AS %s' % (sql, tablealias) |
568 sql = '(%s) AS %s' % (sql, tablealias) |
509 state.subtables[tablealias] = (0, sql) |
569 state.subtables[tablealias] = (0, sql) |
|
570 latest_state = self._state |
510 for vref in subquery.aliases: |
571 for vref in subquery.aliases: |
511 alias = vref.variable |
572 alias = vref.variable |
512 alias._q_sqltable = tablealias |
573 alias._q_sqltable = tablealias |
513 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum) |
574 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum) |
|
575 try: |
|
576 stack = latest_state.needs_source_cb[alias.colnum] |
|
577 if state.subquery_source_cb is None: |
|
578 state.subquery_source_cb = {} |
|
579 for selectidx, vref in iter_mapped_var_sels(select, alias): |
|
580 stack = stack[:] |
|
581 update_source_cb_stack(state, select, vref, stack) |
|
582 state.subquery_source_cb[selectidx] = stack |
|
583 except KeyError: |
|
584 continue |
514 |
585 |
515 def _solutions_sql(self, select, solutions, distinct, needalias): |
586 def _solutions_sql(self, select, solutions, distinct, needalias): |
516 sqls = [] |
587 sqls = [] |
517 for solution in solutions: |
588 for solution in solutions: |
518 self._state.reset(solution) |
589 self._state.reset(solution) |
520 if select.where is not None: |
591 if select.where is not None: |
521 self._state.add_restriction(select.where.accept(self)) |
592 self._state.add_restriction(select.where.accept(self)) |
522 sql = [self._selection_sql(select.selection, distinct, needalias)] |
593 sql = [self._selection_sql(select.selection, distinct, needalias)] |
523 if self._state.restrictions: |
594 if self._state.restrictions: |
524 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions)) |
595 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions)) |
|
596 self._state.merge_source_cbs(self._state._needs_source_cb) |
525 # add required tables |
597 # add required tables |
526 assert len(self._state.actual_tables) == 1, self._state.actual_tables |
598 assert len(self._state.actual_tables) == 1, self._state.actual_tables |
527 tables = self._state.actual_tables[-1] |
599 tables = self._state.actual_tables[-1] |
528 if tables: |
600 if tables: |
529 # sort for test predictability |
601 # sort for test predictability |
530 sql.insert(1, 'FROM %s' % ', '.join(sorted(tables))) |
602 sql.insert(1, 'FROM %s' % ', '.join(sorted(tables))) |
531 elif self._state.restrictions and self.dbms_helper.needs_from_clause: |
603 elif self._state.restrictions and self.dbhelper.needs_from_clause: |
532 sql.insert(1, 'FROM (SELECT 1) AS _T') |
604 sql.insert(1, 'FROM (SELECT 1) AS _T') |
533 sqls.append('\n'.join(sql)) |
605 sqls.append('\n'.join(sql)) |
534 if select.need_intersect: |
606 if select.need_intersect: |
535 #if distinct or not self.dbms_helper.intersect_all_support: |
607 #if distinct or not self.dbhelper.intersect_all_support: |
536 return '\nINTERSECT\n'.join(sqls) |
608 return '\nINTERSECT\n'.join(sqls) |
537 #else: |
609 #else: |
538 # return '\nINTERSECT ALL\n'.join(sqls) |
610 # return '\nINTERSECT ALL\n'.join(sqls) |
539 elif distinct: |
611 elif distinct: |
540 return '\nUNION\n'.join(sqls) |
612 return '\nUNION\n'.join(sqls) |
892 try: |
964 try: |
893 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)] |
965 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)] |
894 except KeyError: |
966 except KeyError: |
895 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type) |
967 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type) |
896 if mapkey in self.attr_map: |
968 if mapkey in self.attr_map: |
897 lhssql = self.attr_map[mapkey](self, lhs.variable, rel) |
969 cb, sourcecb = self.attr_map[mapkey] |
|
970 if sourcecb: |
|
971 # callback is a source callback, we can't use this |
|
972 # attribute in restriction |
|
973 raise QueryError("can't use %s (%s) in restriction" |
|
974 % (mapkey, rel.as_string())) |
|
975 lhssql = cb(self, lhs.variable, rel) |
898 elif rel.r_type == 'eid': |
976 elif rel.r_type == 'eid': |
899 lhssql = lhs.variable._q_sql |
977 lhssql = lhs.variable._q_sql |
900 else: |
978 else: |
901 lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type) |
979 lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type) |
902 try: |
980 try: |
941 if isinstance(rel.parent, Not): |
1019 if isinstance(rel.parent, Not): |
942 self._state.done.add(rel.parent) |
1020 self._state.done.add(rel.parent) |
943 not_ = True |
1021 not_ = True |
944 else: |
1022 else: |
945 not_ = False |
1023 not_ = False |
946 return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args), |
1024 return self.dbhelper.fti_restriction_sql(alias, const.eval(self._args), |
947 jointo, not_) + restriction |
1025 jointo, not_) + restriction |
948 |
1026 |
949 def visit_comparison(self, cmp): |
1027 def visit_comparison(self, cmp): |
950 """generate SQL for a comparison""" |
1028 """generate SQL for a comparison""" |
951 if len(cmp.children) == 2: |
1029 if len(cmp.children) == 2: |
954 else: |
1032 else: |
955 lhs = None |
1033 lhs = None |
956 rhs = cmp.children[0] |
1034 rhs = cmp.children[0] |
957 operator = cmp.operator |
1035 operator = cmp.operator |
958 if operator in ('IS', 'LIKE', 'ILIKE'): |
1036 if operator in ('IS', 'LIKE', 'ILIKE'): |
959 if operator == 'ILIKE' and not self.dbms_helper.ilike_support: |
1037 if operator == 'ILIKE' and not self.dbhelper.ilike_support: |
960 operator = ' LIKE ' |
1038 operator = ' LIKE ' |
961 else: |
1039 else: |
962 operator = ' %s ' % operator |
1040 operator = ' %s ' % operator |
963 elif (operator == '=' and isinstance(rhs, Constant) |
1041 elif (operator == '=' and isinstance(rhs, Constant) |
964 and rhs.eval(self._args) is None): |
1042 and rhs.eval(self._args) is None): |
984 pass |
1062 pass |
985 return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self)) |
1063 return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self)) |
986 |
1064 |
987 def visit_function(self, func): |
1065 def visit_function(self, func): |
988 """generate SQL name for a function""" |
1066 """generate SQL name for a function""" |
989 # func_sql_call will check function is supported by the backend |
1067 args = [c.accept(self) for c in func.children] |
990 return self.dbms_helper.func_as_sql(func.name, |
1068 if func in self._state.source_cb_funcs: |
991 [c.accept(self) for c in func.children]) |
1069 # function executed as a callback on the source |
|
1070 assert len(args) == 1 |
|
1071 return args[0] |
|
1072 # func_as_sql will check function is supported by the backend |
|
1073 return self.dbhelper.func_as_sql(func.name, args) |
992 |
1074 |
993 def visit_constant(self, constant): |
1075 def visit_constant(self, constant): |
994 """generate SQL name for a constant""" |
1076 """generate SQL name for a constant""" |
995 value = constant.value |
1077 value = constant.value |
996 if constant.type is None: |
1078 if constant.type is None: |
1001 rel = constant.relation() |
1083 rel = constant.relation() |
1002 if rel is not None: |
1084 if rel is not None: |
1003 rel._q_needcast = value |
1085 rel._q_needcast = value |
1004 return self.keyword_map[value]() |
1086 return self.keyword_map[value]() |
1005 if constant.type == 'Boolean': |
1087 if constant.type == 'Boolean': |
1006 value = self.dbms_helper.boolean_value(value) |
1088 value = self.dbhelper.boolean_value(value) |
1007 if constant.type == 'Substitute': |
1089 if constant.type == 'Substitute': |
1008 _id = constant.value |
1090 _id = constant.value |
1009 if isinstance(_id, unicode): |
1091 if isinstance(_id, unicode): |
1010 _id = _id.encode() |
1092 _id = _id.encode() |
1011 else: |
1093 else: |
1063 etypes = ','.join("'%s'" % et for et in pts) |
1145 etypes = ','.join("'%s'" % et for et in pts) |
1064 restr = '%s.type IN (%s)' % (vtablename, etypes) |
1146 restr = '%s.type IN (%s)' % (vtablename, etypes) |
1065 self._state.add_restriction(restr) |
1147 self._state.add_restriction(restr) |
1066 elif principal.r_type == 'has_text': |
1148 elif principal.r_type == 'has_text': |
1067 sql = '%s.%s' % (self._fti_table(principal), |
1149 sql = '%s.%s' % (self._fti_table(principal), |
1068 self.dbms_helper.fti_uid_attr) |
1150 self.dbhelper.fti_uid_attr) |
1069 elif principal in variable.stinfo['rhsrelations']: |
1151 elif principal in variable.stinfo['rhsrelations']: |
1070 if self.schema.rschema(principal.r_type).inlined: |
1152 if self.schema.rschema(principal.r_type).inlined: |
1071 sql = self._linked_var_sql(variable) |
1153 sql = self._linked_var_sql(variable) |
1072 else: |
1154 else: |
1073 sql = '%s.eid_to' % self._relation_table(principal) |
1155 sql = '%s.eid_to' % self._relation_table(principal) |
1153 if rel.r_type == 'eid': |
1235 if rel.r_type == 'eid': |
1154 return linkedvar.accept(self) |
1236 return linkedvar.accept(self) |
1155 if isinstance(linkedvar, ColumnAlias): |
1237 if isinstance(linkedvar, ColumnAlias): |
1156 raise BadRQLQuery('variable %s should be selected by the subquery' |
1238 raise BadRQLQuery('variable %s should be selected by the subquery' |
1157 % variable.name) |
1239 % variable.name) |
1158 mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type) |
|
1159 if mapkey in self.attr_map: |
|
1160 return self.attr_map[mapkey](self, linkedvar, rel) |
|
1161 try: |
1240 try: |
1162 sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)] |
1241 sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)] |
1163 except KeyError: |
1242 except KeyError: |
|
1243 mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type) |
|
1244 if mapkey in self.attr_map: |
|
1245 cb, sourcecb = self.attr_map[mapkey] |
|
1246 if not sourcecb: |
|
1247 return cb(self, linkedvar, rel) |
|
1248 # attribute mapped at the source level (bfss for instance) |
|
1249 stmt = rel.stmt |
|
1250 for selectidx, vref in iter_mapped_var_sels(stmt, variable): |
|
1251 stack = [cb] |
|
1252 update_source_cb_stack(self._state, stmt, vref, stack) |
|
1253 self._state._needs_source_cb[selectidx] = stack |
1164 linkedvar.accept(self) |
1254 linkedvar.accept(self) |
1165 sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type) |
1255 sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type) |
1166 return sql |
1256 return sql |
1167 |
1257 |
1168 # tables handling ######################################################### |
1258 # tables handling ######################################################### |
1265 try: |
1355 try: |
1266 return relation._q_sqltable |
1356 return relation._q_sqltable |
1267 except AttributeError: |
1357 except AttributeError: |
1268 pass |
1358 pass |
1269 self._state.done.add(relation) |
1359 self._state.done.add(relation) |
1270 alias = self.alias_and_add_table(self.dbms_helper.fti_table) |
1360 alias = self.alias_and_add_table(self.dbhelper.fti_table) |
1271 relation._q_sqltable = alias |
1361 relation._q_sqltable = alias |
1272 return alias |
1362 return alias |
1273 |
1363 |
1274 def _varmap_table_scope(self, select, table): |
1364 def _varmap_table_scope(self, select, table): |
1275 """since a varmap table may be used for multiple variable, its scope is |
1365 """since a varmap table may be used for multiple variable, its scope is |