[rql rewrite] move some code from querier to rqlrewrite where it makes more sense. stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 15 Jul 2013 10:59:34 +0200
branchstable
changeset 9167 c05652b108ce
parent 9166 e47e192ea0d9
child 9168 0fb4b67bde58
[rql rewrite] move some code from querier to rqlrewrite where it makes more sense. Also, make some minor cleanup/refactoring on the way and try to enhance docstrings. Relates to #3013535
rqlrewrite.py
server/msplanner.py
server/querier.py
server/sources/rql2sql.py
server/utils.py
--- a/rqlrewrite.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/rqlrewrite.py	Mon Jul 15 10:59:34 2013 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -33,6 +33,13 @@
 from cubicweb import Unauthorized
 
 
+def cleanup_solutions(rqlst, solutions):
+    for sol in solutions:
+        for vname in list(sol):
+            if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
+                del sol[vname]
+
+
 def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
     if newroot is None:
         assert solutions is None
@@ -132,10 +139,35 @@
     return newsolutions
 
 
+def _add_noinvariant(noinvariant, restricted, select, nbtrees):
+    # a variable can actually be invariant if it has not been restricted for
+    # security reason or if security assertion hasn't modified the possible
+    # solutions for the query
+    for vname in restricted:
+        try:
+            var = select.defined_vars[vname]
+        except KeyError:
+            # this is an alias
+            continue
+        if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1:
+            noinvariant.add(var)
+
+
+def _expand_selection(terms, selected, aliases, select, newselect):
+    for term in terms:
+        for vref in term.iget_nodes(n.VariableRef):
+            if not vref.name in selected:
+                select.append_selected(vref)
+                colalias = newselect.get_variable(vref.name, len(aliases))
+                aliases.append(n.VariableRef(colalias))
+                selected.add(vref.name)
+
+
 def iter_relations(stinfo):
     # this is a function so that test may return relation in a predictable order
     return stinfo['relations'] - stinfo['rhsrelations']
 
