server/msplanner.py
changeset 6427 c8a5ac2d1eaa
parent 6131 087c5a168010
child 6598 78eaccfbd2b7
--- a/server/msplanner.py	Sat Oct 09 00:05:50 2010 +0200
+++ b/server/msplanner.py	Sat Oct 09 00:05:52 2010 +0200
@@ -84,9 +84,8 @@
   1. return the result of Any X WHERE X owned_by Y from system source, that's
      enough (optimization of the sql querier will avoid join on CWUser, so we
      will directly get local eids)
-
+"""
 
-"""
 __docformat__ = "restructuredtext en"
 
 from itertools import imap, ifilterfalse
@@ -94,6 +93,7 @@
 from logilab.common.compat import any
 from logilab.common.decorators import cached
 
+from rql import BadRQLQuery
 from rql.stmts import Union, Select
 from rql.nodes import (VariableRef, Comparison, Relation, Constant, Variable,
                        Not, Exists, SortTerm, Function)
@@ -434,11 +434,14 @@
         # add source for relations
         rschema = self._schema.rschema
         termssources = {}
+        sourcerels = []
         for rel in self.rqlst.iget_nodes(Relation):
             # process non final relations only
             # note: don't try to get schema for 'is' relation (not available
             # during bootstrap)
-            if not (rel.is_types_restriction() or rschema(rel.r_type).final):
+            if rel.r_type == 'cw_source':
+                sourcerels.append(rel)
+            elif not (rel.is_types_restriction() or rschema(rel.r_type).final):
                 # nothing to do if relation is not supported by multiple sources
                 # or if some source has it listed in its cross_relations
                 # attribute
@@ -469,6 +472,64 @@
                 self._handle_cross_relation(rel, relsources, termssources)
                 self._linkedterms.setdefault(lhsv, set()).add((rhsv, rel))
                 self._linkedterms.setdefault(rhsv, set()).add((lhsv, rel))
+        # extract information from cw_source relation
+        for srel in sourcerels:
+            vref = srel.children[1].children[0]
+            sourceeids, sourcenames = [], []
+            if isinstance(vref, Constant):
+                # simplified variable
+                sourceeids = None, (vref.eval(self.plan.args),)
+            else:
+                var = vref.variable
+                for rel in var.stinfo['relations'] - var.stinfo['rhsrelations']:
+                    if rel.r_type in ('eid', 'name'):
+                        if rel.r_type == 'eid':
+                            slist = sourceeids
+                        else:
+                            slist = sourcenames
+                        sources = [cst.eval(self.plan.args)
+                                   for cst in rel.children[1].get_nodes(Constant)]
+                        if sources:
+                            if slist:
+                                # don't attempt to do anything
+                                sourcenames = sourceeids = None
+                                break
+                            slist[:] = (rel, sources)
+            if sourceeids:
+                rel, values = sourceeids
+                sourcesdict = self._repo.sources_by_eid
+            elif sourcenames:
+                rel, values = sourcenames
+                sourcesdict = self._repo.sources_by_uri
+            else:
+                sourcesdict = None
+            if sourcesdict is not None:
+                lhs = srel.children[0]
+                try:
+                    sources = [sourcesdict[key] for key in values]
+                except KeyError:
+                    raise BadRQLQuery('source conflict for term %s' % lhs.as_string())
+                if isinstance(lhs, Constant):
+                    source = self._session.source_from_eid(lhs.eval(self.plan.args))
+                    if not source in sources:
+                        raise BadRQLQuery('source conflict for term %s' % lhs.as_string())
+                else:
+                    lhs = getattr(lhs, 'variable', lhs)
+                # XXX NOT NOT
+                neged = srel.neged(traverse_scope=True) or (rel and rel.neged(strict=True))
+                if neged:
+                    for source in sources:
+                        self._remove_source_term(source, lhs, check=True)
+                else:
+                    for source, terms in sourcesterms.items():
+                        if lhs in terms and not source in sources:
+                            self._remove_source_term(source, lhs, check=True)
+                if rel is None:
+                    self._remove_source_term(self.system_source, vref)
+                    srel.parent.remove(srel)
+                elif len(var.stinfo['relations']) == 2 and not var.stinfo['selected']:
+                    self._remove_source_term(self.system_source, var)
+                    self.rqlst.undefine_variable(var)
         return termssources
 
     def _handle_cross_relation(self, rel, relsources, termssources):
@@ -713,9 +774,18 @@
                 assert isinstance(term, (rqlb.BaseNode, Variable)), repr(term)
                 continue # may occur with subquery column alias
             if not sourcesterms[source][term]:
-                del sourcesterms[source][term]
-                if not sourcesterms[source]:
-                    del sourcesterms[source]
+                self._remove_source_term(source, term)
+
+    def _remove_source_term(self, source, term, check=False):
+        poped = self._sourcesterms[source].pop(term, None)
+        if not self._sourcesterms[source]:
+            del self._sourcesterms[source]
+        if poped is not None and check:
+            for terms in self._sourcesterms.itervalues():
+                if term in terms:
+                    break
+            else:
+                raise BadRQLQuery('source conflict for term %s' % term.as_string())
 
     def crossed_relation(self, source, relation):
         return relation in self._crossrelations.get(source, ())