server/msplanner.py
branchstable
changeset 5582 3e133b29a1a4
parent 5426 0d4853a6e5ee
child 5768 1e73a466aa69
--- a/server/msplanner.py	Tue May 25 12:21:17 2010 +0200
+++ b/server/msplanner.py	Wed May 26 10:28:48 2010 +0200
@@ -95,7 +95,8 @@
 from logilab.common.decorators import cached
 
 from rql.stmts import Union, Select
-from rql.nodes import VariableRef, Comparison, Relation, Constant, Variable
+from rql.nodes import (VariableRef, Comparison, Relation, Constant, Variable,
+                       Not, Exists)
 
 from cubicweb import server
 from cubicweb.utils import make_uid
@@ -109,6 +110,40 @@
 # str() Constant.value to ensure generated table name won't be unicode
 Constant._ms_table_key = lambda x: str(x.value)
 
+def ms_scope(term):
+    rel = None
+    scope = term.scope
+    if isinstance(term, Variable) and len(term.stinfo['relations']) == 1:
+        rel = iter(term.stinfo['relations']).next().relation()
+    elif isinstance(term, Constant):
+        rel = term.relation()
+    elif isinstance(term, Relation):
+        rel = term
+    if rel is not None and (
+        rel.r_type != 'identity' and rel.scope is scope
+        and isinstance(rel.parent, Exists) and rel.parent.neged(strict=True)):
+        return scope.parent.scope
+    return scope
+
+def need_intersect(select, getrschema):
+    for rel in select.iget_nodes(Relation):
+        if isinstance(rel.parent, Exists) and rel.parent.neged(strict=True) and not rel.is_types_restriction():
+            rschema = getrschema(rel.r_type)
+            if not rschema.final:
+                # if one of the relation's variable is ambiguous but not
+                # invariant, an intersection will be necessary
+                for vref in rel.get_nodes(VariableRef):
+                    var = vref.variable
+                    if (var.valuable_references() == 1
+                        and len(var.stinfo['possibletypes']) > 1):
+                        return True
+    return False
+
+def neged_relation(rel):
+    parent = rel.parent
+    return isinstance(parent, Not) or (isinstance(parent, Exists) and
+                                       isinstance(parent.parent, Not))
+
 def need_source_access_relation(vargraph):
     if not vargraph:
         return False
@@ -195,7 +230,7 @@
     """return true if the variable is used in an outer scope of the given scope
     """
     for rel in var.stinfo['relations']:
-        rscope = rel.scope
+        rscope = ms_scope(rel)
         if not rscope is scope and is_ancestor(scope, rscope):
             return True
     return False
@@ -378,9 +413,9 @@
             elif not self._sourcesterms:
                 self._set_source_for_term(source, const)
             elif source in self._sourcesterms:
-                source_scopes = frozenset(t.scope for t in self._sourcesterms[source])
+                source_scopes = frozenset(ms_scope(t) for t in self._sourcesterms[source])
                 for const in vconsts:
-                    if const.scope in source_scopes:
+                    if ms_scope(const) in source_scopes:
                         self._set_source_for_term(source, const)
                         # if system source is used, add every rewritten constant
                         # to its supported terms even when associated entity
@@ -505,12 +540,15 @@
     def _remove_sources_until_stable(self, term, termssources):
         sourcesterms = self._sourcesterms
         for oterm, rel in self._linkedterms.get(term, ()):
-            if not term.scope is oterm.scope and rel.scope.neged(strict=True):
+            tscope = ms_scope(term)
+            otscope = ms_scope(oterm)
+            rscope = ms_scope(rel)
+            if not tscope is otscope and rscope.neged(strict=True):
                 # can't get information from relation inside a NOT exists
                 # where terms don't belong to the same scope
                 continue
             need_ancestor_scope = False
-            if not (term.scope is rel.scope and oterm.scope is rel.scope):
+            if not (tscope is rscope and otscope is rscope):
                 if rel.ored():
                     continue
                 if rel.ored(traverse_scope=True):
@@ -518,7 +556,7 @@
                     # propagate from parent scope to child scope, nothing else
                     need_ancestor_scope = True
             relsources = self._repo.rel_type_sources(rel.r_type)
