[rql rewriter] to properly handle 'relation' rql expressions, rql rewriter must support multiple variables (eg S and O) at once to be given as varmap stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Wed, 30 Mar 2011 11:08:15 +0200
branchstable
changeset 7139 20807d3d7cf6
parent 7138 9aba650eea6b
child 7140 ba51dac1115d
[rql rewriter] to properly handle 'relation' rql expressions, rql rewriter must support multiple variables (eg S and O) at once to be given as varmap
entity.py
rqlrewrite.py
server/querier.py
test/unittest_entity.py
test/unittest_rqlrewrite.py
--- a/entity.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/entity.py	Wed Mar 30 11:08:15 2011 +0200
@@ -62,7 +62,6 @@
     return True
 
 
-
 class Entity(AppObject):
     """an entity instance has e_schema automagically set on
     the class and instances has access to their issuing cursor.
@@ -808,7 +807,11 @@
             else:
                 existant = None # instead of 'SO', improve perfs
             for select in rqlst.children:
-                rewriter.rewrite(select, [((searchedvar, searchedvar), rqlexprs)],
+                varmap = {}
+                for var in 'SO':
+                    if var in select.defined_vars:
+                        varmap[var] = var
+                rewriter.rewrite(select, [(varmap, rqlexprs)],
                                  select.solutions, args, existant)
             rql = rqlst.as_string()
         return rql, args
--- a/rqlrewrite.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/rqlrewrite.py	Wed Mar 30 11:08:15 2011 +0200
@@ -24,6 +24,7 @@
 __docformat__ = "restructuredtext en"
 
 from rql import nodes as n, stmts, TypeResolverException
+from rql.utils import common_parent
 from yams import BadSchemaDefinition
 from logilab.common.graph import has_path
 
@@ -180,13 +181,23 @@
     def insert_snippets(self, snippets, varexistsmap=None):
         self.rewritten = {}
         for varmap, rqlexprs in snippets:
+            if isinstance(varmap, dict):
+                varmap = tuple(sorted(varmap.items()))
+            else:
+                assert isinstance(varmap, tuple), varmap
             if varexistsmap is not None and not varmap in varexistsmap:
                 continue
-            self.varmap = varmap
-            selectvar, snippetvar = varmap
+            self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
+
+    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+        self.varmap = varmap
+        self.revvarmap = {}
+        self.varinfos = []
+        for i, (selectvar, snippetvar) in enumerate(varmap):
             assert snippetvar in 'SOX'
-            self.revvarmap = {snippetvar: selectvar}
-            self.varinfo = vi = {}
+            self.revvarmap[snippetvar] = (selectvar, i)
+            vi = {}
+            self.varinfos.append(vi)
             try:
                 vi['const'] = typed_eid(selectvar) # XXX gae
                 vi['rhs_rels'] = vi['lhs_rels'] = {}
@@ -194,42 +205,42 @@
                 try:
                     vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                 except KeyError:
-                    # variable has been moved to a newly inserted subquery
+                    # variable may have been moved to a newly inserted subquery
                     # we should insert snippet in that subquery
                     subquery = self.select.aliases[selectvar].query
                     assert len(subquery.children) == 1
                     subselect = subquery.children[0]
                     RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
                                                       subselect.solutions, self.kwargs)
-                    continue
+                    return
                 if varexistsmap is None:
                     vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                     vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                            if not r in sti['rhsrelations'])
                 else:
                     vi['rhs_rels'] = vi['lhs_rels'] = {}
-            parent = None
-            inserted = False
-            for rqlexpr in rqlexprs:
-                self.current_expr = rqlexpr
-                if varexistsmap is None:
-                    try:
-                        new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
-                    except Unsupported:
-                        continue
-                    inserted = True
-                    if new is not None:
-                        self.exists_snippet[rqlexpr] = new
-                    parent = parent or new
-                else:
-                    # called to reintroduce snippet due to ambiguity creation,
-                    # so skip snippets which are not introducing this ambiguity
-                    exists = varexistsmap[varmap]
-                    if self.exists_snippet[rqlexpr] is exists:
-                        self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
-            if varexistsmap is None and not inserted:
-                # no rql expression found matching rql solutions. User has no access right
-                raise Unauthorized()
+        parent = None
+        inserted = False
+        for rqlexpr in rqlexprs:
+            self.current_expr = rqlexpr
+            if varexistsmap is None:
+                try:
+                    new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
+                except Unsupported:
+                    continue
+                inserted = True
+                if new is not None:
+                    self.exists_snippet[rqlexpr] = new
+                parent = parent or new
+            else:
+                # called to reintroduce snippet due to ambiguity creation,
+                # so skip snippets which are not introducing this ambiguity
+                exists = varexistsmap[varmap]
+                if self.exists_snippet[rqlexpr] is exists:
+                    self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
+        if varexistsmap is None and not inserted:
+            # no rql expression found matching rql solutions. User has no access right
+            raise Unauthorized() # XXX bad constraint when inserting constraints
 
     def insert_snippet(self, varmap, snippetrqlst, parent=None):
         new = snippetrqlst.where.accept(self)
@@ -243,10 +254,16 @@
     def _insert_snippet(self, varmap, parent, new):
         if new is not None:
             if self._insert_scope is None:
-                insert_scope = self.varinfo.get('stinfo', {}).get('scope', self.select)
+                insert_scope = None
+                for vi in self.varinfos:
+                    scope = vi.get('stinfo', {}).get('scope', self.select)
+                    if insert_scope is None:
+                        insert_scope = scope
+                    else:
+                        insert_scope = common_parent(scope, insert_scope)
             else:
                 insert_scope = self._insert_scope
-            if self.varinfo.get('stinfo', {}).get('optrelations'):
+            if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos):
                 assert parent is None
                 self._insert_scope = self.snippet_subquery(varmap, new)
                 self.insert_pending()
@@ -292,7 +309,7 @@
                 varname = self.rewritten[key]
             except KeyError:
                 try:
-                    varname = self.revvarmap[key[-1]]
+                    varname = self.revvarmap[key[-1]][0]
                 except KeyError:
                     # variable isn't used anywhere else, we can't insert security
                     raise Unauthorized()
@@ -309,45 +326,51 @@
                 rqlexprs = eschema.get_rqlexprs(action)
                 if not rqlexprs:
                     raise Unauthorized()
-                self.insert_snippets([((varname, 'X'), rqlexprs)])
+                self.insert_snippets([({varname: 'X'}, rqlexprs)])
 
     def snippet_subquery(self, varmap, transformedsnippet):
         """introduce the given snippet in a subquery"""
         subselect = stmts.Select()
-        selectvar = varmap[0]
-        subselectvar = subselect.get_variable(selectvar)
-        subselect.append_selected(n.VariableRef(subselectvar))
         snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
-        aliases = [selectvar]
-        stinfo = self.varinfo['stinfo']
-        need_null_test = False
-        for rel in stinfo['relations']:
-            rschema = self.schema.rschema(rel.r_type)
-            if rschema.final or (rschema.inlined and
-                                 not rel in stinfo['rhsrelations']):
-                rel.children[0].name = selectvar # XXX explain why
-                subselect.add_restriction(rel.copy(subselect))
-                for vref in rel.children[1].iget_nodes(n.VariableRef):
-                    if isinstance(vref.variable, n.ColumnAlias):
-                        # XXX could probably be handled by generating the subquery
-                        # into the detected subquery
-                        raise BadSchemaDefinition(
-                            "cant insert security because of usage two inlined "
-                            "relations in this query. You should probably at "
-                            "least uninline %s" % rel.r_type)
-                    subselect.append_selected(vref.copy(subselect))
-                    aliases.append(vref.name)
-                self.select.remove_node(rel)
-                # when some inlined relation has to be copied in the subquery,
-                # we need to test that either value is NULL or that the snippet
-                # condition is satisfied
-                if rschema.inlined and rel.optional:
-                    need_null_test = True
-        if need_null_test:
-            snippetrqlst = n.Or(
-                n.make_relation(subselectvar, 'is', (None, None), n.Constant,
-                                operator='='),
-                snippetrqlst)
+        aliases = []
+        rels_done = set()
+        for i, (selectvar, snippetvar) in enumerate(varmap):
+            subselectvar = subselect.get_variable(selectvar)
+            subselect.append_selected(n.VariableRef(subselectvar))
+            aliases.append(selectvar)
+            vi = self.varinfos[i]
+            need_null_test = False
+            stinfo = vi['stinfo']
+            for rel in stinfo['relations']:
+                if rel in rels_done:
+                    continue
+                rels_done.add(rel)
+                rschema = self.schema.rschema(rel.r_type)
+                if rschema.final or (rschema.inlined and
+                                     not rel in stinfo['rhsrelations']):
+                    rel.children[0].name = selectvar # XXX explain why
+                    subselect.add_restriction(rel.copy(subselect))
+                    for vref in rel.children[1].iget_nodes(n.VariableRef):
+                        if isinstance(vref.variable, n.ColumnAlias):
+                            # XXX could probably be handled by generating the
+                            # subquery into the detected subquery
+                            raise BadSchemaDefinition(
+                                "cant insert security because of usage two inlined "
+                                "relations in this query. You should probably at "
+                                "least uninline %s" % rel.r_type)
+                        subselect.append_selected(vref.copy(subselect))
+                        aliases.append(vref.name)
+                    self.select.remove_node(rel)
+                    # when some inlined relation has to be copied in the
+                    # subquery, we need to test that either value is NULL or
+                    # that the snippet condition is satisfied
+                    if rschema.inlined and rel.optional:
+                        need_null_test = True
+            if need_null_test:
+                snippetrqlst = n.Or(
+                    n.make_relation(subselectvar, 'is', (None, None), n.Constant,
+                                    operator='='),
+                    snippetrqlst)
         subselect.add_restriction(snippetrqlst)
         if self.u_varname:
             # generate an identifier for the substitution
@@ -439,30 +462,32 @@
         original query, return that relation node
         """
         rschema = self.schema.rschema(sniprel.r_type)
