[rqlrewrite] test and fix rql snippets insertion when several snippets match an optional variable stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Fri, 17 Jun 2011 18:51:01 +0200
branchstable
changeset 7535 d5725a89dac9
parent 7530 15178bf89fb6
child 7536 29961a416faa
child 7537 1af162bd78b8
[rqlrewrite] test and fix rql snippets insertion when several snippets match an optional variable
rqlrewrite.py
test/unittest_rqlrewrite.py
--- a/rqlrewrite.py	Thu Jun 09 16:41:41 2011 +0200
+++ b/rqlrewrite.py	Fri Jun 17 18:51:01 2011 +0200
@@ -20,12 +20,16 @@
 
 This is used for instance for read security checking in the repository.
 """
+from __future__ import with_statement
 
 __docformat__ = "restructuredtext en"
 
 from rql import nodes as n, stmts, TypeResolverException
 from rql.utils import common_parent
+
 from yams import BadSchemaDefinition
+
+from logilab.common import tempattr
 from logilab.common.graph import has_path
 
 from cubicweb import Unauthorized, typed_eid
@@ -156,7 +160,6 @@
         self.exists_snippet = {}
         self.pending_keys = []
         self.existingvars = existingvars
-        self._insert_scope = None
         # we have to annotate the rqlst before inserting snippets, even though
         # we'll have to redo it latter
         self.annotate(select)
@@ -193,6 +196,7 @@
         self.varmap = varmap
         self.revvarmap = {}
         self.varinfos = []
+        self._insert_scope = None
         for i, (selectvar, snippetvar) in enumerate(varmap):
             assert snippetvar in 'SOX'
             self.revvarmap[snippetvar] = (selectvar, i)
@@ -229,7 +233,7 @@
                 except Unsupported:
                     continue
                 inserted = True
-                if new is not None:
+                if new is not None and self._insert_scope is None:
                     self.exists_snippet[rqlexpr] = new
                 parent = parent or new
             else:
@@ -263,11 +267,12 @@
                         insert_scope = common_parent(scope, insert_scope)
             else:
                 insert_scope = self._insert_scope
-            if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos):
+            if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
+                                                  for vi in self.varinfos):
                 assert parent is None
                 self._insert_scope = self.snippet_subquery(varmap, new)
                 self.insert_pending()
-                self._insert_scope = None
+                #self._insert_scope = None
                 return
             if not isinstance(new, (n.Exists, n.Not)):
                 new = n.Exists(new)
@@ -283,15 +288,14 @@
                 except Unsupported:
                     # some solutions have been lost, can't apply this rql expr
                     if parent is None:
-                        self.select.remove_node(new, undefine=True)
+                        self.current_statement().remove_node(new, undefine=True)
                     else:
                         parent.parent.replace(or_, or_.children[0])
                         self._cleanup_inserted(new)
                     raise
                 else:
-                    self._insert_scope = new
-                    self.insert_pending()
-                    self._insert_scope = None
+                    with tempattr(self, '_insert_scope', new):
+                        self.insert_pending()
             return new
         self.insert_pending()
 
@@ -303,6 +307,7 @@
         recomputed, we have to insert snippet defined for <action> of entity
         types taken by X
         """
+        stmt = self.current_statement()
         while self.pending_keys:
             key, action = self.pending_keys.pop()
             try:
@@ -313,12 +318,12 @@
                 except KeyError:
                     # variable isn't used anywhere else, we can't insert security
                     raise Unauthorized()
-            ptypes = self.select.defined_vars[varname].stinfo['possibletypes']
+            ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
             if len(ptypes) > 1:
                 # XXX dunno how to handle this
                 self.session.error(
                     'cant check security of %s, ambigous type for %s in %s',
-                    self.select, varname, key[0]) # key[0] == the rql expression
+                    stmt, varname, key[0]) # key[0] == the rql expression
                 raise Unauthorized()
             etype = iter(ptypes).next()
             eschema = self.schema.eschema(etype)
@@ -499,17 +504,18 @@
         self.rewritten[key] = term.name
 
     def _get_varname_or_term(self, vname):
+        stmt = self.current_statement()
         if vname == 'U':
+            stmt = self.select
             if self.u_varname is None:
-                select = self.select
-                self.u_varname = select.allocate_varname()
+                self.u_varname = stmt.allocate_varname()
                 # generate an identifier for the substitution
-                argname = select.allocate_varname()
+                argname = stmt.allocate_varname()
                 while argname in self.kwargs:
-                    argname = select.allocate_varname()
+                    argname = stmt.allocate_varname()
                 # insert "U eid %(u)s"