+
 class Unsupported(Exception):
     """raised when an rql expression can't be inserted in some rql query
     because it create an unresolvable query (eg no solutions found)
@@ -164,6 +196,110 @@
         if len(self.select.solutions) < len(self.solutions):
             raise Unsupported()
 
+    def insert_local_checks(self, select, kwargs,
+                            localchecks, restricted, noinvariant):
+        """
+        select: the rql syntax tree Select node
+        kwargs: query arguments
+
+        localchecks: {(('Var name', (rqlexpr1, rqlexpr2)),
+                       ('Var name1', (rqlexpr1, rqlexpr23))): [solution]}
+
+              (see querier._check_permissions docstring for more information)
+
+        restricted: set of variable names to which an rql expression has to be
+              applied
+
+        noinvariant: set of variable names that can't be considered has
+              invariant due to security reason (will be filed by this method)
+        """
+        nbtrees = len(localchecks)
+        myunion = union = select.parent
+        # transform in subquery when len(localchecks)>1 and groups
+        if nbtrees > 1 and (select.orderby or select.groupby or
+                            select.having or select.has_aggregat or
+                            select.distinct or
+                            select.limit or select.offset):
+            newselect = stmts.Select()
+            # only select variables in subqueries
+            origselection = select.selection
+            select.select_only_variables()
+            select.has_aggregat = False
+            # create subquery first so correct node are used on copy
+            # (eg ColumnAlias instead of Variable)
+            aliases = [n.VariableRef(newselect.get_variable(vref.name, i))
+                       for i, vref in enumerate(select.selection)]
+            selected = set(vref.name for vref in aliases)
+            # now copy original selection and groups
+            for term in origselection:
+                newselect.append_selected(term.copy(newselect))
+            if select.orderby:
+                sortterms = []
+                for sortterm in select.orderby:
+                    sortterms.append(sortterm.copy(newselect))
+                    for fnode in sortterm.get_nodes(n.Function):
+                        if fnode.name == 'FTIRANK':
+                            # we've to fetch the has_text relation as well
+                            var = fnode.children[0].variable
+                            rel = iter(var.stinfo['ftirels']).next()
+                            assert not rel.ored(), 'unsupported'
+                            newselect.add_restriction(rel.copy(newselect))
+                            # remove relation from the orig select and
+                            # cleanup variable stinfo
+                            rel.parent.remove(rel)
+                            var.stinfo['ftirels'].remove(rel)
+                            var.stinfo['relations'].remove(rel)
+                            # XXX not properly re-annotated after security insertion?
+                            newvar = newselect.get_variable(var.name)
+                            newvar.stinfo.setdefault('ftirels', set()).add(rel)
+                            newvar.stinfo.setdefault('relations', set()).add(rel)
+                newselect.set_orderby(sortterms)
+                _expand_selection(select.orderby, selected, aliases, select, newselect)
+                select.orderby = () # XXX dereference?
+            if select.groupby:
+                newselect.set_groupby([g.copy(newselect) for g in select.groupby])
+                _expand_selection(select.groupby, selected, aliases, select, newselect)
+                select.groupby = () # XXX dereference?
+            if select.having:
+                newselect.set_having([g.copy(newselect) for g in select.having])
+                _expand_selection(select.having, selected, aliases, select, newselect)
+                select.having = () # XXX dereference?
+            if select.limit:
+                newselect.limit = select.limit
+                select.limit = None
+            if select.offset:
+                newselect.offset = select.offset
+                select.offset = 0
+            myunion = stmts.Union()
+            newselect.set_with([n.SubQuery(aliases, myunion)], check=False)
+            newselect.distinct = select.distinct
+            solutions = [sol.copy() for sol in select.solutions]
+            cleanup_solutions(newselect, solutions)
+            newselect.set_possible_types(solutions)
+            # if some solutions doesn't need rewriting, insert original
+            # select as first union subquery
+            if () in localchecks:
+                myunion.append(select)
+            # we're done, replace original select by the new select with
+            # subqueries (more added in the loop below)
+            union.replace(select, newselect)
+        elif not () in localchecks:
+            union.remove(select)
+        for lcheckdef, lchecksolutions in localchecks.iteritems():
+            if not lcheckdef:
+                continue
+            myrqlst = select.copy(solutions=lchecksolutions)
+            myunion.append(myrqlst)
+            # in-place rewrite + annotation / simplification
+            lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
+            self.rewrite(myrqlst, lcheckdef, lchecksolutions, kwargs)
+            _add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
+        if () in localchecks:
+            select.set_possible_types(localchecks[()])
+            add_types_restriction(self.schema, select)
+            _add_noinvariant(noinvariant, restricted, select, nbtrees)
+        self.annotate(union)
+
     def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
         """
         snippets: (varmap, list of rql expression)
--- a/server/msplanner.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/server/msplanner.py	Mon Jul 15 10:59:34 2013 +0200
@@ -100,8 +100,7 @@
 
 from cubicweb import server
 from cubicweb.utils import make_uid
-from cubicweb.rqlrewrite import add_types_restriction
-from cubicweb.server.utils import cleanup_solutions
+from cubicweb.rqlrewrite import add_types_restriction, cleanup_solutions
 from cubicweb.server.ssplanner import SSPlanner, OneFetchStep
 from cubicweb.server.mssteps import *
 
--- a/server/querier.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/server/querier.py	Mon Jul 15 10:59:34 2013 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -24,18 +24,15 @@
 
 from logilab.common.compat import any
 from rql import RQLSyntaxError, CoercionError
-from rql.stmts import Union, Select
-from rql.nodes import ETYPE_PYOBJ_MAP, etype_from_pyobj
-from rql.nodes import (Relation, VariableRef, Constant, SubQuery, Function,
-                       Exists, Not)
+from rql.stmts import Union
+from rql.nodes import ETYPE_PYOBJ_MAP, etype_from_pyobj, Relation, Exists, Not
 from yams import BASE_TYPES
 