-        try:
-            if target == 'object':
-                orel = self.varinfo['lhs_rels'][sniprel.r_type]
-                cardindex = 0
-                ttypes_func = rschema.objects
-                rdef = rschema.rdef
-            else: # target == 'subject':
-                orel = self.varinfo['rhs_rels'][sniprel.r_type]
-                cardindex = 1
-                ttypes_func = rschema.subjects
-                rdef = lambda x, y: rschema.rdef(y, x)
-        except KeyError:
-            # may be raised by self.varinfo['xhs_rels'][sniprel.r_type]
-            return None
-        # can't share neged relation or relations with different outer join
-        if (orel.neged(strict=True) or sniprel.neged(strict=True)
-            or (orel.optional and orel.optional != sniprel.optional)):
-            return None
-        # if cardinality is in '?1', we can ignore the snippet relation and use
-        # variable from the original query
-        for etype in self.varinfo['stinfo']['possibletypes']:
-            for ttype in ttypes_func(etype):
-                if rdef(etype, ttype).cardinality[cardindex] in '+*':
-                    return None
+        for vi in self.varinfos:
+            try:
+                if target == 'object':
+                    orel = vi['lhs_rels'][sniprel.r_type]
+                    cardindex = 0
+                    ttypes_func = rschema.objects
+                    rdef = rschema.rdef
+                else: # target == 'subject':
+                    orel = vi['rhs_rels'][sniprel.r_type]
+                    cardindex = 1
+                    ttypes_func = rschema.subjects
+                    rdef = lambda x, y: rschema.rdef(y, x)
+            except KeyError:
+                # may be raised by vi['xhs_rels'][sniprel.r_type]
+                return None
+            # can't share neged relation or relations with different outer join
+            if (orel.neged(strict=True) or sniprel.neged(strict=True)
+                or (orel.optional and orel.optional != sniprel.optional)):
+                return None
+            # if cardinality is in '?1', we can ignore the snippet relation and use
+            # variable from the original query
+            for etype in vi['stinfo']['possibletypes']:
+                for ttype in ttypes_func(etype):
+                    if rdef(etype, ttype).cardinality[cardindex] in '+*':
+                        return None
+            break
         return orel
 
     def _use_orig_term(self, snippet_varname, term):
