server/msplanner.py
branchstable
changeset 6131 087c5a168010
parent 6129 fea746b60093
child 6427 c8a5ac2d1eaa
--- a/server/msplanner.py	Fri Aug 20 08:36:58 2010 +0200
+++ b/server/msplanner.py	Fri Aug 20 10:59:57 2010 +0200
@@ -110,6 +110,11 @@
 # str() Constant.value to ensure generated table name won't be unicode
 Constant._ms_table_key = lambda x: str(x.value)
 
+Variable._ms_may_be_processed = lambda x, terms, linkedterms: any(
+    t for t in terms if t in linkedterms.get(x, ()))
+Relation._ms_may_be_processed = lambda x, terms, linkedterms: all(
+    getattr(hs, 'variable', hs) in terms for hs in x.get_variable_parts())
+
 def ms_scope(term):
     rel = None
     scope = term.scope
@@ -411,7 +416,8 @@
                 for const in vconsts:
                     self._set_source_for_term(source, const)
             elif not self._sourcesterms:
-                self._set_source_for_term(source, const)
+                for const in vconsts:
+                    self._set_source_for_term(source, const)
             elif source in self._sourcesterms:
                 source_scopes = frozenset(ms_scope(t) for t in self._sourcesterms[source])
                 for const in vconsts:
@@ -480,6 +486,7 @@
                     # not supported by the source, so we can stop here
                     continue
                 self._sourcesterms.setdefault(ssource, {})[rel] = set(self._solindices)
+                solindices = None
                 for term in crossvars:
                     if len(termssources[term]) == 1 and iter(termssources[term]).next()[0].uri == 'system':
                         for ov in crossvars:
@@ -487,8 +494,14 @@
                                 ssset = frozenset((ssource,))
                                 self._remove_sources(ov, termssources[ov] - ssset)
                         break
+                    if solindices is None:
+                        solindices = set(sol for s, sol in termssources[term]
+                                         if s is source)
+                    else:
+                        solindices &= set(sol for s, sol in termssources[term]
+                                          if s is source)
                 else:
-                    self._sourcesterms.setdefault(source, {})[rel] = set(self._solindices)
+                    self._sourcesterms.setdefault(source, {})[rel] = solindices
 
     def _remove_invalid_sources(self, termssources):
         """removes invalid sources from `sourcesterms` member according to
@@ -801,10 +814,13 @@
                                     rhsvar = rhs.variable
                                 except AttributeError:
                                     rhsvar = rhs
-                                if lhsvar in terms and not rhsvar in terms:
-                                    needsel.add(lhsvar.name)
-                                elif rhsvar in terms and not lhsvar in terms:
-                                    needsel.add(rhsvar.name)
+                                try:
+                                    if lhsvar in terms and not rhsvar in terms:
+                                        needsel.add(lhsvar.name)
+                                    elif rhsvar in terms and not lhsvar in terms:
+                                        needsel.add(rhsvar.name)
+                                except AttributeError:
+                                    continue # not an attribute, no selection needed
                 if final and source.uri != 'system':
                     # check rewritten constants
                     for vconsts in select.stinfo['rewritten'].itervalues():
@@ -939,13 +955,14 @@
                 exclude[vars[1]] = vars[0]
             except IndexError:
                 pass
-        accept_term = lambda x: (not any(s for s in sources if not x in sourcesterms.get(s, ()))
-                                 and any(t for t in terms if t in linkedterms.get(x, ()))
+        accept_term = lambda x: (not any(s for s in sources
+                                         if not x in sourcesterms.get(s, ()))
+                                 and x._ms_may_be_processed(terms, linkedterms)
                                  and not exclude.get(x) in terms)
         if isinstance(term, Relation) and term in cross_rels:
             cross_terms = cross_rels.pop(term)
             base_accept_term = accept_term
-            accept_term = lambda x: (base_accept_term(x) or x in cross_terms)
+            accept_term = lambda x: (accept_term(x) or x in cross_terms)
             for refed in cross_terms:
                 if not refed in candidates:
                     terms.append(refed)
@@ -956,7 +973,11 @@
             modified = False
             for term in candidates[:]:
                 if isinstance(term, Constant):
-                    if sorted(set(x[0] for x in self._term_sources(term))) != sources:
+                    termsources = set(x[0] for x in self._term_sources(term))
+                    # ensure system source is there for constant
+                    if self.system_source in sources:
+                        termsources.add(self.system_source)
+                    if sorted(termsources) != sources:
                         continue
                     terms.append(term)
                     candidates.remove(term)
@@ -1614,6 +1635,8 @@
                 for vref in supportedvars:
                     if not vref in newroot.get_selected_variables():
                         newroot.append_selected(VariableRef(newroot.get_variable(vref.name)))
+            elif term in self.terms:
+                newroot.append_selected(term.copy(newroot))
 
     def add_necessary_selection(self, newroot, terms):
         selected = tuple(newroot.get_selected_variables())