--- a/server/msplanner.py Tue May 25 11:51:48 2010 +0200
+++ b/server/msplanner.py Wed May 26 12:33:48 2010 +0200
@@ -95,7 +95,8 @@
from logilab.common.decorators import cached
from rql.stmts import Union, Select
-from rql.nodes import VariableRef, Comparison, Relation, Constant, Variable
+from rql.nodes import (VariableRef, Comparison, Relation, Constant, Variable,
+ Not, Exists)
from cubicweb import server
from cubicweb.utils import make_uid
@@ -109,6 +110,40 @@
# str() Constant.value to ensure generated table name won't be unicode
Constant._ms_table_key = lambda x: str(x.value)
+def ms_scope(term):
+ rel = None
+ scope = term.scope
+ if isinstance(term, Variable) and len(term.stinfo['relations']) == 1:
+ rel = iter(term.stinfo['relations']).next().relation()
+ elif isinstance(term, Constant):
+ rel = term.relation()
+ elif isinstance(term, Relation):
+ rel = term
+ if rel is not None and (
+ rel.r_type != 'identity' and rel.scope is scope
+ and isinstance(rel.parent, Exists) and rel.parent.neged(strict=True)):
+ return scope.parent.scope
+ return scope
+
+def need_intersect(select, getrschema):
+ for rel in select.iget_nodes(Relation):
+ if isinstance(rel.parent, Exists) and rel.parent.neged(strict=True) and not rel.is_types_restriction():
+ rschema = getrschema(rel.r_type)
+ if not rschema.final:
+ # if one of the relation's variable is ambiguous but not
+ # invariant, an intersection will be necessary
+ for vref in rel.get_nodes(VariableRef):
+ var = vref.variable
+ if (var.valuable_references() == 1
+ and len(var.stinfo['possibletypes']) > 1):
+ return True
+ return False
+
+def neged_relation(rel):
+ parent = rel.parent
+ return isinstance(parent, Not) or (isinstance(parent, Exists) and
+ isinstance(parent.parent, Not))
+
def need_source_access_relation(vargraph):
if not vargraph:
return False
@@ -195,7 +230,7 @@
"""return true if the variable is used in an outer scope of the given scope
"""
for rel in var.stinfo['relations']:
- rscope = rel.scope
+ rscope = ms_scope(rel)
if not rscope is scope and is_ancestor(scope, rscope):
return True
return False
@@ -378,9 +413,9 @@
elif not self._sourcesterms:
self._set_source_for_term(source, const)
elif source in self._sourcesterms:
- source_scopes = frozenset(t.scope for t in self._sourcesterms[source])
+ source_scopes = frozenset(ms_scope(t) for t in self._sourcesterms[source])
for const in vconsts:
- if const.scope in source_scopes:
+ if ms_scope(const) in source_scopes:
self._set_source_for_term(source, const)
# if system source is used, add every rewritten constant
# to its supported terms even when associated entity
@@ -505,12 +540,15 @@
def _remove_sources_until_stable(self, term, termssources):
sourcesterms = self._sourcesterms
for oterm, rel in self._linkedterms.get(term, ()):
- if not term.scope is oterm.scope and rel.scope.neged(strict=True):
+ tscope = ms_scope(term)
+ otscope = ms_scope(oterm)
+ rscope = ms_scope(rel)
+ if not tscope is otscope and rscope.neged(strict=True):
# can't get information from relation inside a NOT exists
# where terms don't belong to the same scope
continue
need_ancestor_scope = False
- if not (term.scope is rel.scope and oterm.scope is rel.scope):
+ if not (tscope is rscope and otscope is rscope):
if rel.ored():
continue
if rel.ored(traverse_scope=True):
@@ -518,7 +556,7 @@
# propagate from parent scope to child scope, nothing else
need_ancestor_scope = True
relsources = self._repo.rel_type_sources(rel.r_type)
- if rel.neged(strict=True) and (
+ if neged_relation(rel) and (
len(relsources) < 2
or not isinstance(oterm, Variable)
or oterm.valuable_references() != 1
@@ -532,9 +570,9 @@
# Y)
continue
# compute invalid sources for terms and remove them
- if not need_ancestor_scope or is_ancestor(term.scope, oterm.scope):
+ if not need_ancestor_scope or is_ancestor(tscope, otscope):
self._remove_term_sources(term, rel, oterm, termssources)
- if not need_ancestor_scope or is_ancestor(oterm.scope, term.scope):
+ if not need_ancestor_scope or is_ancestor(otscope, tscope):
self._remove_term_sources(oterm, rel, term, termssources)
def _remove_term_sources(self, term, rel, oterm, termssources):
@@ -693,7 +731,7 @@
sourceterms.clear()
sources = [source]
else:
- scope = term.scope
+ scope = ms_scope(term)
# find which sources support the same term and solutions
sources = self._expand_sources(source, term, solindices)
# no try to get as much terms as possible
@@ -779,7 +817,7 @@
# `terms`, eg cross relations)
for c in vconsts:
rel = c.relation()
- if rel is None or not (rel in terms or rel.neged(strict=True)):
+ if rel is None or not (rel in terms or neged_relation(rel)):
final = False
break
break
@@ -802,13 +840,13 @@
# variable is refed by an outer scope and should be substituted
# using an 'identity' relation (else we'll get a conflict of
# temporary tables)
- if rhsvar in terms and not lhsvar in terms and lhsvar.scope is lhsvar.stmt:
+ if rhsvar in terms and not lhsvar in terms and ms_scope(lhsvar) is lhsvar.stmt:
self._identity_substitute(rel, lhsvar, terms, needsel)
- elif lhsvar in terms and not rhsvar in terms and rhsvar.scope is rhsvar.stmt:
+ elif lhsvar in terms and not rhsvar in terms and ms_scope(rhsvar) is rhsvar.stmt:
self._identity_substitute(rel, rhsvar, terms, needsel)
def _identity_substitute(self, relation, var, terms, needsel):
- newvar = self._insert_identity_variable(relation.scope, var)
+ newvar = self._insert_identity_variable(ms_scope(relation), var)
# ensure relation is using '=' operator, else we rely on a
# sqlgenerator side effect (it won't insert an inequality operator
# in this case)
@@ -824,14 +862,14 @@
if len(self._sourcesterms) > 1:
# priority to variable from subscopes
for term in sourceterms:
- if not term.scope is self.rqlst:
+ if not ms_scope(term) is self.rqlst:
if isinstance(term, Variable):
return term, sourceterms.pop(term)
secondchoice = term
else:
# priority to variable from outer scope
for term in sourceterms:
- if term.scope is self.rqlst:
+ if ms_scope(term) is self.rqlst:
if isinstance(term, Variable):
return term, sourceterms.pop(term)
secondchoice = term
@@ -881,7 +919,7 @@
# term has to belong to the same scope if there is more
# than the system source remaining
if len(sourcesterms) > 1 and not scope is self.rqlst:
- candidates = (t for t in sourceterms.keys() if scope is t.scope)
+ candidates = (t for t in sourceterms.keys() if scope is ms_scope(t))
else:
candidates = sourceterms #.iterkeys()
# we only want one unlinked term in each generated query
@@ -1200,9 +1238,10 @@
step = AggrStep(plan, selection, select, atemptable, temptable)
step.children = steps
elif len(steps) > 1:
- if select.need_intersect or any(select.need_intersect
- for step in steps
- for select in step.union.children):
+ getrschema = self.schema.rschema
+ if need_intersect(select, getrschema) or any(need_intersect(select, getrschema)
+ for step in steps
+ for select in step.union.children):
if temptable:
step = IntersectFetchStep(plan) # XXX not implemented
else: