--- a/server/msplanner.py Sat Oct 09 00:05:50 2010 +0200
+++ b/server/msplanner.py Sat Oct 09 00:05:52 2010 +0200
@@ -84,9 +84,8 @@
1. return the result of Any X WHERE X owned_by Y from system source, that's
enough (optimization of the sql querier will avoid join on CWUser, so we
will directly get local eids)
-
+"""
-"""
__docformat__ = "restructuredtext en"
from itertools import imap, ifilterfalse
@@ -94,6 +93,7 @@
from logilab.common.compat import any
from logilab.common.decorators import cached
+from rql import BadRQLQuery
from rql.stmts import Union, Select
from rql.nodes import (VariableRef, Comparison, Relation, Constant, Variable,
Not, Exists, SortTerm, Function)
@@ -434,11 +434,14 @@
# add source for relations
rschema = self._schema.rschema
termssources = {}
+ sourcerels = []
for rel in self.rqlst.iget_nodes(Relation):
# process non final relations only
# note: don't try to get schema for 'is' relation (not available
# during bootstrap)
- if not (rel.is_types_restriction() or rschema(rel.r_type).final):
+ if rel.r_type == 'cw_source':
+ sourcerels.append(rel)
+ elif not (rel.is_types_restriction() or rschema(rel.r_type).final):
# nothing to do if relation is not supported by multiple sources
# or if some source has it listed in its cross_relations
# attribute
@@ -469,6 +472,64 @@
self._handle_cross_relation(rel, relsources, termssources)
self._linkedterms.setdefault(lhsv, set()).add((rhsv, rel))
self._linkedterms.setdefault(rhsv, set()).add((lhsv, rel))
+ # extract information from cw_source relation
+ for srel in sourcerels:
+ vref = srel.children[1].children[0]
+ sourceeids, sourcenames = [], []
+ if isinstance(vref, Constant):
+ # simplified variable
+ sourceeids = None, (vref.eval(self.plan.args),)
+ else:
+ var = vref.variable
+ for rel in var.stinfo['relations'] - var.stinfo['rhsrelations']:
+ if rel.r_type in ('eid', 'name'):
+ if rel.r_type == 'eid':
+ slist = sourceeids
+ else:
+ slist = sourcenames
+ sources = [cst.eval(self.plan.args)
+ for cst in rel.children[1].get_nodes(Constant)]
+ if sources:
+ if slist:
+ # don't attempt to do anything
+ sourcenames = sourceeids = None
+ break
+ slist[:] = (rel, sources)
+ if sourceeids:
+ rel, values = sourceeids
+ sourcesdict = self._repo.sources_by_eid
+ elif sourcenames:
+ rel, values = sourcenames
+ sourcesdict = self._repo.sources_by_uri
+ else:
+ sourcesdict = None
+ if sourcesdict is not None:
+ lhs = srel.children[0]
+ try:
+ sources = [sourcesdict[key] for key in values]
+ except KeyError:
+ raise BadRQLQuery('source conflict for term %s' % lhs.as_string())
+ if isinstance(lhs, Constant):
+ source = self._session.source_from_eid(lhs.eval(self.plan.args))
+ if not source in sources:
+ raise BadRQLQuery('source conflict for term %s' % lhs.as_string())
+ else:
+ lhs = getattr(lhs, 'variable', lhs)
+ # XXX NOT NOT
+ neged = srel.neged(traverse_scope=True) or (rel and rel.neged(strict=True))
+ if neged:
+ for source in sources:
+ self._remove_source_term(source, lhs, check=True)
+ else:
+ for source, terms in sourcesterms.items():
+ if lhs in terms and not source in sources:
+ self._remove_source_term(source, lhs, check=True)
+ if rel is None:
+ self._remove_source_term(self.system_source, vref)
+ srel.parent.remove(srel)
+ elif len(var.stinfo['relations']) == 2 and not var.stinfo['selected']:
+ self._remove_source_term(self.system_source, var)
+ self.rqlst.undefine_variable(var)
return termssources
def _handle_cross_relation(self, rel, relsources, termssources):
@@ -713,9 +774,18 @@
assert isinstance(term, (rqlb.BaseNode, Variable)), repr(term)
continue # may occur with subquery column alias
if not sourcesterms[source][term]:
- del sourcesterms[source][term]
- if not sourcesterms[source]:
- del sourcesterms[source]
+ self._remove_source_term(source, term)
+
+ def _remove_source_term(self, source, term, check=False):
+ poped = self._sourcesterms[source].pop(term, None)
+ if not self._sourcesterms[source]:
+ del self._sourcesterms[source]
+ if poped is not None and check:
+ for terms in self._sourcesterms.itervalues():
+ if term in terms:
+ break
+ else:
+ raise BadRQLQuery('source conflict for term %s' % term.as_string())
def crossed_relation(self, source, relation):
return relation in self._crossrelations.get(source, ())