-from cubicweb import ValidationError, Unauthorized, QueryError, UnknownEid
+from cubicweb import ValidationError, Unauthorized, UnknownEid
 from cubicweb import Binary, server
 from cubicweb.rset import ResultSet
 
 from cubicweb.utils import QueryCache, RepeatList
-from cubicweb.server.utils import cleanup_solutions
 from cubicweb.server.rqlannotation import SQLGenAnnotator, set_qdata
 from cubicweb.server.ssplanner import READ_ONLY_RTYPES, add_types_restriction
 from cubicweb.server.edition import EditedEntity
@@ -77,12 +74,13 @@
         return session.describe(term.eval(args))[0]
 
 def check_read_access(session, rqlst, solution, args):
-    """check that the given user has credentials to access data read the
-    query
+    """Check that the given user has credentials to access data read by the
+    query and return a dict defining necessary "local checks" (i.e. rql
+    expression in read permission defined in the schema) where no group grants
+    him the permission.
 
-    return a dict defining necessary local checks (due to use of rql expression
-    in the schema), keys are variable names and values associated rql expression
-    for the associated variable with the given solution
+    Returned dictionary's keys are variable names and values the rql expressions
+    for this variable (with the given solution).
     """
     # use `term_etype` since we've to deal with rewritten constants here,
     # when used as an external source by another repository.
@@ -130,35 +128,6 @@
                 localchecks[varname] = erqlexprs
     return localchecks
 
-def add_noinvariant(noinvariant, restricted, select, nbtrees):
-    # a variable can actually be invariant if it has not been restricted for
-    # security reason or if security assertion hasn't modified the possible
-    # solutions for the query
-    if nbtrees != 1:
-        for vname in restricted:
-            try:
-                noinvariant.add(select.defined_vars[vname])
-            except KeyError:
-                # this is an alias
-                continue
-    else:
-        for vname in restricted:
-            try:
-                var = select.defined_vars[vname]
-            except KeyError:
-                # this is an alias
-                continue
-            if len(var.stinfo['possibletypes']) != 1:
-                noinvariant.add(var)
-
-def _expand_selection(terms, selected, aliases, select, newselect):
-    for term in terms:
-        for vref in term.iget_nodes(VariableRef):
-            if not vref.name in selected:
-                select.append_selected(vref)
-                colalias = newselect.get_variable(vref.name, len(aliases))
-                aliases.append(VariableRef(colalias))
-                selected.add(vref.name)
 
 # Plans #######################################################################
 
@@ -258,9 +227,8 @@
                 self.args = args
                 cached = True
             else:
-                noinvariant = set()
                 with self.session.security_enabled(read=False):
-                    self._insert_security(union, noinvariant)
+                    noinvariant = self._insert_security(union)
                 if key is not None:
                     self.session.transaction_data[key] = (union, self.args)
         else:
@@ -272,121 +240,39 @@
         if union.has_text_query:
             self.cache_key = None
 
-    def _insert_security(self, union, noinvariant):
+    def _insert_security(self, union):
+        noinvariant = set()
         for select in union.children[:]:
             for subquery in select.with_:
-                self._insert_security(subquery.query, noinvariant)
+                self._insert_security(subquery.query)
             localchecks, restricted = self._check_permissions(select)
             if any(localchecks):
-                rewrite = self.session.rql_rewriter.rewrite
-                nbtrees = len(localchecks)
-                myunion = union
-                # transform in subquery when len(localchecks)>1 and groups
-                if nbtrees > 1 and (select.orderby or select.groupby or
-                                    select.having or select.has_aggregat or
-                                    select.distinct or
-                                    select.limit or select.offset):
-                    newselect = Select()
-                    # only select variables in subqueries
-                    origselection = select.selection
-                    select.select_only_variables()
-                    select.has_aggregat = False
-                    # create subquery first so correct node are used on copy
-                    # (eg ColumnAlias instead of Variable)
-                    aliases = [VariableRef(newselect.get_variable(vref.name, i))
-                               for i, vref in enumerate(select.selection)]
-                    selected = set(vref.name for vref in aliases)
-                    # now copy original selection and groups
-                    for term in origselection:
-                        newselect.append_selected(term.copy(newselect))
-                    if select.orderby:
-                        sortterms = []
-                        for sortterm in select.orderby:
-                            sortterms.append(sortterm.copy(newselect))
-                            for fnode in sortterm.get_nodes(Function):
-                                if fnode.name == 'FTIRANK':
-                                    # we've to fetch the has_text relation as well
-                                    var = fnode.children[0].variable
-                                    rel = iter(var.stinfo['ftirels']).next()
-                                    assert not rel.ored(), 'unsupported'
-                                    newselect.add_restriction(rel.copy(newselect))
-                                    # remove relation from the orig select and
-                                    # cleanup variable stinfo
-                                    rel.parent.remove(rel)
-                                    var.stinfo['ftirels'].remove(rel)
-                                    var.stinfo['relations'].remove(rel)
-                                    # XXX not properly re-annotated after security insertion?
-                                    newvar = newselect.get_variable(var.name)
-                                    newvar.stinfo.setdefault('ftirels', set()).add(rel)
-                                    newvar.stinfo.setdefault('relations', set()).add(rel)
-                        newselect.set_orderby(sortterms)
-                        _expand_selection(select.orderby, selected, aliases, select, newselect)
-                        select.orderby = () # XXX dereference?
-                    if select.groupby:
-                        newselect.set_groupby([g.copy(newselect) for g in select.groupby])
-                        _expand_selection(select.groupby, selected, aliases, select, newselect)
-                        select.groupby = () # XXX dereference?
-                    if select.having:
-                        newselect.set_having([g.copy(newselect) for g in select.having])
-                        _expand_selection(select.having, selected, aliases, select, newselect)
-                        select.having = () # XXX dereference?
-                    if select.limit:
-                        newselect.limit = select.limit
-                        select.limit = None
-                    if select.offset:
-                        newselect.offset = select.offset
-                        select.offset = 0
-                    myunion = Union()
-                    newselect.set_with([SubQuery(aliases, myunion)], check=False)
-                    newselect.distinct = select.distinct
-                    solutions = [sol.copy() for sol in select.solutions]
-                    cleanup_solutions(newselect, solutions)
-                    newselect.set_possible_types(solutions)
-                    # if some solutions doesn't need rewriting, insert original
-                    # select as first union subquery
-                    if () in localchecks:
-                        myunion.append(select)
-                    # we're done, replace original select by the new select with
-                    # subqueries (more added in the loop below)
-                    union.replace(select, newselect)
-                elif not () in localchecks:
-                    union.remove(select)
-                for lcheckdef, lchecksolutions in localchecks.iteritems():
-                    if not lcheckdef:
-                        continue
-                    myrqlst = select.copy(solutions=lchecksolutions)
-                    myunion.append(myrqlst)
-                    # in-place rewrite + annotation / simplification
-                    lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
-                    rewrite(myrqlst, lcheckdef, lchecksolutions, self.args)
-                    add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
-                if () in localchecks:
-                    select.set_possible_types(localchecks[()])
-                    add_types_restriction(self.schema, select)
-                    add_noinvariant(noinvariant, restricted, select, nbtrees)
-                self.rqlhelper.annotate(union)
+                self.session.rql_rewriter.insert_local_checks(
+                    select, self.args, localchecks, restricted, noinvariant)
+        return noinvariant
 
     def _check_permissions(self, rqlst):
-        """return a dict defining "local checks", e.g. RQLExpression defined in
-        the schema that should be inserted in the original query
-
-        solutions where a variable has a type which the user can't definitly read
-        are removed, else if the user may read it (eg if an rql expression is
-        defined for the "read" permission of the related type), the local checks
-        dict for the solution is updated
+        """Return a dict defining "local checks", i.e. RQLExpression defined in
+        the schema that should be inserted in the original query, together with
+        a set of variable names which requires some security to be inserted.
 
-        return a dict with entries for each different local check necessary,
-        with associated solutions as value. A local check is defined by a list
-        of 2-uple, with variable name as first item and the necessary rql
-        expression as second item for each variable which has to be checked.
-        So solutions which don't require local checks will be associated to
-        the empty tuple key.
+        Solutions where a variable has a type which the user can't definitly
+        read are removed, else if the user *may* read it (i.e. if an rql
+        expression is defined for the "read" permission of the related type),
+        the local checks dict is updated.
 
-        note: rqlst should not have been simplified at this point
+        The local checks dict has entries for each different local check
+        necessary, with associated solutions as value, a local check being
+        defined by a list of 2-uple (variable name, rql expressions) for each
+        variable which has to be checked. Solutions which don't require local
+        checks will be associated to the empty tuple key.
+
+        Note rqlst should not have been simplified at this point.
         """
         session = self.session
         msgs = []