-                select.add_constant_restriction(
-                    select.get_variable(self.u_varname),
+                stmt.add_constant_restriction(
+                    stmt.get_variable(self.u_varname),
                     'eid', unicode(argname), 'Substitute')
                 self.kwargs[argname] = self.session.user.eid
             return self.u_varname
@@ -517,7 +523,7 @@
         try:
             return self.rewritten[key]
         except KeyError:
-            self.rewritten[key] = newvname = self.select.allocate_varname()
+            self.rewritten[key] = newvname = stmt.allocate_varname()
             return newvname
 
     # visitor methods ##########################################################
@@ -625,14 +631,20 @@
 
     def visit_variableref(self, node):
         """get the sql name for a variable reference"""
+        stmt = self.current_statement()
         if node.name in self.revvarmap:
             selectvar, index = self.revvarmap[node.name]
             vi = self.varinfos[index]
             if vi.get('const') is not None:
                 return n.Constant(vi['const'], 'Int') # XXX gae
-            return n.VariableRef(self.select.get_variable(selectvar))
+            return n.VariableRef(stmt.get_variable(selectvar))
         vname_or_term = self._get_varname_or_term(node.name)
         if isinstance(vname_or_term, basestring):
-            return n.VariableRef(self.select.get_variable(vname_or_term))
+            return n.VariableRef(stmt.get_variable(vname_or_term))
         # shared term
-        return vname_or_term.copy(self.select)
+        return vname_or_term.copy(stmt)
+
+    def current_statement(self):
+        if self._insert_scope is None:
+            return self.select
+        return self._insert_scope.stmt
--- a/test/unittest_rqlrewrite.py	Thu Jun 09 16:41:41 2011 +0200
+++ b/test/unittest_rqlrewrite.py	Fri Jun 17 18:51:01 2011 +0200
@@ -82,8 +82,9 @@
     for vref in node.iget_nodes(nodes.VariableRef):
         vrefmap.setdefault(vref.name, set()).add(vref)
     for var in node.defined_vars.itervalues():
-        assert not (var.stinfo['references'] ^ vrefmap[var.name])
-        assert (var.stinfo['references'])
+        assert var.stinfo['references']
+        assert not (var.stinfo['references'] ^ vrefmap[var.name]), (node.as_string(), var.stinfo['references'], vrefmap[var.name])
+
 
 class RQLRewriteTC(TestCase):
     """a faire:
@@ -95,10 +96,10 @@
     """
 
     def test_base_var(self):
-        card_constraint = ('X in_state S, U in_group G, P require_state S,'
+        constraint = ('X in_state S, U in_group G, P require_state S,'
                            'P name "read", P require_group G')
         rqlst = parse('Card C')
-        rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {})
+        rewrite(rqlst, {('C', 'X'): (constraint,)}, {})
         self.failUnlessEqual(rqlst.as_string(),
                              u"Any C WHERE C is Card, B eid %(D)s, "
                              "EXISTS(C in_state A, B in_group E, F require_state A, "
@@ -130,27 +131,31 @@
                              "E in_state D, D name 'subscribed'), D is State, E is CWUser)")
 
     def test_simplified_rqlst(self):
-        card_constraint = ('X in_state S, U in_group G, P require_state S,'
+        constraint = ('X in_state S, U in_group G, P require_state S,'
                            'P name "read", P require_group G')
         rqlst = parse('Any 2') # this is the simplified rql st for Any X WHERE X eid 12
-        rewrite(rqlst, {('2', 'X'): (card_constraint,)}, {})
+        rewrite(rqlst, {('2', 'X'): (constraint,)}, {})
         self.failUnlessEqual(rqlst.as_string(),
                              u"Any 2 WHERE B eid %(C)s, "
                              "EXISTS(2 in_state A, B in_group D, E require_state A, "
                              "E name 'read', E require_group D, A is State, D is CWGroup, E is CWPermission)")
 
-    def test_optional_var_base(self):
-        card_constraint = ('X in_state S, U in_group G, P require_state S,'
+    def test_optional_var_base_1(self):
+        constraint = ('X in_state S, U in_group G, P require_state S,'
                            'P name "read", P require_group G')
         rqlst = parse('Any A,C WHERE A documented_by C?')
-        rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {})
+        rewrite(rqlst, {('C', 'X'): (constraint,)}, {})
         self.failUnlessEqual(rqlst.as_string(),
                              "Any A,C WHERE A documented_by C?, A is Affaire "
                              "WITH C BEING "
                              "(Any C WHERE EXISTS(C in_state B, D in_group F, G require_state B, G name 'read', "
                              "G require_group F), D eid %(A)s, C is Card)")
+
+    def test_optional_var_base_2(self):
+        constraint = ('X in_state S, U in_group G, P require_state S,'
+                           'P name "read", P require_group G')
         rqlst = parse('Any A,C,T WHERE A documented_by C?, C title T')
-        rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {})
+        rewrite(rqlst, {('C', 'X'): (constraint,)}, {})
         self.failUnlessEqual(rqlst.as_string(),
                              "Any A,C,T WHERE A documented_by C?, A is Affaire "
                              "WITH C,T BEING "
@@ -158,6 +163,19 @@
                              "G require_state B, G name 'read', G require_group F), "
                              "D eid %(A)s, C is Card)")
 
+    def test_optional_var_base_3(self):
+        constraint1 = ('X in_state S, U in_group G, P require_state S,'
+                       'P name "read", P require_group G')
+        constraint2 = 'X in_state S, S name "public"'
+        rqlst = parse('Any A,C,T WHERE A documented_by C?, C title T')
+        rewrite(rqlst, {('C', 'X'): (constraint1, constraint2)}, {})
+        self.failUnlessEqual(rqlst.as_string(),
+                             "Any A,C,T WHERE A documented_by C?, A is Affaire "
+                             "WITH C,T BEING (Any C,T WHERE C title T, "
+                             "EXISTS(C in_state B, D in_group F, G require_state B, G name 'read', G require_group F), "
+                             "D eid %(A)s, C is Card, "
+                             "EXISTS(C in_state E, E name 'public'))")
+
     def test_optional_var_inlined(self):
         c1 = ('X require_permission P')
         c2 = ('X inlined_card O, O require_permission P')