server/msplanner.py
changeset 1231 1457a545af03
parent 1230 232e16835fff
child 1237 c836bdb3b17b
--- a/server/msplanner.py	Sat Apr 04 13:21:56 2009 +0200
+++ b/server/msplanner.py	Sat Apr 04 14:44:04 2009 +0200
@@ -223,11 +223,13 @@
         self.needsplit = False
         self.temptable = None
         self.finaltable = None
+        # shortcuts
         self._schema = plan.schema
         self._session = plan.session
         self._repo = self._session.repo
         self._solutions = rqlst.solutions
         self._solindices = range(len(self._solutions))
+        self.system_source = self._repo.system_source
         # source : {term: [solution index, ]}
         self.sourcesterms = self._sourcesterms = {}
         # source : {relation: set(child variable and constant)}
@@ -270,12 +272,12 @@
     def part_sources(self):
         if self._sourcesterms:
             return tuple(sorted(self._sourcesterms))
-        return (self._repo.system_source,)
+        return (self.system_source,)
     
     @property
     @cached
     def _sys_source_set(self):
-        return frozenset((self._repo.system_source, solindex)
+        return frozenset((self.system_source, solindex)
                          for solindex in self._solindices)        
        
     @cached
@@ -309,7 +311,7 @@
                         source = self._session.source_from_eid(eid)
                         if vrels and not any(source.support_relation(r.r_type)
                                              for r in vrels):
-                            self._set_source_for_term(repo.system_source, varobj)
+                            self._set_source_for_term(self.system_source, varobj)
                         else:
                             self._set_source_for_term(source, varobj)
                 continue
@@ -317,7 +319,7 @@
             if not rels and not varobj.stinfo['typerels']:
                 # (rare) case where the variable has no type specified nor
                 # relation accessed ex. "Any MAX(X)"
-                self._set_source_for_term(repo.system_source, varobj)
+                self._set_source_for_term(self.system_source, varobj)
                 continue
             for i, sol in enumerate(self._solutions):
                 vartype = sol[varname]
@@ -344,7 +346,7 @@
         for vconsts in self.rqlst.stinfo['rewritten'].itervalues():
             const = vconsts[0]
             source = self._session.source_from_eid(const.eval(self.plan.args))
-            if source is repo.system_source:
+            if source is self.system_source:
                 for const in vconsts:
                     self._set_source_for_term(source, const)
             elif source in self._sourcesterms:
@@ -357,8 +359,8 @@
                         # doesn't actually comes from it so we get a changes
                         # that allequals will return True as expected when
                         # computing needsplit
-                        if repo.system_source in sourcesterms:
-                            self._set_source_for_term(repo.system_source, const)
+                        if self.system_source in sourcesterms:
+                            self._set_source_for_term(self.system_source, const)
         # add source for relations
         rschema = self._schema.rschema
         termssources = {}
@@ -400,7 +402,7 @@
     def _handle_cross_relation(self, rel, relsources, termssources):
         for source in relsources:
             if rel.r_type in source.cross_relations:
-                ssource = self._repo.system_source
+                ssource = self.system_source
                 crossvars = set(x.variable for x in rel.get_nodes(VariableRef))
                 for const in rel.get_nodes(Constant):
                     if source.uri != 'system' and not const in self._sourcesterms.get(source, ()):
@@ -520,9 +522,9 @@
                 # if the term is a not invariant variable, we should filter out
                 # source where the relation is a cross relation from invalid
                 # sources
-                invalid_sources = frozenset([(s, solidx) for s, solidx in invalid_sources
-                                             if not (s in self._crossrelations and
-                                                     rel in self._crossrelations[s])])
+                invalid_sources = frozenset((s, solidx) for s, solidx in invalid_sources
+                                            if not (s in self._crossrelations and
+                                                    rel in self._crossrelations[s]))
         if invalid_sources:
             self._remove_sources(term, invalid_sources)
             termsources -= invalid_sources
@@ -620,7 +622,10 @@
         select = self.rqlst
         rschema = self._schema.rschema
         for source in self.part_sources:
-            sourceterms = self._sourcesterms[source]
+            try:
+                sourceterms = self._sourcesterms[source]
+            except KeyError:
+                continue # already proceed
             while sourceterms:
                 # take a term randomly, and all terms supporting the
                 # same solutions
@@ -646,25 +651,66 @@
                         # go to the next iteration directly!
                         continue
                     if not sourceterms:
-                        del self._sourcesterms[source]
-                # suppose this is a final step until the contrary is proven
-                final = scope is select
+                         try:
+                             del self._sourcesterms[source]
+                         except KeyError:
+                             # XXX already cleaned
+                             pass
                 # set of terms which should be additionaly selected when
                 # possible
                 needsel = set()
-                # add attribute variables and mark variables which should be
-                # additionaly selected when possible
-                for var in select.defined_vars.itervalues():
-                    if not var in terms:
-                        stinfo = var.stinfo
-                        for ovar, rtype in stinfo['attrvars']:
-                            if ovar in terms:
+                if not self._sourcesterms:
+                    terms += scope.defined_vars.values() + scope.aliases.values()
+                    final = True
+                else:
+                    # suppose this is a final step until the contrary is proven
+                    final = scope is select
+                    # add attribute variables and mark variables which should be
+                    # additionaly selected when possible
+                    for var in select.defined_vars.itervalues():
+                        if not var in terms:
+                            stinfo = var.stinfo
+                            for ovar, rtype in stinfo['attrvars']:
+                                if ovar in terms:
+                                    needsel.add(var.name)
+                                    terms.append(var)
+                                    break
+                            else:
                                 needsel.add(var.name)
-                                terms.append(var)
+                                final = False
+                    # check where all relations are supported by the sources
+                    for rel in scope.iget_nodes(Relation):
+                        if rel.is_types_restriction():
+                            continue
+                        # take care not overwriting the existing "source" identifier
+                        for _source in sources:
+                            if not _source.support_relation(rel.r_type) or (
+                                self.crossed_relation(_source, rel) and not rel in terms):
+                                for vref in rel.iget_nodes(VariableRef):
+                                    needsel.add(vref.name)
+                                final = False
                                 break
                         else:
-                            needsel.add(var.name)
-                            final = False
+                            if not scope is select:
+                                self._exists_relation(rel, terms, needsel)
+                            # if relation is supported by all sources and some of
+                            # its lhs/rhs variable isn't in "terms", and the
+                            # other end *is* in "terms", mark it have to be
+                            # selected
+                            if source.uri != 'system' and not rschema(rel.r_type).is_final():
+                                lhs, rhs = rel.get_variable_parts()
+                                try:
+                                    lhsvar = lhs.variable
+                                except AttributeError:
+                                    lhsvar = lhs
+                                try:
+                                    rhsvar = rhs.variable
+                                except AttributeError:
+                                    rhsvar = rhs
+                                if lhsvar in terms and not rhsvar in terms:
+                                    needsel.add(lhsvar.name)
+                                elif rhsvar in terms and not lhsvar in terms:
+                                    needsel.add(rhsvar.name)
                 if final and source.uri != 'system':
                     # check rewritten constants
                     for vconsts in select.stinfo['rewritten'].itervalues():
@@ -683,41 +729,6 @@
                                     final = False
                                     break
                             break
-                # check where all relations are supported by the sources
-                for rel in scope.iget_nodes(Relation):
-                    if rel.is_types_restriction():
-                        continue
-                    # take care not overwriting the existing "source" identifier
-                    for _source in sources:
-                        if not _source.support_relation(rel.r_type):
-                            for vref in rel.iget_nodes(VariableRef):
-                                needsel.add(vref.name)
-                            final = False
-                            break
-                        elif self.crossed_relation(_source, rel) and not rel in terms:
-                            final = False
-                            break
-                    else:
-                        if not scope is select:
-                            self._exists_relation(rel, terms, needsel)
-                        # if relation is supported by all sources and some of
-                        # its lhs/rhs variable isn't in "terms", and the
-                        # other end *is* in "terms", mark it have to be
-                        # selected
-                        if source.uri != 'system' and not rschema(rel.r_type).is_final():
-                            lhs, rhs = rel.get_variable_parts()
-                            try:
-                                lhsvar = lhs.variable
-                            except AttributeError:
-                                lhsvar = lhs
-                            try:
-                                rhsvar = rhs.variable
-                            except AttributeError:
-                                rhsvar = rhs
-                            if lhsvar in terms and not rhsvar in terms:
-                                needsel.add(lhsvar.name)
-                            elif rhsvar in terms and not lhsvar in terms:
-                                needsel.add(rhsvar.name)
                 if final:
                     self._cleanup_sourcesterms(sources, solindices)
                 steps.append((sources, terms, solindices, scope, needsel, final)
@@ -759,75 +770,93 @@
         secondchoice = None
         if len(self._sourcesterms) > 1:
             # priority to variable from subscopes
-            for var in sourceterms:
-                if not var.scope is self.rqlst:
-                    if isinstance(var, Variable):
-                        return var, sourceterms.pop(var)
-                    secondchoice = var
+            for term in sourceterms:
+                if not term.scope is self.rqlst:
+                    if isinstance(term, Variable):
+                        return term, sourceterms.pop(term)
+                    secondchoice = term
         else:
-            # priority to variable outer scope
-            for var in sourceterms:
-                if var.scope is self.rqlst:
-                    if isinstance(var, Variable):
-                        return var, sourceterms.pop(var)
-                    secondchoice = var
+            # priority to variable from outer scope
+            for term in sourceterms:
+                if term.scope is self.rqlst:
+                    if isinstance(term, Variable):
+                        return term, sourceterms.pop(term)
+                    secondchoice = term
         if secondchoice is not None:
             return secondchoice, sourceterms.pop(secondchoice)
-        # priority to variable
-        for var in sourceterms:
-            if isinstance(var, Variable):
-                return var, sourceterms.pop(var)
-        # whatever
-        var = iter(sourceterms).next()
-        return var, sourceterms.pop(var)
+        # priority to variable with the less solutions supported and with the
+        # most valuable refs
+        variables = sorted([(var, sols) for (var, sols) in sourceterms.items()
+                            if isinstance(var, Variable)],
+                           key=lambda (v, s): (len(s), -v.valuable_references()))
+        if variables:
+            var = variables[0][0]
+            return var, sourceterms.pop(var)
+        # priority to constant
+        for term in sourceterms:
+            if isinstance(term, Constant):
+                return term, sourceterms.pop(term)
+        # whatever (relation)
+        term = iter(sourceterms).next()
+        return term, sourceterms.pop(term)
             
     def _expand_sources(self, selected_source, term, solindices):
         """return all sources supporting given term / solindices"""
         sources = [selected_source]
         sourcesterms = self._sourcesterms
-        for source in sourcesterms:
+        for source in sourcesterms.keys():
             if source is selected_source:
                 continue
             if not (term in sourcesterms[source] and 
                     solindices.issubset(sourcesterms[source][term])):
                 continue
             sources.append(source)
-            if source.uri != 'system':
+            if source.uri != 'system' or not (isinstance(term, Variable) and not term in self._linkedterms):
                 termsolindices = sourcesterms[source][term]
                 termsolindices -= solindices
                 if not termsolindices:
-                    del sourcesterms[source][term]                
+                    del sourcesterms[source][term]
+                    if not sourcesterms[source]:
+                        del sourcesterms[source]
         return sources
             
     def _expand_terms(self, term, sources, sourceterms, scope, solindices):
         terms = [term]
         sources = sorted(sources)
+        sourcesterms = self._sourcesterms
         nbunlinked = 1
         linkedterms = self._linkedterms
         # term has to belong to the same scope if there is more
         # than the system source remaining
-        if len(self._sourcesterms) > 1 and not scope is self.rqlst:
+        if len(sourcesterms) > 1 and not scope is self.rqlst:
             candidates = (t for t in sourceterms.keys() if scope is t.scope)
         else:
             candidates = sourceterms #.iterkeys()
         # we only want one unlinked term in each generated query
         candidates = [t for t in candidates
-                      if isinstance(t, Constant) or
+                      if isinstance(t, (Constant, Relation)) or
                       (solindices.issubset(sourceterms[t]) and t in linkedterms)]
-        accept_term = lambda x: (not any(s for s in sources if not x in self._sourcesterms[s])
-                                and any(t for t in terms if t in linkedterms.get(x, ())))
-        source_cross_rels = {}
+        cross_rels = {}
         for source in sources:
-            source_cross_rels.update(self._crossrelations.get(source, {}))
-        if isinstance(term, Relation) and term in source_cross_rels:
-            cross_terms = source_cross_rels.pop(term)
+            cross_rels.update(self._crossrelations.get(source, {}))
+        exclude = {}
+        for rel, crossvars in cross_rels.iteritems():
+            vars = [t for t in crossvars if isinstance(t, Variable)]
+            try:
+                exclude[vars[0]] = vars[1]
+                exclude[vars[1]] = vars[0]
+            except IndexError:
+                pass
+        accept_term = lambda x: (not any(s for s in sources if not x in sourcesterms.get(s, ()))
+                                 and any(t for t in terms if t in linkedterms.get(x, ()))
+                                 and not exclude.get(x) in terms)
+        if isinstance(term, Relation) and term in cross_rels:
+            cross_terms = cross_rels.pop(term)
             base_accept_term = accept_term
             accept_term = lambda x: (base_accept_term(x) or x in cross_terms)
             for refed in cross_terms:
                 if not refed in candidates:
-                    candidates.append(refed)
-        else:
-            cross_terms = ()
+                    terms.append(refed)
         # repeat until no term can't be added, since addition of a new
         # term may permit to another one to be added
         modified = True
@@ -846,32 +875,34 @@
                     terms.append(term)
                     candidates.remove(term)
                     modified = True
-                    for source in sources:
-                        sourceterms = self._sourcesterms[source]
-                        # terms should be deleted once all possible solutions
-                        # indices have been consumed
-                        try:
-                            sourceterms[term] -= solindices
-                            if not sourceterms[term]:
-                                del sourceterms[term]
-                        except KeyError:
-                            assert term in cross_terms
+                    self._cleanup_sourcesterms(sources, solindices, term)
         return terms
     
-    def _cleanup_sourcesterms(self, sources, solindices):
-        """on final parts, remove solutions so we know they are already processed"""
+    def _cleanup_sourcesterms(self, sources, solindices, term=None):
+        """remove solutions so we know they are already processed"""
         for source in sources:
             try:
                 sourceterms = self._sourcesterms[source]
             except KeyError:
                 continue
-            for term, termsolindices in sourceterms.items():
-                if isinstance(term, Relation) and self.crossed_relation(source, term):
-                    continue
-                termsolindices -= solindices
-                if not termsolindices:
-                    del sourceterms[term]
-                    
+            if term is None:
+                for term, termsolindices in sourceterms.items():
+                    if isinstance(term, Relation) and self.crossed_relation(source, term):
+                        continue
+                    termsolindices -= solindices
+                    if not termsolindices:
+                        del sourceterms[term]
+            else:
+                try:
+                    sourceterms[term] -= solindices
+                    if not sourceterms[term]:
+                        del sourceterms[term]
+                except KeyError:
+                    pass
+                    #assert term in cross_terms
+            if not sourceterms:
+                del self._sourcesterms[source]
+                
     def merge_input_maps(self, allsolindices):
         """inputmaps is a dictionary with tuple of solution indices as key with
         an associated input map as value. This function compute for each
@@ -1016,7 +1047,7 @@
             byinputmap = {}
             for ppi in cppis:
                 select = ppi.rqlst
-                if sources != (plan.session.repo.system_source,):
+                if sources != (ppi.system_source,):
                     add_types_restriction(self.schema, select)
                 # part plan info for subqueries
                 inputmap = self._ppi_subqueries(ppi)
@@ -1104,8 +1135,15 @@
                             inputmap = subinputmap
                         else:
                             inputmap.update(subinputmap)
-                        steps.append(ppi.build_final_part(minrqlst, solindices, inputmap,
-                                                          sources, insertedvars))
+                        if inputmap and len(sources) > 1:
+                            sources.remove(ppi.system_source)
+                            steps.append(ppi.build_final_part(minrqlst, solindices, None,
+                                                              sources, insertedvars))
+                            steps.append(ppi.build_final_part(minrqlst, solindices, inputmap,
+                                                              [ppi.system_source], insertedvars))
+                        else:
+                            steps.append(ppi.build_final_part(minrqlst, solindices, inputmap,
+                                                              sources, insertedvars))
                 else:
                     table = '_T%s%s' % (''.join(sorted(v._ms_table_key() for v in terms)),
                                         ''.join(sorted(str(i) for i in solindices)))