31 """ |
31 """ |
32 __docformat__ = "restructuredtext en" |
32 __docformat__ = "restructuredtext en" |
33 |
33 |
34 import threading |
34 import threading |
35 |
35 |
|
36 from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY |
|
37 |
36 from rql import BadRQLQuery, CoercionError |
38 from rql import BadRQLQuery, CoercionError |
37 from rql.stmts import Union, Select |
39 from rql.stmts import Union, Select |
38 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not, |
40 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not, |
39 Variable, ColumnAlias, Relation, SubQuery, Exists) |
41 Variable, ColumnAlias, Relation, SubQuery, Exists) |
40 |
42 |
|
43 from cubicweb import QueryError |
41 from cubicweb.server.sqlutils import SQL_PREFIX |
44 from cubicweb.server.sqlutils import SQL_PREFIX |
42 from cubicweb.server.utils import cleanup_solutions |
45 from cubicweb.server.utils import cleanup_solutions |
43 |
46 |
44 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable |
47 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable |
|
48 |
|
49 FunctionDescr.source_execute = None |
|
50 |
|
51 def default_update_cb_stack(self, stack): |
|
52 stack.append(self.source_execute) |
|
53 FunctionDescr.update_cb_stack = default_update_cb_stack |
|
54 |
|
55 LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH') |
|
56 def length_source_execute(source, value): |
|
57 return len(value.getvalue()) |
|
58 LENGTH.source_execute = length_source_execute |
45 |
59 |
46 def _new_var(select, varname): |
60 def _new_var(select, varname): |
47 newvar = select.get_variable(varname) |
61 newvar = select.get_variable(varname) |
48 if not 'relations' in newvar.stinfo: |
62 if not 'relations' in newvar.stinfo: |
49 # not yet initialized |
63 # not yet initialized |
250 for vref in term.iget_nodes(VariableRef): |
264 for vref in term.iget_nodes(VariableRef): |
251 if not vref.name in selectedidx: |
265 if not vref.name in selectedidx: |
252 selectedidx.append(vref.name) |
266 selectedidx.append(vref.name) |
253 rqlst.selection.append(vref) |
267 rqlst.selection.append(vref) |
254 |
268 |
255 # IGenerator implementation for RQL->SQL ###################################### |
269 def iter_mapped_var_sels(stmt, variable): |
256 |
270 # variable is a Variable or ColumnAlias node mapped to a source side |
|
271 # callback |
|
272 if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias |
|
273 variable.stinfo['selected']): |
|
274 raise QueryError("can't use %s as a restriction variable" |
|
275 % variable.name) |
|
276 for selectidx in variable.stinfo['selected']: |
|
277 vrefs = stmt.selection[selectidx].get_nodes(VariableRef) |
|
278 if len(vrefs) != 1: |
|
279 raise QueryError() |
|
280 yield selectidx, vrefs[0] |
|
281 |
|
282 def update_source_cb_stack(state, stmt, node, stack): |
|
283 while True: |
|
284 node = node.parent |
|
285 if node is stmt: |
|
286 break |
|
287 if not isinstance(node, Function): |
|
288 raise QueryError() |
|
289 func = SQL_FUNCTIONS_REGISTRY.get_function(node.name) |
|
290 if func.source_execute is None: |
|
291 raise QueryError('%s can not be called on mapped attribute' |
|
292 % node.name) |
|
293 state.source_cb_funcs.add(node) |
|
294 func.update_cb_stack(stack) |
|
295 |
|
296 |
|
297 # IGenerator implementation for RQL->SQL ####################################### |
257 |
298 |
258 class StateInfo(object): |
299 class StateInfo(object): |
259 def __init__(self, existssols, unstablevars): |
300 def __init__(self, existssols, unstablevars): |
260 self.existssols = existssols |
301 self.existssols = existssols |
261 self.unstablevars = unstablevars |
302 self.unstablevars = unstablevars |
262 self.subtables = {} |
303 self.subtables = {} |
|
304 self.needs_source_cb = None |
|
305 self.subquery_source_cb = None |
|
306 self.source_cb_funcs = set() |
263 |
307 |
264 def reset(self, solution): |
308 def reset(self, solution): |
265 """reset some visit variables""" |
309 """reset some visit variables""" |
266 self.solution = solution |
310 self.solution = solution |
267 self.count = 0 |
311 self.count = 0 |
274 self.duplicate_switches = [] |
318 self.duplicate_switches = [] |
275 self.aliases = {} |
319 self.aliases = {} |
276 self.restrictions = [] |
320 self.restrictions = [] |
277 self._restr_stack = [] |
321 self._restr_stack = [] |
278 self.ignore_varmap = False |
322 self.ignore_varmap = False |
|
323 self._needs_source_cb = {} |
|
324 |
|
325 def merge_source_cbs(self, needs_source_cb): |
|
326 if self.needs_source_cb is None: |
|
327 self.needs_source_cb = needs_source_cb |
|
328 elif needs_source_cb != self.needs_source_cb: |
|
329 raise QueryError('query fetch some source mapped attribute, some not') |
|
330 |
|
331 def finalize_source_cbs(self): |
|
332 if self.subquery_source_cb is not None: |
|
333 self.needs_source_cb.update(self.subquery_source_cb) |
279 |
334 |
280 def add_restriction(self, restr): |
335 def add_restriction(self, restr): |
281 if restr: |
336 if restr: |
282 self.restrictions.append(restr) |
337 self.restrictions.append(restr) |
283 |
338 |
371 self._not_scope_offset = 0 |
426 self._not_scope_offset = 0 |
372 try: |
427 try: |
373 # union query for each rqlst / solution |
428 # union query for each rqlst / solution |
374 sql = self.union_sql(union) |
429 sql = self.union_sql(union) |
375 # we are done |
430 # we are done |
376 return sql, self._query_attrs |
431 return sql, self._query_attrs, self._state.needs_source_cb |
377 finally: |
432 finally: |
378 self._lock.release() |
433 self._lock.release() |
379 |
434 |
380 def has_text_need_distinct_union_sql(self, union, needalias=False): |
435 def has_text_need_distinct_union_sql(self, union, needalias=False): |
381 if getattr(union, 'has_text_query', False): |
436 if getattr(union, 'has_text_query', False): |
434 select.select_only_variables() |
489 select.select_only_variables() |
435 needwrap = True |
490 needwrap = True |
436 else: |
491 else: |
437 existssols, unstable = {}, () |
492 existssols, unstable = {}, () |
438 state = StateInfo(existssols, unstable) |
493 state = StateInfo(existssols, unstable) |
|
494 if self._state is not None: |
|
495 # state from a previous unioned select |
|
496 state.merge_source_cbs(self._state.needs_source_cb) |
439 # treat subqueries |
497 # treat subqueries |
440 self._subqueries_sql(select, state) |
498 self._subqueries_sql(select, state) |
441 # generate sql for this select node |
499 # generate sql for this select node |
442 selectidx = [str(term) for term in select.selection] |
500 selectidx = [str(term) for term in select.selection] |
443 if needwrap: |
501 if needwrap: |
489 fselectidx) |
547 fselectidx) |
490 for sortterm in sorts) |
548 for sortterm in sorts) |
491 if fneedwrap: |
549 if fneedwrap: |
492 selection = ['T1.C%s' % i for i in xrange(len(origselection))] |
550 selection = ['T1.C%s' % i for i in xrange(len(origselection))] |
493 sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql) |
551 sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql) |
|
552 state.finalize_source_cbs() |
494 finally: |
553 finally: |
495 select.selection = origselection |
554 select.selection = origselection |
496 # limit / offset |
555 # limit / offset |
497 limit = select.limit |
556 limit = select.limit |
498 if limit: |
557 if limit: |
506 for i, subquery in enumerate(select.with_): |
565 for i, subquery in enumerate(select.with_): |
507 sql = self.union_sql(subquery.query, needalias=True) |
566 sql = self.union_sql(subquery.query, needalias=True) |
508 tablealias = '_T%s' % i # XXX nested subqueries |
567 tablealias = '_T%s' % i # XXX nested subqueries |
509 sql = '(%s) AS %s' % (sql, tablealias) |
568 sql = '(%s) AS %s' % (sql, tablealias) |
510 state.subtables[tablealias] = (0, sql) |
569 state.subtables[tablealias] = (0, sql) |
|
570 latest_state = self._state |
511 for vref in subquery.aliases: |
571 for vref in subquery.aliases: |
512 alias = vref.variable |
572 alias = vref.variable |
513 alias._q_sqltable = tablealias |
573 alias._q_sqltable = tablealias |
514 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum) |
574 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum) |
|
575 try: |
|
576 stack = latest_state.needs_source_cb[alias.colnum] |
|
577 if state.subquery_source_cb is None: |
|
578 state.subquery_source_cb = {} |
|
579 for selectidx, vref in iter_mapped_var_sels(select, alias): |
|
580 stack = stack[:] |
|
581 update_source_cb_stack(state, select, vref, stack) |
|
582 state.subquery_source_cb[selectidx] = stack |
|
583 except KeyError: |
|
584 continue |
515 |
585 |
516 def _solutions_sql(self, select, solutions, distinct, needalias): |
586 def _solutions_sql(self, select, solutions, distinct, needalias): |
517 sqls = [] |
587 sqls = [] |
518 for solution in solutions: |
588 for solution in solutions: |
519 self._state.reset(solution) |
589 self._state.reset(solution) |
521 if select.where is not None: |
591 if select.where is not None: |
522 self._state.add_restriction(select.where.accept(self)) |
592 self._state.add_restriction(select.where.accept(self)) |
523 sql = [self._selection_sql(select.selection, distinct, needalias)] |
593 sql = [self._selection_sql(select.selection, distinct, needalias)] |
524 if self._state.restrictions: |
594 if self._state.restrictions: |
525 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions)) |
595 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions)) |
|
596 self._state.merge_source_cbs(self._state._needs_source_cb) |
526 # add required tables |
597 # add required tables |
527 assert len(self._state.actual_tables) == 1, self._state.actual_tables |
598 assert len(self._state.actual_tables) == 1, self._state.actual_tables |
528 tables = self._state.actual_tables[-1] |
599 tables = self._state.actual_tables[-1] |
529 if tables: |
600 if tables: |
530 # sort for test predictability |
601 # sort for test predictability |
893 try: |
964 try: |
894 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)] |
965 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)] |
895 except KeyError: |
966 except KeyError: |
896 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type) |
967 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type) |
897 if mapkey in self.attr_map: |
968 if mapkey in self.attr_map: |
898 lhssql = self.attr_map[mapkey](self, lhs.variable, rel) |
969 cb, sourcecb = self.attr_map[mapkey] |
|
970 if sourcecb: |
|
971 # callback is a source callback, we can't use this |
|
972 # attribute in restriction |
|
973 raise QueryError("can't use %s (%s) in restriction" |
|
974 % (mapkey, rel.as_string())) |
|
975 lhssql = cb(self, lhs.variable, rel) |
899 elif rel.r_type == 'eid': |
976 elif rel.r_type == 'eid': |
900 lhssql = lhs.variable._q_sql |
977 lhssql = lhs.variable._q_sql |
901 else: |
978 else: |
902 lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type) |
979 lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type) |
903 try: |
980 try: |
985 pass |
1062 pass |
986 return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self)) |
1063 return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self)) |
987 |
1064 |
988 def visit_function(self, func): |
1065 def visit_function(self, func): |
989 """generate SQL name for a function""" |
1066 """generate SQL name for a function""" |
990 # func_sql_call will check function is supported by the backend |
1067 args = [c.accept(self) for c in func.children] |
991 return self.dbms_helper.func_as_sql(func.name, |
1068 if func in self._state.source_cb_funcs: |
992 [c.accept(self) for c in func.children]) |
1069 # function executed as a callback on the source |
|
1070 assert len(args) == 1 |
|
1071 return args[0] |
|
1072 # func_as_sql will check function is supported by the backend |
|
1073 return self.dbhelper.func_as_sql(func.name, args) |
993 |
1074 |
994 def visit_constant(self, constant): |
1075 def visit_constant(self, constant): |
995 """generate SQL name for a constant""" |
1076 """generate SQL name for a constant""" |
996 value = constant.value |
1077 value = constant.value |
997 if constant.type is None: |
1078 if constant.type is None: |
1154 if rel.r_type == 'eid': |
1235 if rel.r_type == 'eid': |
1155 return linkedvar.accept(self) |
1236 return linkedvar.accept(self) |
1156 if isinstance(linkedvar, ColumnAlias): |
1237 if isinstance(linkedvar, ColumnAlias): |
1157 raise BadRQLQuery('variable %s should be selected by the subquery' |
1238 raise BadRQLQuery('variable %s should be selected by the subquery' |
1158 % variable.name) |
1239 % variable.name) |
1159 mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type) |
|
1160 if mapkey in self.attr_map: |
|
1161 return self.attr_map[mapkey](self, linkedvar, rel) |
|
1162 try: |
1240 try: |
1163 sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)] |
1241 sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)] |
1164 except KeyError: |
1242 except KeyError: |
|
1243 mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type) |
|
1244 if mapkey in self.attr_map: |
|
1245 cb, sourcecb = self.attr_map[mapkey] |
|
1246 if not sourcecb: |
|
1247 return cb(self, linkedvar, rel) |
|
1248 # attribute mapped at the source level (bfss for instance) |
|
1249 stmt = rel.stmt |
|
1250 for selectidx, vref in iter_mapped_var_sels(stmt, variable): |
|
1251 stack = [cb] |
|
1252 update_source_cb_stack(self._state, stmt, vref, stack) |
|
1253 self._state._needs_source_cb[selectidx] = stack |
1165 linkedvar.accept(self) |
1254 linkedvar.accept(self) |
1166 sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type) |
1255 sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type) |
1167 return sql |
1256 return sql |
1168 |
1257 |
1169 # tables handling ######################################################### |
1258 # tables handling ######################################################### |