rqlrewrite.py
branchstable
changeset 9167 c05652b108ce
parent 8748 f5027f8d2478
child 9169 544b22a3485b
--- a/rqlrewrite.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/rqlrewrite.py	Mon Jul 15 10:59:34 2013 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -33,6 +33,13 @@
 from cubicweb import Unauthorized
 
 
+def cleanup_solutions(rqlst, solutions):
+    for sol in solutions:
+        for vname in list(sol):
+            if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
+                del sol[vname]
+
+
 def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
     if newroot is None:
         assert solutions is None
@@ -132,10 +139,35 @@
     return newsolutions
 
 
+def _add_noinvariant(noinvariant, restricted, select, nbtrees):
+    # a variable can actually be invariant if it has not been restricted for
+    # security reason or if security assertion hasn't modified the possible
+    # solutions for the query
+    for vname in restricted:
+        try:
+            var = select.defined_vars[vname]
+        except KeyError:
+            # this is an alias
+            continue
+        if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1:
+            noinvariant.add(var)
+
+
+def _expand_selection(terms, selected, aliases, select, newselect):
+    for term in terms:
+        for vref in term.iget_nodes(n.VariableRef):
+            if not vref.name in selected:
+                select.append_selected(vref)
+                colalias = newselect.get_variable(vref.name, len(aliases))
+                aliases.append(n.VariableRef(colalias))
+                selected.add(vref.name)
+
+
 def iter_relations(stinfo):
     # this is a function so that test may return relation in a predictable order
     return stinfo['relations'] - stinfo['rhsrelations']
 
+
 class Unsupported(Exception):
     """raised when an rql expression can't be inserted in some rql query
     because it create an unresolvable query (eg no solutions found)
@@ -164,6 +196,110 @@
         if len(self.select.solutions) < len(self.solutions):
             raise Unsupported()
 
+    def insert_local_checks(self, select, kwargs,
+                            localchecks, restricted, noinvariant):
+        """
+        select: the rql syntax tree Select node
+        kwargs: query arguments
+
+        localchecks: {(('Var name', (rqlexpr1, rqlexpr2)),
+                       ('Var name1', (rqlexpr1, rqlexpr23))): [solution]}
+
+              (see querier._check_permissions docstring for more information)
+
+        restricted: set of variable names to which an rql expression has to be
+              applied
+
+        noinvariant: set of variable names that can't be considered has
+              invariant due to security reason (will be filed by this method)
+        """
+        nbtrees = len(localchecks)
+        myunion = union = select.parent
+        # transform in subquery when len(localchecks)>1 and groups
+        if nbtrees > 1 and (select.orderby or select.groupby or
+                            select.having or select.has_aggregat or
+                            select.distinct or
+                            select.limit or select.offset):
+            newselect = stmts.Select()
+            # only select variables in subqueries
+            origselection = select.selection
+            select.select_only_variables()
+            select.has_aggregat = False
+            # create subquery first so correct node are used on copy
+            # (eg ColumnAlias instead of Variable)
+            aliases = [n.VariableRef(newselect.get_variable(vref.name, i))
+                       for i, vref in enumerate(select.selection)]
+            selected = set(vref.name for vref in aliases)
+            # now copy original selection and groups
+            for term in origselection:
+                newselect.append_selected(term.copy(newselect))
+            if select.orderby:
+                sortterms = []
+                for sortterm in select.orderby:
+                    sortterms.append(sortterm.copy(newselect))
+                    for fnode in sortterm.get_nodes(n.Function):
+                        if fnode.name == 'FTIRANK':
+                            # we've to fetch the has_text relation as well
+                            var = fnode.children[0].variable
+                            rel = iter(var.stinfo['ftirels']).next()
+                            assert not rel.ored(), 'unsupported'
+                            newselect.add_restriction(rel.copy(newselect))
+                            # remove relation from the orig select and
+                            # cleanup variable stinfo
+                            rel.parent.remove(rel)
+                            var.stinfo['ftirels'].remove(rel)
+                            var.stinfo['relations'].remove(rel)
+                            # XXX not properly re-annotated after security insertion?
+                            newvar = newselect.get_variable(var.name)
+                            newvar.stinfo.setdefault('ftirels', set()).add(rel)
+                            newvar.stinfo.setdefault('relations', set()).add(rel)
+                newselect.set_orderby(sortterms)
+                _expand_selection(select.orderby, selected, aliases, select, newselect)
+                select.orderby = () # XXX dereference?
+            if select.groupby:
+                newselect.set_groupby([g.copy(newselect) for g in select.groupby])
+                _expand_selection(select.groupby, selected, aliases, select, newselect)
+                select.groupby = () # XXX dereference?
+            if select.having:
+                newselect.set_having([g.copy(newselect) for g in select.having])
+                _expand_selection(select.having, selected, aliases, select, newselect)
+                select.having = () # XXX dereference?
+            if select.limit:
+                newselect.limit = select.limit
+                select.limit = None
+            if select.offset:
+                newselect.offset = select.offset
+                select.offset = 0
+            myunion = stmts.Union()
+            newselect.set_with([n.SubQuery(aliases, myunion)], check=False)
+            newselect.distinct = select.distinct
+            solutions = [sol.copy() for sol in select.solutions]
+            cleanup_solutions(newselect, solutions)
+            newselect.set_possible_types(solutions)
+            # if some solutions doesn't need rewriting, insert original
+            # select as first union subquery
+            if () in localchecks:
+                myunion.append(select)
+            # we're done, replace original select by the new select with
+            # subqueries (more added in the loop below)
+            union.replace(select, newselect)
+        elif not () in localchecks:
+            union.remove(select)
+        for lcheckdef, lchecksolutions in localchecks.iteritems():
+            if not lcheckdef:
+                continue
+            myrqlst = select.copy(solutions=lchecksolutions)
+            myunion.append(myrqlst)
+            # in-place rewrite + annotation / simplification
+            lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
+            self.rewrite(myrqlst, lcheckdef, lchecksolutions, kwargs)
+            _add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
+        if () in localchecks:
+            select.set_possible_types(localchecks[()])
+            add_types_restriction(self.schema, select)
+            _add_noinvariant(noinvariant, restricted, select, nbtrees)
+        self.annotate(union)
+
     def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
         """
         snippets: (varmap, list of rql expression)