rqlrewrite.py
branchstable
changeset 7139 20807d3d7cf6
parent 7138 9aba650eea6b
child 7176 f2a976cf7dac
--- a/rqlrewrite.py	Wed Mar 30 11:07:16 2011 +0200
+++ b/rqlrewrite.py	Wed Mar 30 11:08:15 2011 +0200
@@ -24,6 +24,7 @@
 __docformat__ = "restructuredtext en"
 
 from rql import nodes as n, stmts, TypeResolverException
+from rql.utils import common_parent
 from yams import BadSchemaDefinition
 from logilab.common.graph import has_path
 
@@ -180,13 +181,23 @@
     def insert_snippets(self, snippets, varexistsmap=None):
         self.rewritten = {}
         for varmap, rqlexprs in snippets:
+            if isinstance(varmap, dict):
+                varmap = tuple(sorted(varmap.items()))
+            else:
+                assert isinstance(varmap, tuple), varmap
             if varexistsmap is not None and not varmap in varexistsmap:
                 continue
-            self.varmap = varmap
-            selectvar, snippetvar = varmap
+            self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
+
+    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+        self.varmap = varmap
+        self.revvarmap = {}
+        self.varinfos = []
+        for i, (selectvar, snippetvar) in enumerate(varmap):
             assert snippetvar in 'SOX'
-            self.revvarmap = {snippetvar: selectvar}
-            self.varinfo = vi = {}
+            self.revvarmap[snippetvar] = (selectvar, i)
+            vi = {}
+            self.varinfos.append(vi)
             try:
                 vi['const'] = typed_eid(selectvar) # XXX gae
                 vi['rhs_rels'] = vi['lhs_rels'] = {}
@@ -194,42 +205,42 @@
                 try:
                     vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                 except KeyError:
-                    # variable has been moved to a newly inserted subquery
+                    # variable may have been moved to a newly inserted subquery
                     # we should insert snippet in that subquery
                     subquery = self.select.aliases[selectvar].query
                     assert len(subquery.children) == 1
                     subselect = subquery.children[0]
                     RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
                                                       subselect.solutions, self.kwargs)
-                    continue
+                    return
                 if varexistsmap is None:
                     vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                     vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                            if not r in sti['rhsrelations'])
                 else:
                     vi['rhs_rels'] = vi['lhs_rels'] = {}
-            parent = None
-            inserted = False
-            for rqlexpr in rqlexprs:
-                self.current_expr = rqlexpr
-                if varexistsmap is None:
-                    try:
-                        new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
-                    except Unsupported:
-                        continue
-                    inserted = True
-                    if new is not None:
-                        self.exists_snippet[rqlexpr] = new
-                    parent = parent or new
-                else:
-                    # called to reintroduce snippet due to ambiguity creation,
-                    # so skip snippets which are not introducing this ambiguity
-                    exists = varexistsmap[varmap]
-                    if self.exists_snippet[rqlexpr] is exists:
-                        self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
-            if varexistsmap is None and not inserted:
-                # no rql expression found matching rql solutions. User has no access right
-                raise Unauthorized()
+        parent = None
+        inserted = False
+        for rqlexpr in rqlexprs:
+            self.current_expr = rqlexpr
+            if varexistsmap is None:
+                try:
+                    new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
+                except Unsupported:
+                    continue
+                inserted = True
+                if new is not None:
+                    self.exists_snippet[rqlexpr] = new
+                parent = parent or new
+            else:
+                # called to reintroduce snippet due to ambiguity creation,
+                # so skip snippets which are not introducing this ambiguity
+                exists = varexistsmap[varmap]
+                if self.exists_snippet[rqlexpr] is exists:
+                    self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
+        if varexistsmap is None and not inserted:
+            # no rql expression found matching rql solutions. User has no access right
+            raise Unauthorized() # XXX bad constraint when inserting constraints
 
     def insert_snippet(self, varmap, snippetrqlst, parent=None):
         new = snippetrqlst.where.accept(self)