-            if rel.neged(strict=True) and (
+            if neged_relation(rel) and (
                 len(relsources) < 2
                 or not isinstance(oterm, Variable)
                 or oterm.valuable_references() != 1
@@ -532,9 +570,9 @@
                 # Y)
                 continue
             # compute invalid sources for terms and remove them
-            if not need_ancestor_scope or is_ancestor(term.scope, oterm.scope):
+            if not need_ancestor_scope or is_ancestor(tscope, otscope):
                 self._remove_term_sources(term, rel, oterm, termssources)
-            if not need_ancestor_scope or is_ancestor(oterm.scope, term.scope):
+            if not need_ancestor_scope or is_ancestor(otscope, tscope):
                 self._remove_term_sources(oterm, rel, term, termssources)
 
     def _remove_term_sources(self, term, rel, oterm, termssources):
@@ -693,7 +731,7 @@
                     sourceterms.clear()
                     sources = [source]
                 else:
-                    scope = term.scope
+                    scope = ms_scope(term)
                     # find which sources support the same term and solutions
                     sources = self._expand_sources(source, term, solindices)
                     # no try to get as much terms as possible
@@ -779,7 +817,7 @@
                             # `terms`, eg cross relations)
                             for c in vconsts:
                                 rel = c.relation()
-                                if rel is None or not (rel in terms or rel.neged(strict=True)):
+                                if rel is None or not (rel in terms or neged_relation(rel)):
                                     final = False
                                     break
                             break
@@ -802,13 +840,13 @@
             # variable is refed by an outer scope and should be substituted
             # using an 'identity' relation (else we'll get a conflict of
             # temporary tables)
-            if rhsvar in terms and not lhsvar in terms and lhsvar.scope is lhsvar.stmt:
+            if rhsvar in terms and not lhsvar in terms and ms_scope(lhsvar) is lhsvar.stmt:
                 self._identity_substitute(rel, lhsvar, terms, needsel)
-            elif lhsvar in terms and not rhsvar in terms and rhsvar.scope is rhsvar.stmt:
+            elif lhsvar in terms and not rhsvar in terms and ms_scope(rhsvar) is rhsvar.stmt:
                 self._identity_substitute(rel, rhsvar, terms, needsel)
 
     def _identity_substitute(self, relation, var, terms, needsel):
-        newvar = self._insert_identity_variable(relation.scope, var)
+        newvar = self._insert_identity_variable(ms_scope(relation), var)
         # ensure relation is using '=' operator, else we rely on a
         # sqlgenerator side effect (it won't insert an inequality operator
         # in this case)
@@ -824,14 +862,14 @@
         if len(self._sourcesterms) > 1:
             # priority to variable from subscopes
             for term in sourceterms:
-                if not term.scope is self.rqlst:
+                if not ms_scope(term) is self.rqlst:
                     if isinstance(term, Variable):
                         return term, sourceterms.pop(term)
                     secondchoice = term
         else:
             # priority to variable from outer scope
             for term in sourceterms:
-                if term.scope is self.rqlst:
+                if ms_scope(term) is self.rqlst:
                     if isinstance(term, Variable):
                         return term, sourceterms.pop(term)
                     secondchoice = term
@@ -881,7 +919,7 @@
         # term has to belong to the same scope if there is more
         # than the system source remaining
         if len(sourcesterms) > 1 and not scope is self.rqlst:
-            candidates = (t for t in sourceterms.keys() if scope is t.scope)
+            candidates = (t for t in sourceterms.keys() if scope is ms_scope(t))
         else:
             candidates = sourceterms #.iterkeys()
         # we only want one unlinked term in each generated query
@@ -1200,9 +1238,10 @@
             step = AggrStep(plan, selection, select, atemptable, temptable)
             step.children = steps
         elif len(steps) > 1:
-            if select.need_intersect or any(select.need_intersect
-                                            for step in steps
-                                            for select in step.union.children):
+            getrschema = self.schema.rschema
+            if need_intersect(select, getrschema) or any(need_intersect(select, getrschema)
+                                                         for step in steps
+                                                         for select in step.union.children):
                 if temptable:
                     step = IntersectFetchStep(plan) # XXX not implemented
                 else: