--- a/server/sources/rql2sql.py Wed Mar 24 18:04:59 2010 +0100
+++ b/server/sources/rql2sql.py Thu Mar 25 14:26:13 2010 +0100
@@ -33,16 +33,30 @@
import threading
+from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY
+
from rql import BadRQLQuery, CoercionError
from rql.stmts import Union, Select
from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
Variable, ColumnAlias, Relation, SubQuery, Exists)
+from cubicweb import QueryError
from cubicweb.server.sqlutils import SQL_PREFIX
from cubicweb.server.utils import cleanup_solutions
ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
+FunctionDescr.source_execute = None
+
+def default_update_cb_stack(self, stack):
+ stack.append(self.source_execute)
+FunctionDescr.update_cb_stack = default_update_cb_stack
+
+LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH')
+def length_source_execute(source, value):
+ return len(value.getvalue())
+LENGTH.source_execute = length_source_execute
+
def _new_var(select, varname):
newvar = select.get_variable(varname)
if not 'relations' in newvar.stinfo:
@@ -252,14 +266,44 @@
selectedidx.append(vref.name)
rqlst.selection.append(vref)
-# IGenerator implementation for RQL->SQL ######################################
+def iter_mapped_var_sels(stmt, variable):
+ # variable is a Variable or ColumnAlias node mapped to a source side
+ # callback
+ if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
+ variable.stinfo['selected']):
+ raise QueryError("can't use %s as a restriction variable"
+ % variable.name)
+ for selectidx in variable.stinfo['selected']:
+ vrefs = stmt.selection[selectidx].get_nodes(VariableRef)
+ if len(vrefs) != 1:
+ raise QueryError()
+ yield selectidx, vrefs[0]
+def update_source_cb_stack(state, stmt, node, stack):
+ while True:
+ node = node.parent
+ if node is stmt:
+ break
+ if not isinstance(node, Function):
+ raise QueryError()
+ func = SQL_FUNCTIONS_REGISTRY.get_function(node.name)
+ if func.source_execute is None:
+ raise QueryError('%s can not be called on mapped attribute'
+ % node.name)
+ state.source_cb_funcs.add(node)
+ func.update_cb_stack(stack)
+
+
+# IGenerator implementation for RQL->SQL #######################################
class StateInfo(object):
def __init__(self, existssols, unstablevars):
self.existssols = existssols
self.unstablevars = unstablevars
self.subtables = {}
+ self.needs_source_cb = None
+ self.subquery_source_cb = None
+ self.source_cb_funcs = set()
def reset(self, solution):
"""reset some visit variables"""
@@ -276,6 +320,17 @@
self.restrictions = []
self._restr_stack = []
self.ignore_varmap = False
+ self._needs_source_cb = {}
+
+ def merge_source_cbs(self, needs_source_cb):
+ if self.needs_source_cb is None:
+ self.needs_source_cb = needs_source_cb
+ elif needs_source_cb != self.needs_source_cb:
+ raise QueryError('query fetch some source mapped attribute, some not')
+
+ def finalize_source_cbs(self):
+ if self.subquery_source_cb is not None:
+ self.needs_source_cb.update(self.subquery_source_cb)
def add_restriction(self, restr):
if restr:
@@ -332,16 +387,16 @@
protected by a lock
"""
- def __init__(self, schema, dbms_helper, attrmap=None):
+ def __init__(self, schema, dbhelper, attrmap=None):
self.schema = schema
- self.dbms_helper = dbms_helper
- self.dbencoding = dbms_helper.dbencoding
- self.keyword_map = {'NOW' : self.dbms_helper.sql_current_timestamp,
- 'TODAY': self.dbms_helper.sql_current_date,
+ self.dbhelper = dbhelper
+ self.dbencoding = dbhelper.dbencoding
+ self.keyword_map = {'NOW' : self.dbhelper.sql_current_timestamp,
+ 'TODAY': self.dbhelper.sql_current_date,
}
- if not self.dbms_helper.union_parentheses_support:
+ if not self.dbhelper.union_parentheses_support:
self.union_sql = self.noparen_union_sql
- if self.dbms_helper.fti_need_distinct:
+ if self.dbhelper.fti_need_distinct:
self.__union_sql = self.union_sql
self.union_sql = self.has_text_need_distinct_union_sql
self._lock = threading.Lock()
@@ -373,7 +428,7 @@
# union query for each rqlst / solution
sql = self.union_sql(union)
# we are done
- return sql, self._query_attrs
+ return sql, self._query_attrs, self._state.needs_source_cb
finally:
self._lock.release()
@@ -391,9 +446,10 @@
return '\nUNION ALL\n'.join(sqls)
def noparen_union_sql(self, union, needalias=False):
- # needed for sqlite backend which doesn't like parentheses around
- # union query. This may cause bug in some condition (sort in one of
- # the subquery) but will work in most case
+ # needed for sqlite backend which doesn't like parentheses around union
+ # query. This may cause bug in some condition (sort in one of the
+ # subquery) but will work in most case
+ #
# see http://www.sqlite.org/cvstrac/tktview?tn=3074
sqls = (self.select_sql(select, needalias)
for i, select in enumerate(union.children))
@@ -435,6 +491,9 @@
else:
existssols, unstable = {}, ()
state = StateInfo(existssols, unstable)
+ if self._state is not None:
+ # state from a previous unioned select
+ state.merge_source_cbs(self._state.needs_source_cb)
# treat subqueries
self._subqueries_sql(select, state)
# generate sql for this select node
@@ -490,6 +549,7 @@
if fneedwrap:
selection = ['T1.C%s' % i for i in xrange(len(origselection))]
sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
+ state.finalize_source_cbs()
finally:
select.selection = origselection
# limit / offset
@@ -504,13 +564,24 @@
def _subqueries_sql(self, select, state):
for i, subquery in enumerate(select.with_):
sql = self.union_sql(subquery.query, needalias=True)
- tablealias = '_T%s' % i
+ tablealias = '_T%s' % i # XXX nested subqueries
sql = '(%s) AS %s' % (sql, tablealias)
state.subtables[tablealias] = (0, sql)
+ latest_state = self._state
for vref in subquery.aliases:
alias = vref.variable
alias._q_sqltable = tablealias
alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
+ try:
+ stack = latest_state.needs_source_cb[alias.colnum]
+ if state.subquery_source_cb is None:
+ state.subquery_source_cb = {}
+ for selectidx, vref in iter_mapped_var_sels(select, alias):
+ stack = stack[:]
+ update_source_cb_stack(state, select, vref, stack)
+ state.subquery_source_cb[selectidx] = stack
+ except KeyError:
+ continue
def _solutions_sql(self, select, solutions, distinct, needalias):
sqls = []
@@ -522,17 +593,18 @@
sql = [self._selection_sql(select.selection, distinct, needalias)]
if self._state.restrictions:
sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
+ self._state.merge_source_cbs(self._state._needs_source_cb)
# add required tables
assert len(self._state.actual_tables) == 1, self._state.actual_tables
tables = self._state.actual_tables[-1]
if tables:
# sort for test predictability
sql.insert(1, 'FROM %s' % ', '.join(sorted(tables)))
- elif self._state.restrictions and self.dbms_helper.needs_from_clause:
+ elif self._state.restrictions and self.dbhelper.needs_from_clause:
sql.insert(1, 'FROM (SELECT 1) AS _T')
sqls.append('\n'.join(sql))
if select.need_intersect:
- #if distinct or not self.dbms_helper.intersect_all_support:
+ #if distinct or not self.dbhelper.intersect_all_support:
return '\nINTERSECT\n'.join(sqls)
#else:
# return '\nINTERSECT ALL\n'.join(sqls)
@@ -894,7 +966,13 @@
except KeyError:
mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
if mapkey in self.attr_map:
- lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+ cb, sourcecb = self.attr_map[mapkey]
+ if sourcecb:
+ # callback is a source callback, we can't use this
+ # attribute in restriction
+ raise QueryError("can't use %s (%s) in restriction"
+ % (mapkey, rel.as_string()))
+ lhssql = cb(self, lhs.variable, rel)
elif rel.r_type == 'eid':
lhssql = lhs.variable._q_sql
else:
@@ -943,7 +1021,7 @@
not_ = True
else:
not_ = False
- return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args),
+ return self.dbhelper.fti_restriction_sql(alias, const.eval(self._args),
jointo, not_) + restriction
def visit_comparison(self, cmp):
@@ -956,7 +1034,7 @@
rhs = cmp.children[0]
operator = cmp.operator
if operator in ('IS', 'LIKE', 'ILIKE'):
- if operator == 'ILIKE' and not self.dbms_helper.ilike_support:
+ if operator == 'ILIKE' and not self.dbhelper.ilike_support:
operator = ' LIKE '
else:
operator = ' %s ' % operator
@@ -986,9 +1064,13 @@
def visit_function(self, func):
"""generate SQL name for a function"""
- # func_sql_call will check function is supported by the backend
- return self.dbms_helper.func_as_sql(func.name,
- [c.accept(self) for c in func.children])
+ args = [c.accept(self) for c in func.children]
+ if func in self._state.source_cb_funcs:
+ # function executed as a callback on the source
+ assert len(args) == 1
+ return args[0]
+ # func_as_sql will check function is supported by the backend
+ return self.dbhelper.func_as_sql(func.name, args)
def visit_constant(self, constant):
"""generate SQL name for a constant"""
@@ -1003,7 +1085,7 @@
rel._q_needcast = value
return self.keyword_map[value]()
if constant.type == 'Boolean':
- value = self.dbms_helper.boolean_value(value)
+ value = self.dbhelper.boolean_value(value)
if constant.type == 'Substitute':
_id = constant.value
if isinstance(_id, unicode):
@@ -1065,7 +1147,7 @@
self._state.add_restriction(restr)
elif principal.r_type == 'has_text':
sql = '%s.%s' % (self._fti_table(principal),
- self.dbms_helper.fti_uid_attr)
+ self.dbhelper.fti_uid_attr)
elif principal in variable.stinfo['rhsrelations']:
if self.schema.rschema(principal.r_type).inlined:
sql = self._linked_var_sql(variable)
@@ -1155,12 +1237,20 @@
if isinstance(linkedvar, ColumnAlias):
raise BadRQLQuery('variable %s should be selected by the subquery'
% variable.name)
- mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
- if mapkey in self.attr_map:
- return self.attr_map[mapkey](self, linkedvar, rel)
try:
sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
except KeyError:
+ mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+ if mapkey in self.attr_map:
+ cb, sourcecb = self.attr_map[mapkey]
+ if not sourcecb:
+ return cb(self, linkedvar, rel)
+ # attribute mapped at the source level (bfss for instance)
+ stmt = rel.stmt
+ for selectidx, vref in iter_mapped_var_sels(stmt, variable):
+ stack = [cb]
+ update_source_cb_stack(self._state, stmt, vref, stack)
+ self._state._needs_source_cb[selectidx] = stack
linkedvar.accept(self)
sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
return sql
@@ -1267,7 +1357,7 @@
except AttributeError:
pass
self._state.done.add(relation)
- alias = self.alias_and_add_table(self.dbms_helper.fti_table)
+ alias = self.alias_and_add_table(self.dbhelper.fti_table)
relation._q_sqltable = alias
return alias