@@ -243,10 +254,16 @@
     def _insert_snippet(self, varmap, parent, new):
         if new is not None:
             if self._insert_scope is None:
-                insert_scope = self.varinfo.get('stinfo', {}).get('scope', self.select)
+                insert_scope = None
+                for vi in self.varinfos:
+                    scope = vi.get('stinfo', {}).get('scope', self.select)
+                    if insert_scope is None:
+                        insert_scope = scope
+                    else:
+                        insert_scope = common_parent(scope, insert_scope)
             else:
                 insert_scope = self._insert_scope
-            if self.varinfo.get('stinfo', {}).get('optrelations'):
+            if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos):
                 assert parent is None
                 self._insert_scope = self.snippet_subquery(varmap, new)
                 self.insert_pending()
@@ -292,7 +309,7 @@
                 varname = self.rewritten[key]
             except KeyError:
                 try:
-                    varname = self.revvarmap[key[-1]]
+                    varname = self.revvarmap[key[-1]][0]
                 except KeyError:
                     # variable isn't used anywhere else, we can't insert security
                     raise Unauthorized()
@@ -309,45 +326,51 @@
                 rqlexprs = eschema.get_rqlexprs(action)
                 if not rqlexprs:
                     raise Unauthorized()
-                self.insert_snippets([((varname, 'X'), rqlexprs)])
+                self.insert_snippets([({varname: 'X'}, rqlexprs)])
 
     def snippet_subquery(self, varmap, transformedsnippet):
         """introduce the given snippet in a subquery"""
         subselect = stmts.Select()
-        selectvar = varmap[0]
-        subselectvar = subselect.get_variable(selectvar)
-        subselect.append_selected(n.VariableRef(subselectvar))
         snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
-        aliases = [selectvar]
-        stinfo = self.varinfo['stinfo']
-        need_null_test = False
-        for rel in stinfo['relations']:
-            rschema = self.schema.rschema(rel.r_type)
-            if rschema.final or (rschema.inlined and
-                                 not rel in stinfo['rhsrelations']):
-                rel.children[0].name = selectvar # XXX explain why
-                subselect.add_restriction(rel.copy(subselect))
-                for vref in rel.children[1].iget_nodes(n.VariableRef):
-                    if isinstance(vref.variable, n.ColumnAlias):
-                        # XXX could probably be handled by generating the subquery
-                        # into the detected subquery
-                        raise BadSchemaDefinition(
-                            "cant insert security because of usage two inlined "
-                            "relations in this query. You should probably at "
-                            "least uninline %s" % rel.r_type)
-                    subselect.append_selected(vref.copy(subselect))
-                    aliases.append(vref.name)
-                self.select.remove_node(rel)
-                # when some inlined relation has to be copied in the subquery,
-                # we need to test that either value is NULL or that the snippet
-                # condition is satisfied
-                if rschema.inlined and rel.optional:
-                    need_null_test = True
-        if need_null_test:
-            snippetrqlst = n.Or(
-                n.make_relation(subselectvar, 'is', (None, None), n.Constant,
-                                operator='='),
-                snippetrqlst)
+        aliases = []
+        rels_done = set()
+        for i, (selectvar, snippetvar) in enumerate(varmap):
+            subselectvar = subselect.get_variable(selectvar)
+            subselect.append_selected(n.VariableRef(subselectvar))
+            aliases.append(selectvar)
+            vi = self.varinfos[i]
+            need_null_test = False
+            stinfo = vi['stinfo']
+            for rel in stinfo['relations']:
+                if rel in rels_done:
+                    continue
+                rels_done.add(rel)
+                rschema = self.schema.rschema(rel.r_type)
+                if rschema.final or (rschema.inlined and
+                                     not rel in stinfo['rhsrelations']):
+                    rel.children[0].name = selectvar # XXX explain why
+                    subselect.add_restriction(rel.copy(subselect))
+                    for vref in rel.children[1].iget_nodes(n.VariableRef):
+                        if isinstance(vref.variable, n.ColumnAlias):
+                            # XXX could probably be handled by generating the
+                            # subquery into the detected subquery
+                            raise BadSchemaDefinition(
+                                "cant insert security because of usage two inlined "
+                                "relations in this query. You should probably at "
+                                "least uninline %s" % rel.r_type)
+                        subselect.append_selected(vref.copy(subselect))
+                        aliases.append(vref.name)
+                    self.select.remove_node(rel)
+                    # when some inlined relation has to be copied in the
+                    # subquery, we need to test that either value is NULL or
+                    # that the snippet condition is satisfied
+                    if rschema.inlined and rel.optional:
+                        need_null_test = True
+            if need_null_test:
+                snippetrqlst = n.Or(
+                    n.make_relation(subselectvar, 'is', (None, None), n.Constant,
+                                    operator='='),
+                    snippetrqlst)
         subselect.add_restriction(snippetrqlst)
         if self.u_varname:
             # generate an identifier for the substitution