@@ -601,10 +626,11 @@
     def visit_variableref(self, node):
         """get the sql name for a variable reference"""
         if node.name in self.revvarmap:
-            if self.varinfo.get('const') is not None:
-                return n.Constant(self.varinfo['const'], 'Int') # XXX gae
-            return n.VariableRef(self.select.get_variable(
-                self.revvarmap[node.name]))
+            selectvar, index = self.revvarmap[node.name]
+            vi = self.varinfos[index]
+            if vi.get('const') is not None:
+                return n.Constant(vi['const'], 'Int') # XXX gae
+            return n.VariableRef(self.select.get_variable(selectvar))
         vname_or_term = self._get_varname_or_term(node.name)
         if isinstance(vname_or_term, basestring):
             return n.VariableRef(self.select.get_variable(vname_or_term))
--- a/server/querier.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/server/querier.py	Wed Mar 30 11:08:15 2011 +0200
@@ -354,7 +354,7 @@
                     myrqlst = select.copy(solutions=lchecksolutions)
                     myunion.append(myrqlst)
                     # in-place rewrite + annotation / simplification
-                    lcheckdef = [((var, 'X'), rqlexprs) for var, rqlexprs in lcheckdef]
+                    lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
                     rewrite(myrqlst, lcheckdef, lchecksolutions, self.args)
                     add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
                 if () in localchecks:
--- a/test/unittest_entity.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/test/unittest_entity.py	Wed Mar 30 11:08:15 2011 +0200
@@ -223,38 +223,48 @@
                           'Any X,AA ORDERBY AA DESC '
                           'WHERE E eid %(x)s, E tags X, X modification_date AA')
 
-    def test_unrelated_rql_security_1(self):
+    def test_unrelated_rql_security_1_manager(self):
         user = self.request().user
         rql = user.cw_unrelated_rql('use_email', 'EmailAddress', 'subject')[0]
         self.assertEqual(rql, 'Any O,AA,AB,AC ORDERBY AC DESC '
-                          'WHERE NOT S use_email O, S eid %(x)s, O is EmailAddress, O address AA, O alias AB, O modification_date AC')
+                         'WHERE NOT S use_email O, S eid %(x)s, '
+                         'O is EmailAddress, O address AA, O alias AB, O modification_date AC')
+
+    def test_unrelated_rql_security_1_user(self):
         self.create_user('toto')
         self.login('toto')
         user = self.request().user
         rql = user.cw_unrelated_rql('use_email', 'EmailAddress', 'subject')[0]
         self.assertEqual(rql, 'Any O,AA,AB,AC ORDERBY AC DESC '
-                          'WHERE NOT S use_email O, S eid %(x)s, O is EmailAddress, O address AA, O alias AB, O modification_date AC')
+                          'WHERE NOT S use_email O, S eid %(x)s, '
+                         'O is EmailAddress, O address AA, O alias AB, O modification_date AC')
         user = self.execute('Any X WHERE X login "admin"').get_entity(0, 0)
-        self.assertRaises(Unauthorized, user.cw_unrelated_rql, 'use_email', 'EmailAddress', 'subject')
+        rql = user.cw_unrelated_rql('use_email', 'EmailAddress', 'subject')[0]
+        self.assertEqual(rql, 'Any O,AA,AB,AC ORDERBY AC DESC WHERE '
+                         'NOT EXISTS(S use_email O), S eid %(x)s, '
+                         'O is EmailAddress, O address AA, O alias AB, O modification_date AC, '
+                         'A eid %(B)s, EXISTS(S identity A, NOT A in_group C, C name "guests", C is CWGroup)')
+
+    def test_unrelated_rql_security_1_anon(self):
         self.login('anon')
         user = self.request().user