-        neweids = session.transaction_data.get('neweids', ())
+        # dict(varname: eid), allowing to check rql expression for variables
+        # which have a known eid
         varkwargs = {}
         if not session.transaction_data.get('security-rqlst-cache'):
             for var in rqlst.defined_vars.itervalues():
@@ -414,20 +300,27 @@
                         rqlexprs = localcheck.pop(varname)
                     except KeyError:
                         continue
-                    if eid in neweids:
+                    # if entity has been added in the current transaction, the
+                    # user can read it whatever rql expressions are associated
+                    # to its type
+                    if session.added_in_transaction(eid):
                         continue
                     for rqlexpr in rqlexprs:
                         if rqlexpr.check(session, eid):
                             break
                     else:
                         raise Unauthorized('No read acces on %r with eid %i.' % (var, eid))
+                # mark variables protected by an rql expression
                 restricted_vars.update(localcheck)
-                localchecks.setdefault(tuple(localcheck.iteritems()), []).append(solution)
+                # turn local check into a dict key
+                localcheck = tuple(sorted(localcheck.iteritems()))
+                localchecks.setdefault(localcheck, []).append(solution)
         # raise Unautorized exception if the user can't access to any solution
         if not newsolutions:
             raise Unauthorized('\n'.join(msgs))
+        # if there is some message, solutions have been modified and must be
+        # reconsidered by the syntax treee
         if msgs:
-            # (else solutions have not been modified)
             rqlst.set_possible_types(newsolutions)
         return localchecks, restricted_vars
 
@@ -728,7 +621,7 @@
             if args:
                 # different SQL generated when some argument is None or not (IS
                 # NULL). This should be considered when computing sql cache key
-                cachekey += tuple(sorted([k for k,v in args.iteritems()
+                cachekey += tuple(sorted([k for k, v in args.iteritems()
                                           if v is None]))
         # make an execution plan
         plan = self.plan_factory(rqlst, args, session)
--- a/server/sources/rql2sql.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/server/sources/rql2sql.py	Mon Jul 15 10:59:34 2013 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -62,8 +62,8 @@
                        Not, Comparison, ColumnAlias, Relation, SubQuery, Exists)
 
 from cubicweb import QueryError
+from cubicweb.rqlrewrite import cleanup_solutions
 from cubicweb.server.sqlutils import SQL_PREFIX
-from cubicweb.server.utils import cleanup_solutions
 
 ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
 
--- a/server/utils.py	Fri Jul 12 10:39:01 2013 +0200
+++ b/server/utils.py	Mon Jul 15 10:59:34 2013 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -91,13 +91,6 @@
     return rloop(seqin, [])
 
 
-def cleanup_solutions(rqlst, solutions):
-    for sol in solutions:
-        for vname in list(sol):
-            if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
-                del sol[vname]
-
-
 def eschema_eid(session, eschema):
     """get eid of the CWEType entity for the given yams type. You should use
     this because when schema has been loaded from the file-system, not from the