@@ -439,30 +462,32 @@
         original query, return that relation node
         """
         rschema = self.schema.rschema(sniprel.r_type)
-        try:
-            if target == 'object':
-                orel = self.varinfo['lhs_rels'][sniprel.r_type]
-                cardindex = 0
-                ttypes_func = rschema.objects
-                rdef = rschema.rdef
-            else: # target == 'subject':
-                orel = self.varinfo['rhs_rels'][sniprel.r_type]
-                cardindex = 1
-                ttypes_func = rschema.subjects
-                rdef = lambda x, y: rschema.rdef(y, x)
-        except KeyError:
-            # may be raised by self.varinfo['xhs_rels'][sniprel.r_type]
-            return None
-        # can't share neged relation or relations with different outer join
-        if (orel.neged(strict=True) or sniprel.neged(strict=True)
-            or (orel.optional and orel.optional != sniprel.optional)):
-            return None
-        # if cardinality is in '?1', we can ignore the snippet relation and use
-        # variable from the original query
-        for etype in self.varinfo['stinfo']['possibletypes']:
-            for ttype in ttypes_func(etype):
-                if rdef(etype, ttype).cardinality[cardindex] in '+*':
-                    return None
+        for vi in self.varinfos:
+            try:
+                if target == 'object':
+                    orel = vi['lhs_rels'][sniprel.r_type]
+                    cardindex = 0
+                    ttypes_func = rschema.objects
+                    rdef = rschema.rdef
+                else: # target == 'subject':
+                    orel = vi['rhs_rels'][sniprel.r_type]
+                    cardindex = 1
+                    ttypes_func = rschema.subjects
+                    rdef = lambda x, y: rschema.rdef(y, x)
+            except KeyError:
+                # may be raised by vi['xhs_rels'][sniprel.r_type]
+                return None
+            # can't share neged relation or relations with different outer join
+            if (orel.neged(strict=True) or sniprel.neged(strict=True)
+                or (orel.optional and orel.optional != sniprel.optional)):
+                return None
+            # if cardinality is in '?1', we can ignore the snippet relation and use
+            # variable from the original query
+            for etype in vi['stinfo']['possibletypes']:
+                for ttype in ttypes_func(etype):
+                    if rdef(etype, ttype).cardinality[cardindex] in '+*':
+                        return None
+            break
         return orel
 
     def _use_orig_term(self, snippet_varname, term):
@@ -601,10 +626,11 @@
     def visit_variableref(self, node):
         """get the sql name for a variable reference"""
         if node.name in self.revvarmap:
-            if self.varinfo.get('const') is not None:
-                return n.Constant(self.varinfo['const'], 'Int') # XXX gae
-            return n.VariableRef(self.select.get_variable(
-                self.revvarmap[node.name]))
+            selectvar, index = self.revvarmap[node.name]
+            vi = self.varinfos[index]
+            if vi.get('const') is not None:
+                return n.Constant(vi['const'], 'Int') # XXX gae
+            return n.VariableRef(self.select.get_variable(selectvar))
         vname_or_term = self._get_varname_or_term(node.name)
         if isinstance(vname_or_term, basestring):
             return n.VariableRef(self.select.get_variable(vname_or_term))