--- a/server/sources/rql2sql.py Thu Mar 25 13:49:07 2010 +0100
+++ b/server/sources/rql2sql.py Thu Mar 25 13:59:47 2010 +0100
@@ -33,16 +33,30 @@
import threading
+from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY
+
from rql import BadRQLQuery, CoercionError
from rql.stmts import Union, Select
from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
Variable, ColumnAlias, Relation, SubQuery, Exists)
+from cubicweb import QueryError
from cubicweb.server.sqlutils import SQL_PREFIX
from cubicweb.server.utils import cleanup_solutions
ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
+FunctionDescr.source_execute = None
+
+def default_update_cb_stack(self, stack):
+ stack.append(self.source_execute)
+FunctionDescr.update_cb_stack = default_update_cb_stack
+
+LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH')
+def length_source_execute(source, value):
+ return len(value.getvalue())
+LENGTH.source_execute = length_source_execute
+
def _new_var(select, varname):
newvar = select.get_variable(varname)
if not 'relations' in newvar.stinfo:
@@ -252,14 +266,44 @@
selectedidx.append(vref.name)
rqlst.selection.append(vref)
-# IGenerator implementation for RQL->SQL ######################################
+def iter_mapped_var_sels(stmt, variable):
+ # variable is a Variable or ColumnAlias node mapped to a source side
+ # callback
+ if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
+ variable.stinfo['selected']):
+ raise QueryError("can't use %s as a restriction variable"
+ % variable.name)
+ for selectidx in variable.stinfo['selected']:
+ vrefs = stmt.selection[selectidx].get_nodes(VariableRef)
+ if len(vrefs) != 1:
+ raise QueryError()
+ yield selectidx, vrefs[0]
+def update_source_cb_stack(state, stmt, node, stack):
+ while True:
+ node = node.parent
+ if node is stmt:
+ break
+ if not isinstance(node, Function):
+ raise QueryError()
+ func = SQL_FUNCTIONS_REGISTRY.get_function(node.name)
+ if func.source_execute is None:
+ raise QueryError('%s can not be called on mapped attribute'
+ % node.name)
+ state.source_cb_funcs.add(node)
+ func.update_cb_stack(stack)
+
+
+# IGenerator implementation for RQL->SQL #######################################
class StateInfo(object):
def __init__(self, existssols, unstablevars):
self.existssols = existssols
self.unstablevars = unstablevars
self.subtables = {}
+ self.needs_source_cb = None
+ self.subquery_source_cb = None
+ self.source_cb_funcs = set()
def reset(self, solution):
"""reset some visit variables"""
@@ -276,6 +320,17 @@
self.restrictions = []
self._restr_stack = []
self.ignore_varmap = False
+ self._needs_source_cb = {}
+
+ def merge_source_cbs(self, needs_source_cb):
+ if self.needs_source_cb is None:
+ self.needs_source_cb = needs_source_cb
+ elif needs_source_cb != self.needs_source_cb:
+ raise QueryError('query fetch some source mapped attribute, some not')
+
+ def finalize_source_cbs(self):
+ if self.subquery_source_cb is not None:
+ self.needs_source_cb.update(self.subquery_source_cb)
def add_restriction(self, restr):
if restr:
@@ -373,7 +428,7 @@
# union query for each rqlst / solution
sql = self.union_sql(union)
# we are done
- return sql, self._query_attrs
+ return sql, self._query_attrs, self._state.needs_source_cb
finally:
self._lock.release()
@@ -436,6 +491,9 @@
else:
existssols, unstable = {}, ()
state = StateInfo(existssols, unstable)
+ if self._state is not None:
+ # state from a previous unioned select
+ state.merge_source_cbs(self._state.needs_source_cb)
# treat subqueries
self._subqueries_sql(select, state)
# generate sql for this select node
@@ -491,6 +549,7 @@
if fneedwrap:
selection = ['T1.C%s' % i for i in xrange(len(origselection))]
sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
+ state.finalize_source_cbs()
finally:
select.selection = origselection
# limit / offset
@@ -508,10 +567,21 @@
tablealias = '_T%s' % i # XXX nested subqueries
sql = '(%s) AS %s' % (sql, tablealias)
state.subtables[tablealias] = (0, sql)
+ latest_state = self._state
for vref in subquery.aliases:
alias = vref.variable
alias._q_sqltable = tablealias
alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
+ try:
+ stack = latest_state.needs_source_cb[alias.colnum]
+ if state.subquery_source_cb is None:
+ state.subquery_source_cb = {}
+ for selectidx, vref in iter_mapped_var_sels(select, alias):
+ stack = stack[:]
+ update_source_cb_stack(state, select, vref, stack)
+ state.subquery_source_cb[selectidx] = stack
+ except KeyError:
+ continue
def _solutions_sql(self, select, solutions, distinct, needalias):
sqls = []
@@ -523,6 +593,7 @@
sql = [self._selection_sql(select.selection, distinct, needalias)]
if self._state.restrictions:
sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
+ self._state.merge_source_cbs(self._state._needs_source_cb)
# add required tables
assert len(self._state.actual_tables) == 1, self._state.actual_tables
tables = self._state.actual_tables[-1]
@@ -895,7 +966,13 @@
except KeyError:
mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
if mapkey in self.attr_map:
- lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+ cb, sourcecb = self.attr_map[mapkey]
+ if sourcecb:
+ # callback is a source callback, we can't use this
+ # attribute in restriction
+ raise QueryError("can't use %s (%s) in restriction"
+ % (mapkey, rel.as_string()))
+ lhssql = cb(self, lhs.variable, rel)
elif rel.r_type == 'eid':
lhssql = lhs.variable._q_sql
else:
@@ -987,9 +1064,13 @@
def visit_function(self, func):
"""generate SQL name for a function"""
- # func_sql_call will check function is supported by the backend
- return self.dbms_helper.func_as_sql(func.name,
- [c.accept(self) for c in func.children])
+ args = [c.accept(self) for c in func.children]
+ if func in self._state.source_cb_funcs:
+ # function executed as a callback on the source
+ assert len(args) == 1
+ return args[0]
+ # func_as_sql will check function is supported by the backend
+ return self.dbhelper.func_as_sql(func.name, args)
def visit_constant(self, constant):
"""generate SQL name for a constant"""
@@ -1156,12 +1237,20 @@
if isinstance(linkedvar, ColumnAlias):
raise BadRQLQuery('variable %s should be selected by the subquery'
% variable.name)
- mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
- if mapkey in self.attr_map:
- return self.attr_map[mapkey](self, linkedvar, rel)
try:
sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
except KeyError:
+ mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+ if mapkey in self.attr_map:
+ cb, sourcecb = self.attr_map[mapkey]
+ if not sourcecb:
+ return cb(self, linkedvar, rel)
+ # attribute mapped at the source level (bfss for instance)
+ stmt = rel.stmt
+ for selectidx, vref in iter_mapped_var_sels(stmt, variable):
+ stack = [cb]
+ update_source_cb_stack(self._state, stmt, vref, stack)
+ self._state._needs_source_cb[selectidx] = stack
linkedvar.accept(self)
sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
return sql