-        self.assertRaises(Unauthorized, user.cw_unrelated_rql, 'use_email', 'EmailAddress', 'subject')
+        rql = user.cw_unrelated_rql('use_email', 'EmailAddress', 'subject')[0]
+        self.assertEqual(rql, 'Any O,AA,AB,AC ORDERBY AC DESC WHERE '
+                         'NOT EXISTS(S use_email O), S eid %(x)s, '
+                         'O is EmailAddress, O address AA, O alias AB, O modification_date AC, '
+                         'A eid %(B)s, EXISTS(S identity A, NOT A in_group C, C name "guests", C is CWGroup)')
 
     def test_unrelated_rql_security_2(self):
         email = self.execute('INSERT EmailAddress X: X address "hop"').get_entity(0, 0)
         rql = email.cw_unrelated_rql('use_email', 'CWUser', 'object')[0]
         self.assertEqual(rql, 'Any S,AA,AB,AC,AD ORDERBY AA ASC '
                           'WHERE NOT S use_email O, O eid %(x)s, S is CWUser, S login AA, S firstname AB, S surname AC, S modification_date AD')
-        #rql = email.cw_unrelated_rql('use_email', 'Person', 'object')[0]
-        #self.assertEqual(rql, '')
         self.login('anon')
         email = self.execute('Any X WHERE X eid %(x)s', {'x': email.eid}).get_entity(0, 0)
         rql = email.cw_unrelated_rql('use_email', 'CWUser', 'object')[0]
         self.assertEqual(rql, 'Any S,AA,AB,AC,AD ORDERBY AA '
                           'WHERE NOT EXISTS(S use_email O), O eid %(x)s, S is CWUser, S login AA, S firstname AB, S surname AC, S modification_date AD, '
                           'A eid %(B)s, EXISTS(S identity A, NOT A in_group C, C name "guests", C is CWGroup)')
-        #rql = email.cw_unrelated_rql('use_email', 'Person', 'object')[0]
-        #self.assertEqual(rql, '')
 
     def test_unrelated_rql_security_nonexistant(self):
         self.login('anon')
--- a/test/unittest_rqlrewrite.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/test/unittest_rqlrewrite.py	Wed Mar 30 11:08:15 2011 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -62,15 +62,17 @@
             def simplify(mainrqlst, needcopy=False):
                 rqlhelper.simplify(rqlst, needcopy)
     rewriter = RQLRewriter(mock_object(vreg=FakeVReg, user=(mock_object(eid=1))))
-    for v, snippets in snippets_map.items():
-        snippets_map[v] = [isinstance(snippet, basestring)
-                           and mock_object(snippet_rqlst=parse('Any X WHERE '+snippet).children[0],
-                                           expression='Any X WHERE '+snippet)
-                           or snippet
-                           for snippet in snippets]
+    snippets = []
+    for v, exprs in snippets_map.items():
+        rqlexprs = [isinstance(snippet, basestring)
+                    and mock_object(snippet_rqlst=parse('Any X WHERE '+snippet).children[0],
+                                    expression='Any X WHERE '+snippet)
+                    or snippet
+                    for snippet in exprs]
+        snippets.append((dict([v]), rqlexprs))
     rqlhelper.compute_solutions(rqlst.children[0], {'eid': eid_func_map}, kwargs=kwargs)
     solutions = rqlst.children[0].solutions
-    rewriter.rewrite(rqlst.children[0], snippets_map.items(), solutions, kwargs,
+    rewriter.rewrite(rqlst.children[0], snippets, solutions, kwargs,
                      existingvars)
     test_vrefs(rqlst.children[0])
     return rewriter.rewritten