# HG changeset patch # User Sylvain Thénault # Date 1308329461 -7200 # Node ID d5725a89dac9f6663be87390e28d4c60a9d4ace4 # Parent 15178bf89fb6ab8e51afa5ee3a81888a87ee9d5c [rqlrewrite] test and fix rql snippets insertion when several snippets match an optional variable diff -r 15178bf89fb6 -r d5725a89dac9 rqlrewrite.py --- a/rqlrewrite.py Thu Jun 09 16:41:41 2011 +0200 +++ b/rqlrewrite.py Fri Jun 17 18:51:01 2011 +0200 @@ -20,12 +20,16 @@ This is used for instance for read security checking in the repository. """ +from __future__ import with_statement __docformat__ = "restructuredtext en" from rql import nodes as n, stmts, TypeResolverException from rql.utils import common_parent + from yams import BadSchemaDefinition + +from logilab.common import tempattr from logilab.common.graph import has_path from cubicweb import Unauthorized, typed_eid @@ -156,7 +160,6 @@ self.exists_snippet = {} self.pending_keys = [] self.existingvars = existingvars - self._insert_scope = None # we have to annotate the rqlst before inserting snippets, even though # we'll have to redo it latter self.annotate(select) @@ -193,6 +196,7 @@ self.varmap = varmap self.revvarmap = {} self.varinfos = [] + self._insert_scope = None for i, (selectvar, snippetvar) in enumerate(varmap): assert snippetvar in 'SOX' self.revvarmap[snippetvar] = (selectvar, i) @@ -229,7 +233,7 @@ except Unsupported: continue inserted = True - if new is not None: + if new is not None and self._insert_scope is None: self.exists_snippet[rqlexpr] = new parent = parent or new else: @@ -263,11 +267,12 @@ insert_scope = common_parent(scope, insert_scope) else: insert_scope = self._insert_scope - if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos): + if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations') + for vi in self.varinfos): assert parent is None self._insert_scope = self.snippet_subquery(varmap, new) self.insert_pending() - self._insert_scope = None + #self._insert_scope = None return if not isinstance(new, (n.Exists, n.Not)): new = n.Exists(new) @@ -283,15 +288,14 @@ except Unsupported: # some solutions have been lost, can't apply this rql expr if parent is None: - self.select.remove_node(new, undefine=True) + self.current_statement().remove_node(new, undefine=True) else: parent.parent.replace(or_, or_.children[0]) self._cleanup_inserted(new) raise else: - self._insert_scope = new - self.insert_pending() - self._insert_scope = None + with tempattr(self, '_insert_scope', new): + self.insert_pending() return new self.insert_pending() @@ -303,6 +307,7 @@ recomputed, we have to insert snippet defined for of entity types taken by X """ + stmt = self.current_statement() while self.pending_keys: key, action = self.pending_keys.pop() try: @@ -313,12 +318,12 @@ except KeyError: # variable isn't used anywhere else, we can't insert security raise Unauthorized() - ptypes = self.select.defined_vars[varname].stinfo['possibletypes'] + ptypes = stmt.defined_vars[varname].stinfo['possibletypes'] if len(ptypes) > 1: # XXX dunno how to handle this self.session.error( 'cant check security of %s, ambigous type for %s in %s', - self.select, varname, key[0]) # key[0] == the rql expression + stmt, varname, key[0]) # key[0] == the rql expression raise Unauthorized() etype = iter(ptypes).next() eschema = self.schema.eschema(etype) @@ -499,17 +504,18 @@ self.rewritten[key] = term.name def _get_varname_or_term(self, vname): + stmt = self.current_statement() if vname == 'U': + stmt = self.select if self.u_varname is None: - select = self.select - self.u_varname = select.allocate_varname() + self.u_varname = stmt.allocate_varname() # generate an identifier for the substitution - argname = select.allocate_varname() + argname = stmt.allocate_varname() while argname in self.kwargs: - argname = select.allocate_varname() + argname = stmt.allocate_varname() # insert "U eid %(u)s" - select.add_constant_restriction( - select.get_variable(self.u_varname), + stmt.add_constant_restriction( + stmt.get_variable(self.u_varname), 'eid', unicode(argname), 'Substitute') self.kwargs[argname] = self.session.user.eid return self.u_varname @@ -517,7 +523,7 @@ try: return self.rewritten[key] except KeyError: - self.rewritten[key] = newvname = self.select.allocate_varname() + self.rewritten[key] = newvname = stmt.allocate_varname() return newvname # visitor methods ########################################################## @@ -625,14 +631,20 @@ def visit_variableref(self, node): """get the sql name for a variable reference""" + stmt = self.current_statement() if node.name in self.revvarmap: selectvar, index = self.revvarmap[node.name] vi = self.varinfos[index] if vi.get('const') is not None: return n.Constant(vi['const'], 'Int') # XXX gae - return n.VariableRef(self.select.get_variable(selectvar)) + return n.VariableRef(stmt.get_variable(selectvar)) vname_or_term = self._get_varname_or_term(node.name) if isinstance(vname_or_term, basestring): - return n.VariableRef(self.select.get_variable(vname_or_term)) + return n.VariableRef(stmt.get_variable(vname_or_term)) # shared term - return vname_or_term.copy(self.select) + return vname_or_term.copy(stmt) + + def current_statement(self): + if self._insert_scope is None: + return self.select + return self._insert_scope.stmt diff -r 15178bf89fb6 -r d5725a89dac9 test/unittest_rqlrewrite.py --- a/test/unittest_rqlrewrite.py Thu Jun 09 16:41:41 2011 +0200 +++ b/test/unittest_rqlrewrite.py Fri Jun 17 18:51:01 2011 +0200 @@ -82,8 +82,9 @@ for vref in node.iget_nodes(nodes.VariableRef): vrefmap.setdefault(vref.name, set()).add(vref) for var in node.defined_vars.itervalues(): - assert not (var.stinfo['references'] ^ vrefmap[var.name]) - assert (var.stinfo['references']) + assert var.stinfo['references'] + assert not (var.stinfo['references'] ^ vrefmap[var.name]), (node.as_string(), var.stinfo['references'], vrefmap[var.name]) + class RQLRewriteTC(TestCase): """a faire: @@ -95,10 +96,10 @@ """ def test_base_var(self): - card_constraint = ('X in_state S, U in_group G, P require_state S,' + constraint = ('X in_state S, U in_group G, P require_state S,' 'P name "read", P require_group G') rqlst = parse('Card C') - rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {}) + rewrite(rqlst, {('C', 'X'): (constraint,)}, {}) self.failUnlessEqual(rqlst.as_string(), u"Any C WHERE C is Card, B eid %(D)s, " "EXISTS(C in_state A, B in_group E, F require_state A, " @@ -130,27 +131,31 @@ "E in_state D, D name 'subscribed'), D is State, E is CWUser)") def test_simplified_rqlst(self): - card_constraint = ('X in_state S, U in_group G, P require_state S,' + constraint = ('X in_state S, U in_group G, P require_state S,' 'P name "read", P require_group G') rqlst = parse('Any 2') # this is the simplified rql st for Any X WHERE X eid 12 - rewrite(rqlst, {('2', 'X'): (card_constraint,)}, {}) + rewrite(rqlst, {('2', 'X'): (constraint,)}, {}) self.failUnlessEqual(rqlst.as_string(), u"Any 2 WHERE B eid %(C)s, " "EXISTS(2 in_state A, B in_group D, E require_state A, " "E name 'read', E require_group D, A is State, D is CWGroup, E is CWPermission)") - def test_optional_var_base(self): - card_constraint = ('X in_state S, U in_group G, P require_state S,' + def test_optional_var_base_1(self): + constraint = ('X in_state S, U in_group G, P require_state S,' 'P name "read", P require_group G') rqlst = parse('Any A,C WHERE A documented_by C?') - rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {}) + rewrite(rqlst, {('C', 'X'): (constraint,)}, {}) self.failUnlessEqual(rqlst.as_string(), "Any A,C WHERE A documented_by C?, A is Affaire " "WITH C BEING " "(Any C WHERE EXISTS(C in_state B, D in_group F, G require_state B, G name 'read', " "G require_group F), D eid %(A)s, C is Card)") + + def test_optional_var_base_2(self): + constraint = ('X in_state S, U in_group G, P require_state S,' + 'P name "read", P require_group G') rqlst = parse('Any A,C,T WHERE A documented_by C?, C title T') - rewrite(rqlst, {('C', 'X'): (card_constraint,)}, {}) + rewrite(rqlst, {('C', 'X'): (constraint,)}, {}) self.failUnlessEqual(rqlst.as_string(), "Any A,C,T WHERE A documented_by C?, A is Affaire " "WITH C,T BEING " @@ -158,6 +163,19 @@ "G require_state B, G name 'read', G require_group F), " "D eid %(A)s, C is Card)") + def test_optional_var_base_3(self): + constraint1 = ('X in_state S, U in_group G, P require_state S,' + 'P name "read", P require_group G') + constraint2 = 'X in_state S, S name "public"' + rqlst = parse('Any A,C,T WHERE A documented_by C?, C title T') + rewrite(rqlst, {('C', 'X'): (constraint1, constraint2)}, {}) + self.failUnlessEqual(rqlst.as_string(), + "Any A,C,T WHERE A documented_by C?, A is Affaire " + "WITH C,T BEING (Any C,T WHERE C title T, " + "EXISTS(C in_state B, D in_group F, G require_state B, G name 'read', G require_group F), " + "D eid %(A)s, C is Card, " + "EXISTS(C in_state E, E name 'public'))") + def test_optional_var_inlined(self): c1 = ('X require_permission P') c2 = ('X inlined_card O, O require_permission P')