server/sources/rql2sql.py
changeset 5016 b3b0b808a0ed
parent 5004 4cc020ee70e2
parent 5013 ad91f93bbb93
child 5302 dfd147de06b2
--- a/server/sources/rql2sql.py	Wed Mar 24 18:04:59 2010 +0100
+++ b/server/sources/rql2sql.py	Thu Mar 25 14:26:13 2010 +0100
@@ -33,16 +33,30 @@
 
 import threading
 
+from logilab.database import FunctionDescr, SQL_FUNCTIONS_REGISTRY
+
 from rql import BadRQLQuery, CoercionError
 from rql.stmts import Union, Select
 from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
                        Variable, ColumnAlias, Relation, SubQuery, Exists)
 
+from cubicweb import QueryError
 from cubicweb.server.sqlutils import SQL_PREFIX
 from cubicweb.server.utils import cleanup_solutions
 
 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
 
+FunctionDescr.source_execute = None
+
+def default_update_cb_stack(self, stack):
+    stack.append(self.source_execute)
+FunctionDescr.update_cb_stack = default_update_cb_stack
+
+LENGTH = SQL_FUNCTIONS_REGISTRY.get_function('LENGTH')
+def length_source_execute(source, value):
+    return len(value.getvalue())
+LENGTH.source_execute = length_source_execute
+
 def _new_var(select, varname):
     newvar = select.get_variable(varname)
     if not 'relations' in newvar.stinfo:
@@ -252,14 +266,44 @@
                         selectedidx.append(vref.name)
                         rqlst.selection.append(vref)
 
-# IGenerator implementation for RQL->SQL ######################################
+def iter_mapped_var_sels(stmt, variable):
+    # variable is a Variable or ColumnAlias node mapped to a source side
+    # callback
+    if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
+            variable.stinfo['selected']):
+        raise QueryError("can't use %s as a restriction variable"
+                         % variable.name)
+    for selectidx in variable.stinfo['selected']:
+        vrefs = stmt.selection[selectidx].get_nodes(VariableRef)
+        if len(vrefs) != 1:
+            raise QueryError()
+        yield selectidx, vrefs[0]
 
+def update_source_cb_stack(state, stmt, node, stack):
+    while True:
+        node = node.parent
+        if node is stmt:
+            break
+        if not isinstance(node, Function):
+            raise QueryError()
+        func = SQL_FUNCTIONS_REGISTRY.get_function(node.name)
+        if func.source_execute is None:
+            raise QueryError('%s can not be called on mapped attribute'
+                             % node.name)
+        state.source_cb_funcs.add(node)
+        func.update_cb_stack(stack)
+
+
+# IGenerator implementation for RQL->SQL #######################################
 
 class StateInfo(object):
     def __init__(self, existssols, unstablevars):
         self.existssols = existssols
         self.unstablevars = unstablevars
         self.subtables = {}
+        self.needs_source_cb = None
+        self.subquery_source_cb = None
+        self.source_cb_funcs = set()
 
     def reset(self, solution):
         """reset some visit variables"""
@@ -276,6 +320,17 @@
         self.restrictions = []
         self._restr_stack = []
         self.ignore_varmap = False
+        self._needs_source_cb = {}
+
+    def merge_source_cbs(self, needs_source_cb):
+        if self.needs_source_cb is None:
+            self.needs_source_cb = needs_source_cb
+        elif needs_source_cb != self.needs_source_cb:
+            raise QueryError('query fetch some source mapped attribute, some not')
+
+    def finalize_source_cbs(self):
+        if self.subquery_source_cb is not None:
+            self.needs_source_cb.update(self.subquery_source_cb)
 
     def add_restriction(self, restr):
         if restr:
