server/sources/rql2sql.py
branchstable
changeset 5013 ad91f93bbb93
parent 5010 b2c5aee8ca3f
child 5016 b3b0b808a0ed
child 5280 7e13bb484a19
equal deleted inserted replaced
5012:9c4ea944ecf9 5013:ad91f93bbb93
    31 """
    31 """
    32 __docformat__ = "restructuredtext en"
    32 __docformat__ = "restructuredtext en"
    33 
    33 
    34 import threading
    34 import threading
    35 
    35 
       
    36 from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY
       
    37 
    36 from rql import BadRQLQuery, CoercionError
    38 from rql import BadRQLQuery, CoercionError
    37 from rql.stmts import Union, Select
    39 from rql.stmts import Union, Select
    38 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
    40 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
    39                        Variable, ColumnAlias, Relation, SubQuery, Exists)
    41                        Variable, ColumnAlias, Relation, SubQuery, Exists)
    40 
    42 
       
    43 from cubicweb import QueryError
    41 from cubicweb.server.sqlutils import SQL_PREFIX
    44 from cubicweb.server.sqlutils import SQL_PREFIX
    42 from cubicweb.server.utils import cleanup_solutions
    45 from cubicweb.server.utils import cleanup_solutions
    43 
    46 
    44 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
    47 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
       
    48 
       
    49 FunctionDescr.source_execute = None
       
    50 
       
    51 def default_update_cb_stack(self, stack):
       
    52     stack.append(self.source_execute)
       
    53 FunctionDescr.update_cb_stack = default_update_cb_stack
       
    54 
       
    55 LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH')
       
    56 def length_source_execute(source, value):
       
    57     return len(value.getvalue())
       
    58 LENGTH.source_execute = length_source_execute
    45 
    59 
    46 def _new_var(select, varname):
    60 def _new_var(select, varname):
    47     newvar = select.get_variable(varname)
    61     newvar = select.get_variable(varname)
    48     if not 'relations' in newvar.stinfo:
    62     if not 'relations' in newvar.stinfo:
    49         # not yet initialized
    63         # not yet initialized
   250                 for vref in term.iget_nodes(VariableRef):
   264                 for vref in term.iget_nodes(VariableRef):
   251                     if not vref.name in selectedidx:
   265                     if not vref.name in selectedidx:
   252                         selectedidx.append(vref.name)
   266                         selectedidx.append(vref.name)
   253                         rqlst.selection.append(vref)
   267                         rqlst.selection.append(vref)
   254 
   268 
   255 # IGenerator implementation for RQL->SQL ######################################
   269 def iter_mapped_var_sels(stmt, variable):
   256 
   270     # variable is a Variable or ColumnAlias node mapped to a source side
       
   271     # callback
       
   272     if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
       
   273             variable.stinfo['selected']):
       
   274         raise QueryError("can't use %s as a restriction variable"
       
   275                          % variable.name)
       
   276     for selectidx in variable.stinfo['selected']:
       
   277         vrefs = stmt.selection[selectidx].get_nodes(VariableRef)
       
   278         if len(vrefs) != 1:
       
   279             raise QueryError()
       
   280         yield selectidx, vrefs[0]
       
   281 
       
   282 def update_source_cb_stack(state, stmt, node, stack):
       
   283     while True:
       
   284         node = node.parent
       
   285         if node is stmt:
       
   286             break
       
   287         if not isinstance(node, Function):
       
   288             raise QueryError()
       
   289         func = SQL_FUNCTIONS_REGISTRY.get_function(node.name)
       
   290         if func.source_execute is None:
       
   291             raise QueryError('%s can not be called on mapped attribute'
       
   292                              % node.name)
       
   293         state.source_cb_funcs.add(node)
       
   294         func.update_cb_stack(stack)
       
   295 
       
   296 
       
   297 # IGenerator implementation for RQL->SQL #######################################
   257 
   298 
   258 class StateInfo(object):
   299 class StateInfo(object):
   259     def __init__(self, existssols, unstablevars):
   300     def __init__(self, existssols, unstablevars):
   260         self.existssols = existssols
   301         self.existssols = existssols
   261         self.unstablevars = unstablevars
   302         self.unstablevars = unstablevars
   262         self.subtables = {}
   303         self.subtables = {}
       
   304         self.needs_source_cb = None
       
   305         self.subquery_source_cb = None
       
   306         self.source_cb_funcs = set()
   263 
   307 
   264     def reset(self, solution):
   308     def reset(self, solution):
   265         """reset some visit variables"""
   309         """reset some visit variables"""
   266         self.solution = solution
   310         self.solution = solution
   267         self.count = 0
   311         self.count = 0
   274         self.duplicate_switches = []
   318         self.duplicate_switches = []
   275         self.aliases = {}
   319         self.aliases = {}
   276         self.restrictions = []
   320         self.restrictions = []
   277         self._restr_stack = []
   321         self._restr_stack = []
   278         self.ignore_varmap = False
   322         self.ignore_varmap = False
       
   323         self._needs_source_cb = {}
       
   324 
       
   325     def merge_source_cbs(self, needs_source_cb):
       
   326         if self.needs_source_cb is None:
       
   327             self.needs_source_cb = needs_source_cb
       
   328         elif needs_source_cb != self.needs_source_cb:
       
   329             raise QueryError('query fetch some source mapped attribute, some not')
       
   330 
       
   331     def finalize_source_cbs(self):
       
   332         if self.subquery_source_cb is not None:
       
   333             self.needs_source_cb.update(self.subquery_source_cb)
   279 
   334 
   280     def add_restriction(self, restr):
   335     def add_restriction(self, restr):
   281         if restr:
   336         if restr:
   282             self.restrictions.append(restr)
   337             self.restrictions.append(restr)
   283 
   338 
   371         self._not_scope_offset = 0
   426         self._not_scope_offset = 0
   372         try:
   427         try:
   373             # union query for each rqlst / solution
   428             # union query for each rqlst / solution
   374             sql = self.union_sql(union)
   429             sql = self.union_sql(union)
   375             # we are done
   430             # we are done
   376             return sql, self._query_attrs
   431             return sql, self._query_attrs, self._state.needs_source_cb
   377         finally:
   432         finally:
   378             self._lock.release()
   433             self._lock.release()
   379 
   434 
   380     def has_text_need_distinct_union_sql(self, union, needalias=False):
   435     def has_text_need_distinct_union_sql(self, union, needalias=False):
   381         if getattr(union, 'has_text_query', False):
   436         if getattr(union, 'has_text_query', False):
   434                     select.select_only_variables()
   489                     select.select_only_variables()
   435                     needwrap = True
   490                     needwrap = True
   436         else:
   491         else:
   437             existssols, unstable = {}, ()
   492             existssols, unstable = {}, ()
   438         state = StateInfo(existssols, unstable)
   493         state = StateInfo(existssols, unstable)
       
   494         if self._state is not None:
       
   495             # state from a previous unioned select
       
   496             state.merge_source_cbs(self._state.needs_source_cb)
   439         # treat subqueries
   497         # treat subqueries
   440         self._subqueries_sql(select, state)
   498         self._subqueries_sql(select, state)
   441         # generate sql for this select node
   499         # generate sql for this select node
   442         selectidx = [str(term) for term in select.selection]
   500         selectidx = [str(term) for term in select.selection]
   443         if needwrap:
   501         if needwrap:
   489                                                                      fselectidx)
   547                                                                      fselectidx)
   490                                                   for sortterm in sorts)
   548                                                   for sortterm in sorts)
   491                 if fneedwrap:
   549                 if fneedwrap:
   492                     selection = ['T1.C%s' % i for i in xrange(len(origselection))]
   550                     selection = ['T1.C%s' % i for i in xrange(len(origselection))]
   493                     sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
   551                     sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
       
   552             state.finalize_source_cbs()
   494         finally:
   553         finally:
   495             select.selection = origselection
   554             select.selection = origselection
   496         # limit / offset
   555         # limit / offset
   497         limit = select.limit
   556         limit = select.limit
   498         if limit:
   557         if limit:
   506         for i, subquery in enumerate(select.with_):
   565         for i, subquery in enumerate(select.with_):
   507             sql = self.union_sql(subquery.query, needalias=True)
   566             sql = self.union_sql(subquery.query, needalias=True)
   508             tablealias = '_T%s' % i # XXX nested subqueries
   567             tablealias = '_T%s' % i # XXX nested subqueries
   509             sql = '(%s) AS %s' % (sql, tablealias)
   568             sql = '(%s) AS %s' % (sql, tablealias)
   510             state.subtables[tablealias] = (0, sql)
   569             state.subtables[tablealias] = (0, sql)
       
   570             latest_state = self._state
   511             for vref in subquery.aliases:
   571             for vref in subquery.aliases:
   512                 alias = vref.variable
   572                 alias = vref.variable
   513                 alias._q_sqltable = tablealias
   573                 alias._q_sqltable = tablealias
   514                 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
   574                 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
       
   575                 try:
       
   576                     stack = latest_state.needs_source_cb[alias.colnum]
       
   577                     if state.subquery_source_cb is None:
       
   578                         state.subquery_source_cb = {}
       
   579                     for selectidx, vref in iter_mapped_var_sels(select, alias):
       
   580                         stack = stack[:]
       
   581                         update_source_cb_stack(state, select, vref, stack)
       
   582                         state.subquery_source_cb[selectidx] = stack
       
   583                 except KeyError:
       
   584                     continue
   515 
   585 
   516     def _solutions_sql(self, select, solutions, distinct, needalias):
   586     def _solutions_sql(self, select, solutions, distinct, needalias):
   517         sqls = []
   587         sqls = []
   518         for solution in solutions:
   588         for solution in solutions:
   519             self._state.reset(solution)
   589             self._state.reset(solution)
   521             if select.where is not None:
   591             if select.where is not None:
   522                 self._state.add_restriction(select.where.accept(self))
   592                 self._state.add_restriction(select.where.accept(self))
   523             sql = [self._selection_sql(select.selection, distinct, needalias)]
   593             sql = [self._selection_sql(select.selection, distinct, needalias)]
   524             if self._state.restrictions:
   594             if self._state.restrictions:
   525                 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
   595                 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
       
   596             self._state.merge_source_cbs(self._state._needs_source_cb)
   526             # add required tables
   597             # add required tables
   527             assert len(self._state.actual_tables) == 1, self._state.actual_tables
   598             assert len(self._state.actual_tables) == 1, self._state.actual_tables
   528             tables = self._state.actual_tables[-1]
   599             tables = self._state.actual_tables[-1]
   529             if tables:
   600             if tables:
   530                 # sort for test predictability
   601                 # sort for test predictability
   893             try:
   964             try:
   894                 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)]
   965                 lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)]
   895             except KeyError:
   966             except KeyError:
   896                 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
   967                 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
   897                 if mapkey in self.attr_map:
   968                 if mapkey in self.attr_map:
   898                     lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
   969                     cb, sourcecb = self.attr_map[mapkey]
       
   970                     if sourcecb:
       
   971                         # callback is a source callback, we can't use this
       
   972                         # attribute in restriction
       
   973                         raise QueryError("can't use %s (%s) in restriction"
       
   974                                          % (mapkey, rel.as_string()))
       
   975                     lhssql = cb(self, lhs.variable, rel)
   899                 elif rel.r_type == 'eid':
   976                 elif rel.r_type == 'eid':
   900                     lhssql = lhs.variable._q_sql
   977                     lhssql = lhs.variable._q_sql
   901                 else:
   978                 else:
   902                     lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type)
   979                     lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type)
   903         try:
   980         try:
   985             pass
  1062             pass
   986         return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
  1063         return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
   987 
  1064 
   988     def visit_function(self, func):
  1065     def visit_function(self, func):
   989         """generate SQL name for a function"""
  1066         """generate SQL name for a function"""
   990         # func_sql_call will check function is supported by the backend
  1067         args = [c.accept(self) for c in func.children]
   991         return self.dbms_helper.func_as_sql(func.name,
  1068         if func in self._state.source_cb_funcs:
   992                                             [c.accept(self) for c in func.children])
  1069             # function executed as a callback on the source
       
  1070             assert len(args) == 1
       
  1071             return args[0]
       
  1072         # func_as_sql will check function is supported by the backend
       
  1073         return self.dbhelper.func_as_sql(func.name, args)
   993 
  1074 
   994     def visit_constant(self, constant):
  1075     def visit_constant(self, constant):
   995         """generate SQL name for a constant"""
  1076         """generate SQL name for a constant"""
   996         value = constant.value
  1077         value = constant.value
   997         if constant.type is None:
  1078         if constant.type is None:
  1154         if rel.r_type == 'eid':
  1235         if rel.r_type == 'eid':
  1155             return linkedvar.accept(self)
  1236             return linkedvar.accept(self)
  1156         if isinstance(linkedvar, ColumnAlias):
  1237         if isinstance(linkedvar, ColumnAlias):
  1157             raise BadRQLQuery('variable %s should be selected by the subquery'
  1238             raise BadRQLQuery('variable %s should be selected by the subquery'
  1158                               % variable.name)
  1239                               % variable.name)
  1159         mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
       
  1160         if mapkey in self.attr_map:
       
  1161             return self.attr_map[mapkey](self, linkedvar, rel)
       
  1162         try:
  1240         try:
  1163             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
  1241             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
  1164         except KeyError:
  1242         except KeyError:
       
  1243             mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
       
  1244             if mapkey in self.attr_map:
       
  1245                 cb, sourcecb = self.attr_map[mapkey]
       
  1246                 if not sourcecb:
       
  1247                     return cb(self, linkedvar, rel)
       
  1248                 # attribute mapped at the source level (bfss for instance)
       
  1249                 stmt = rel.stmt
       
  1250                 for selectidx, vref in iter_mapped_var_sels(stmt, variable):
       
  1251                     stack = [cb]
       
  1252                     update_source_cb_stack(self._state, stmt, vref, stack)
       
  1253                     self._state._needs_source_cb[selectidx] = stack
  1165             linkedvar.accept(self)
  1254             linkedvar.accept(self)
  1166             sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
  1255             sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
  1167         return sql
  1256         return sql
  1168 
  1257 
  1169     # tables handling #########################################################
  1258     # tables handling #########################################################