server/sources/rql2sql.py
branchstable
changeset 5013 ad91f93bbb93
parent 5010 b2c5aee8ca3f
child 5016 b3b0b808a0ed
child 5280 7e13bb484a19
--- a/server/sources/rql2sql.py	Thu Mar 25 13:49:07 2010 +0100
+++ b/server/sources/rql2sql.py	Thu Mar 25 13:59:47 2010 +0100
@@ -33,16 +33,30 @@
 
 import threading
 
+from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY
+
 from rql import BadRQLQuery, CoercionError
 from rql.stmts import Union, Select
 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
                        Variable, ColumnAlias, Relation, SubQuery, Exists)
 
+from cubicweb import QueryError
 from cubicweb.server.sqlutils import SQL_PREFIX
 from cubicweb.server.utils import cleanup_solutions
 
 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
 
+FunctionDescr.source_execute = None
+
+def default_update_cb_stack(self, stack):
+    stack.append(self.source_execute)
+FunctionDescr.update_cb_stack = default_update_cb_stack
+
+LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH')
+def length_source_execute(source, value):
+    return len(value.getvalue())
+LENGTH.source_execute = length_source_execute
+
 def _new_var(select, varname):
     newvar = select.get_variable(varname)
     if not 'relations' in newvar.stinfo:
@@ -252,14 +266,44 @@
                         selectedidx.append(vref.name)
                         rqlst.selection.append(vref)
 
-# IGenerator implementation for RQL->SQL ######################################
+def iter_mapped_var_sels(stmt, variable):
+    # variable is a Variable or ColumnAlias node mapped to a source side
+    # callback
+    if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
+            variable.stinfo['selected']):
+        raise QueryError("can't use %s as a restriction variable"
+                         % variable.name)
+    for selectidx in variable.stinfo['selected']:
+        vrefs = stmt.selection[selectidx].get_nodes(VariableRef)
+        if len(vrefs) != 1:
+            raise QueryError()
+        yield selectidx, vrefs[0]
 
+def update_source_cb_stack(state, stmt, node, stack):
+    while True:
+        node = node.parent
+        if node is stmt:
+            break
+        if not isinstance(node, Function):
+            raise QueryError()
+        func = SQL_FUNCTIONS_REGISTRY.get_function(node.name)
+        if func.source_execute is None:
+            raise QueryError('%s can not be called on mapped attribute'
+                             % node.name)
+        state.source_cb_funcs.add(node)
+        func.update_cb_stack(stack)
+
+
+# IGenerator implementation for RQL->SQL #######################################
 
 class StateInfo(object):
     def __init__(self, existssols, unstablevars):
         self.existssols = existssols
         self.unstablevars = unstablevars
         self.subtables = {}
+        self.needs_source_cb = None
+        self.subquery_source_cb = None
+        self.source_cb_funcs = set()
 
     def reset(self, solution):
         """reset some visit variables"""
@@ -276,6 +320,17 @@
         self.restrictions = []
         self._restr_stack = []
         self.ignore_varmap = False
+        self._needs_source_cb = {}
+
+    def merge_source_cbs(self, needs_source_cb):
+        if self.needs_source_cb is None:
+            self.needs_source_cb = needs_source_cb
+        elif needs_source_cb != self.needs_source_cb:
+            raise QueryError('query fetch some source mapped attribute, some not')
+
+    def finalize_source_cbs(self):
+        if self.subquery_source_cb is not None:
+            self.needs_source_cb.update(self.subquery_source_cb)
 
     def add_restriction(self, restr):
         if restr:
@@ -373,7 +428,7 @@
             # union query for each rqlst / solution
             sql = self.union_sql(union)
             # we are done
-            return sql, self._query_attrs
+            return sql, self._query_attrs, self._state.needs_source_cb
         finally:
             self._lock.release()
 
@@ -436,6 +491,9 @@
         else:
             existssols, unstable = {}, ()
         state = StateInfo(existssols, unstable)