@@ -332,16 +387,16 @@
     protected by a lock
     """
 
-    def __init__(self, schema, dbms_helper, attrmap=None):
+    def __init__(self, schema, dbhelper, attrmap=None):
         self.schema = schema
-        self.dbms_helper = dbms_helper
-        self.dbencoding = dbms_helper.dbencoding
-        self.keyword_map = {'NOW' : self.dbms_helper.sql_current_timestamp,
-                            'TODAY': self.dbms_helper.sql_current_date,
+        self.dbhelper = dbhelper
+        self.dbencoding = dbhelper.dbencoding
+        self.keyword_map = {'NOW' : self.dbhelper.sql_current_timestamp,
+                            'TODAY': self.dbhelper.sql_current_date,
                             }
-        if not self.dbms_helper.union_parentheses_support:
+        if not self.dbhelper.union_parentheses_support:
             self.union_sql = self.noparen_union_sql
-        if self.dbms_helper.fti_need_distinct:
+        if self.dbhelper.fti_need_distinct:
             self.__union_sql = self.union_sql
             self.union_sql = self.has_text_need_distinct_union_sql
         self._lock = threading.Lock()
@@ -373,7 +428,7 @@
             # union query for each rqlst / solution
             sql = self.union_sql(union)
             # we are done
-            return sql, self._query_attrs
+            return sql, self._query_attrs, self._state.needs_source_cb
         finally:
             self._lock.release()
 
@@ -391,9 +446,10 @@
         return '\nUNION ALL\n'.join(sqls)
 
     def noparen_union_sql(self, union, needalias=False):
-        # needed for sqlite backend which doesn't like parentheses around
-        # union query. This may cause bug in some condition (sort in one of
-        # the subquery) but will work in most case
+        # needed for sqlite backend which doesn't like parentheses around union
+        # query. This may cause bug in some condition (sort in one of the
+        # subquery) but will work in most case
+        #
         # see http://www.sqlite.org/cvstrac/tktview?tn=3074
         sqls = (self.select_sql(select, needalias)
                 for i, select in enumerate(union.children))
@@ -435,6 +491,9 @@
         else:
             existssols, unstable = {}, ()
         state = StateInfo(existssols, unstable)
+        if self._state is not None:
+            # state from a previous unioned select
+            state.merge_source_cbs(self._state.needs_source_cb)
         # treat subqueries
         self._subqueries_sql(select, state)
         # generate sql for this select node
@@ -490,6 +549,7 @@
                 if fneedwrap:
                     selection = ['T1.C%s' % i for i in xrange(len(origselection))]
                     sql = 'SELECT %s FROM (%s) AS T1' % (','.join(selection), sql)
+            state.finalize_source_cbs()
         finally:
             select.selection = origselection
         # limit / offset
@@ -504,13 +564,24 @@
     def _subqueries_sql(self, select, state):
         for i, subquery in enumerate(select.with_):
             sql = self.union_sql(subquery.query, needalias=True)
-            tablealias = '_T%s' % i
+            tablealias = '_T%s' % i # XXX nested subqueries
             sql = '(%s) AS %s' % (sql, tablealias)
             state.subtables[tablealias] = (0, sql)
+            latest_state = self._state
             for vref in subquery.aliases:
                 alias = vref.variable
                 alias._q_sqltable = tablealias
                 alias._q_sql = '%s.C%s' % (tablealias, alias.colnum)
+                try:
+                    stack = latest_state.needs_source_cb[alias.colnum]
+                    if state.subquery_source_cb is None:
+                        state.subquery_source_cb = {}
+                    for selectidx, vref in iter_mapped_var_sels(select, alias):
+                        stack = stack[:]
+                        update_source_cb_stack(state, select, vref, stack)
+                        state.subquery_source_cb[selectidx] = stack
+                except KeyError:
+                    continue
 
     def _solutions_sql(self, select, solutions, distinct, needalias):
         sqls = []
@@ -522,17 +593,18 @@
             sql = [self._selection_sql(select.selection, distinct, needalias)]
             if self._state.restrictions:
                 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
+            self._state.merge_source_cbs(self._state._needs_source_cb)
             # add required tables
             assert len(self._state.actual_tables) == 1, self._state.actual_tables
             tables = self._state.actual_tables[-1]
             if tables:
                 # sort for test predictability
                 sql.insert(1, 'FROM %s' % ', '.join(sorted(tables)))
-            elif self._state.restrictions and self.dbms_helper.needs_from_clause:
+            elif self._state.restrictions and self.dbhelper.needs_from_clause:
                 sql.insert(1, 'FROM (SELECT 1) AS _T')
             sqls.append('\n'.join(sql))
         if select.need_intersect:
-            #if distinct or not self.dbms_helper.intersect_all_support:
+            #if distinct or not self.dbhelper.intersect_all_support:
             return '\nINTERSECT\n'.join(sqls)
             #else:
             #    return '\nINTERSECT ALL\n'.join(sqls)
@@ -894,7 +966,13 @@
             except KeyError:
                 mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
                 if mapkey in self.attr_map:
-                    lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+                    cb, sourcecb = self.attr_map[mapkey]
+                    if sourcecb:
+                        # callback is a source callback, we can't use this
+                        # attribute in restriction
+                        raise QueryError("can't use %s (%s) in restriction"
+                                         % (mapkey, rel.as_string()))
+                    lhssql = cb(self, lhs.variable, rel)
                 elif rel.r_type == 'eid':
                     lhssql = lhs.variable._q_sql
                 else:
@@ -943,7 +1021,7 @@
             not_ = True
         else:
             not_ = False
-        return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args),
+        return self.dbhelper.fti_restriction_sql(alias, const.eval(self._args),
                                                     jointo, not_) + restriction
 
     def visit_comparison(self, cmp):
@@ -956,7 +1034,7 @@
             rhs = cmp.children[0]
         operator = cmp.operator
         if operator in ('IS', 'LIKE', 'ILIKE'):
-            if operator == 'ILIKE' and not self.dbms_helper.ilike_support:
+            if operator == 'ILIKE' and not self.dbhelper.ilike_support:
                 operator = ' LIKE '
             else:
                 operator = ' %s ' % operator
@@ -986,9 +1064,13 @@
 
     def visit_function(self, func):
         """generate SQL name for a function"""
-        # func_sql_call will check function is supported by the backend
-        return self.dbms_helper.func_as_sql(func.name,
-                                            [c.accept(self) for c in func.children])
+        args = [c.accept(self) for c in func.children]
+        if func in self._state.source_cb_funcs:
+            # function executed as a callback on the source
+            assert len(args) == 1
+            return args[0]
+        # func_as_sql will check function is supported by the backend
+        return self.dbhelper.func_as_sql(func.name, args)
 
     def visit_constant(self, constant):
         """generate SQL name for a constant"""
@@ -1003,7 +1085,7 @@
                 rel._q_needcast = value
             return self.keyword_map[value]()
         if constant.type == 'Boolean':
-            value = self.dbms_helper.boolean_value(value)
+            value = self.dbhelper.boolean_value(value)
         if constant.type == 'Substitute':
             _id = constant.value
             if isinstance(_id, unicode):
@@ -1065,7 +1147,7 @@
                     self._state.add_restriction(restr)
             elif principal.r_type == 'has_text':
                 sql = '%s.%s' % (self._fti_table(principal),
-                                 self.dbms_helper.fti_uid_attr)
+                                 self.dbhelper.fti_uid_attr)
             elif principal in variable.stinfo['rhsrelations']:
                 if self.schema.rschema(principal.r_type).inlined:
                     sql = self._linked_var_sql(variable)
@@ -1155,12 +1237,20 @@
         if isinstance(linkedvar, ColumnAlias):
             raise BadRQLQuery('variable %s should be selected by the subquery'
                               % variable.name)
-        mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
-        if mapkey in self.attr_map:
-            return self.attr_map[mapkey](self, linkedvar, rel)
         try:
             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
         except KeyError:
+            mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+            if mapkey in self.attr_map:
+                cb, sourcecb = self.attr_map[mapkey]
+                if not sourcecb:
+                    return cb(self, linkedvar, rel)
+                # attribute mapped at the source level (bfss for instance)
+                stmt = rel.stmt
+                for selectidx, vref in iter_mapped_var_sels(stmt, variable):
+                    stack = [cb]
+                    update_source_cb_stack(self._state, stmt, vref, stack)
+                    self._state._needs_source_cb[selectidx] = stack
             linkedvar.accept(self)
             sql = '%s.%s%s' % (linkedvar._q_sqltable, SQL_PREFIX, rel.r_type)
         return sql
@@ -1267,7 +1357,7 @@
             except AttributeError:
                 pass
         self._state.done.add(relation)
-        alias = self.alias_and_add_table(self.dbms_helper.fti_table)
+        alias = self.alias_and_add_table(self.dbhelper.fti_table)
         relation._q_sqltable = alias
         return alias