+        if self._state is not None:
+            # state from a previous unioned select
+            state.merge_source_cbs(self._state.needs_source_cb)
         # treat subqueries
         self._subqueries_sql(select, state)
         # generate sql for this select node
@@ -491,6 +549,7 @@
                 if fneedwrap:
                     selection = ['T1.C%s' % i for i in xrange(len(origselection))]
                     sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
+            state.finalize_source_cbs()
         finally:
             select.selection = origselection
         # limit / offset
@@ -508,10 +567,21 @@
             tablealias = '_T%s' % i # XXX nested subqueries
             sql = '(%s) AS %s' % (sql, tablealias)
             state.subtables[tablealias] = (0, sql)
+            latest_state = self._state
             for vref in subquery.aliases:
                 alias = vref.variable
                 alias._q_sqltable = tablealias
                 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
+                try:
+                    stack = latest_state.needs_source_cb[alias.colnum]
+                    if state.subquery_source_cb is None:
+                        state.subquery_source_cb = {}
+                    for selectidx, vref in iter_mapped_var_sels(select, alias):
+                        stack = stack[:]
+                        update_source_cb_stack(state, select, vref, stack)
+                        state.subquery_source_cb[selectidx] = stack
+                except KeyError:
+                    continue
 
     def _solutions_sql(self, select, solutions, distinct, needalias):
         sqls = []
@@ -523,6 +593,7 @@
             sql = [self._selection_sql(select.selection, distinct, needalias)]
             if self._state.restrictions:
                 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
+            self._state.merge_source_cbs(self._state._needs_source_cb)
             # add required tables
             assert len(self._state.actual_tables) == 1, self._state.actual_tables
             tables = self._state.actual_tables[-1]
@@ -895,7 +966,13 @@
             except KeyError:
                 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
                 if mapkey in self.attr_map:
-                    lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+                    cb, sourcecb = self.attr_map[mapkey]
+                    if sourcecb:
+                        # callback is a source callback, we can't use this
+                        # attribute in restriction
+                        raise QueryError("can't use %s (%s) in restriction"
+                                         % (mapkey, rel.as_string()))
+                    lhssql = cb(self, lhs.variable, rel)
                 elif rel.r_type == 'eid':
                     lhssql = lhs.variable._q_sql
                 else:
@@ -987,9 +1064,13 @@
 
     def visit_function(self, func):
         """generate SQL name for a function"""
-        # func_sql_call will check function is supported by the backend
-        return self.dbms_helper.func_as_sql(func.name,
-                                            [c.accept(self) for c in func.children])
+        args = [c.accept(self) for c in func.children]
+        if func in self._state.source_cb_funcs:
+            # function executed as a callback on the source
+            assert len(args) == 1
+            return args[0]
+        # func_as_sql will check function is supported by the backend
+        return self.dbhelper.func_as_sql(func.name, args)
 
     def visit_constant(self, constant):
         """generate SQL name for a constant"""
@@ -1156,12 +1237,20 @@
         if isinstance(linkedvar, ColumnAlias):
             raise BadRQLQuery('variable %s should be selected by the subquery'
                               % variable.name)
-        mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
-        if mapkey in self.attr_map:
-            return self.attr_map[mapkey](self, linkedvar, rel)
         try:
             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
         except KeyError:
+            mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+            if mapkey in self.attr_map:
+                cb, sourcecb = self.attr_map[mapkey]
+                if not sourcecb:
+                    return cb(self, linkedvar, rel)
+                # attribute mapped at the source level (bfss for instance)
+                stmt = rel.stmt
+                for selectidx, vref in iter_mapped_var_sels(stmt, variable):
+                    stack = [cb]
+                    update_source_cb_stack(self._state, stmt, vref, stack)
+                    self._state._needs_source_cb[selectidx] = stack
             linkedvar.accept(self)
             sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
         return sql