server/msplanner.py
changeset 393 45a675515427
parent 392 bccd9a788f7a
child 426 e96662444ec6
--- a/server/msplanner.py	Tue Jan 13 17:55:29 2009 +0100
+++ b/server/msplanner.py	Tue Jan 13 17:56:02 2009 +0100
@@ -54,6 +54,7 @@
 """
 __docformat__ = "restructuredtext en"
 
+from copy import deepcopy
 from itertools import imap, ifilterfalse
 
 from logilab.common.compat import any
@@ -75,6 +76,7 @@
 Constant._ms_table_key = lambda x: str(x.value)
 
 AbstractSource.dont_cross_relations = ()
+AbstractSource.cross_relations = ()
 
 def allequals(solutions):
     """return true if all solutions are identical"""
@@ -172,15 +174,23 @@
         self._session = plan.session
         self._solutions = rqlst.solutions
         self._solindices = range(len(self._solutions))
-        # source : {varname: [solution index, ]}
-        self._sourcesvars = {}
+        # source : {var: [solution index, ]}
+        self.sourcesvars = self._sourcesvars = {}
+        # source : {relation: set(child variable and constant)}
+        self._crossrelations = {}
         # dictionnary of variables which are linked to each other using a non
         # final relation which is supported by multiple sources
         self._linkedvars = {}
+        self._crosslinkedvars = {}
         # processing
         self._compute_sourcesvars()
         self._remove_invalid_sources()
         self._compute_needsplit()
+        self.sourcesvars = {}
+        for k, v in self._sourcesvars.iteritems():
+            self.sourcesvars[k] = {}
+            for k2, v2 in v.iteritems():
+                self.sourcesvars[k][k2] = v2.copy()
         self._inputmaps = {}
         if rqlhelper is not None: # else test
             self._insert_identity_variable = rqlhelper._annotator.rewrite_shared_optional
@@ -202,16 +212,17 @@
                          for solindex in self._solindices)        
        
     @cached
-    def _norel_support_set(self, rtype):
+    def _norel_support_set(self, relation):
         """return a set of (source, solindex) where source doesn't support the
         relation
         """
         return frozenset((source, solidx) for source in self._session.repo.sources
                          for solidx in self._solindices
-                         if not (source.support_relation(rtype)
-                                 or rtype in source.dont_cross_relations))
-        
-    
+                         if not ((source.support_relation(relation.r_type) and
+                                  not self.crossed_relation(source, relation))
+                                 or relation.r_type in source.dont_cross_relations))
+
+
     def _compute_sourcesvars(self):
         """compute for each variable/solution in the rqlst which sources support
         them
@@ -264,9 +275,50 @@
                         if not varobj._q_invariant and any(ifilterfalse(
                             source.support_relation, (r.r_type for r in rels))):
                             self.needsplit = True               
-            
+
+    def _handle_cross_relation(self, rel, relsources, vsources):
+        crossvars = None
+        for source in relsources:
+            if rel.r_type in source.cross_relations:
+                crossvars = set(x.variable for x in rel.get_nodes(VariableRef))
+                crossvars.update(frozenset(x for x in rel.get_nodes(Constant)))
+                assert len(crossvars) == 2
+                ssource = self._session.repo.system_source
+                needsplit = True
+                flag = 0
+                for v in crossvars:
+                    if isinstance(v, Constant):
+                        self._sourcesvars[ssource][v] = set(self._solindices)
+                    if len(vsources[v]) == 1:
+                        if iter(vsources[v]).next()[0].uri == 'system':
+                            flag = 1
+                            for ov in crossvars:
+                                if ov is not v and ov._q_invariant:
+                                    ssset = frozenset((ssource,))
+                                    self._remove_sources(ov, vsources[ov] - ssset)
+                        else:
+                            for ov in crossvars:
+                                if ov is not v and ov._q_invariant:
+                                    needsplit = False
+                                    break
+                            else:
+                                continue
+                        if not rel.neged(strict=True):
+                            break
+                else:
+                    self._crossrelations.setdefault(source, {})[rel] = crossvars
+                    if not flag:
+                        self._sourcesvars.setdefault(source, {})[rel] = set(self._solindices)
+                    self._sourcesvars.setdefault(ssource, {})[rel] = set(self._solindices)
+                    if needsplit:
+                        self.needsplit = True
+        return crossvars is None
+        
     def _remove_invalid_sources(self):
-        """removes invalid sources from `sourcesvars` member"""
+        """removes invalid sources from `sourcesvars` member according to
+        traversed relations and their properties (which sources support them,
+        can they cross sources, etc...)
+        """
         repo = self._session.repo
         rschema = repo.schema.rschema
         vsources = {}
@@ -276,8 +328,18 @@
             # during bootstrap)
             if not rel.is_types_restriction() and not rschema(rel.r_type).is_final():
                 # nothing to do if relation is not supported by multiple sources
+                # or if some source has it listed in its cross_relations
+                # attribute
+                #
+                # XXX code below don't deal if some source allow relation
+                #     crossing but not another one
                 relsources = repo.rel_type_sources(rel.r_type)
+                crossvars = None
                 if len(relsources) < 2:
+                    # filter out sources being there because they have this
+                    # relation in their dont_cross_relations attribute
+                    relsources = [source for source in relsources
+                                  if source.support_relation(rel.r_type)]
                     if relsources:
                         # this means the relation is using a variable inlined as
                         # a constant and another unsupported variable, in which
@@ -291,8 +353,12 @@
                     vsources[lhsv] = self._term_sources(lhs)
                 if not rhsv in vsources:
                     vsources[rhsv] = self._term_sources(rhs)
-                self._linkedvars.setdefault(lhsv, set()).add((rhsv, rel))
-                self._linkedvars.setdefault(rhsv, set()).add((lhsv, rel))
+                if self._handle_cross_relation(rel, relsources, vsources):
+                    self._linkedvars.setdefault(lhsv, set()).add((rhsv, rel))
+                    self._linkedvars.setdefault(rhsv, set()).add((lhsv, rel))
+                else:
+                    self._crosslinkedvars.setdefault(lhsv, set()).add((rhsv, rel))
+                    self._crosslinkedvars.setdefault(rhsv, set()).add((lhsv, rel))
         for term in self._linkedvars:
             self._remove_sources_until_stable(term, vsources)
         if len(self._sourcesvars) > 1 and hasattr(self.plan.rqlst, 'main_relations'):
@@ -308,10 +374,7 @@
             for rel in self.plan.rqlst.main_relations:
                 if not rschema(rel.r_type).is_final():
                     # nothing to do if relation is not supported by multiple sources
-                    relsources = [source for source in repo.sources
-                                  if source.support_relation(rel.r_type)
-                                  or rel.r_type in source.dont_cross_relations]
-                    if len(relsources) < 2:
+                    if len(repo.rel_type_sources(rel.r_type)) < 2:
                         continue
                     lhs, rhs = rel.get_variable_parts()
                     try:
@@ -319,7 +382,7 @@
                         rhsv = self._extern_term(rhs, vsources, inserted)
                     except KeyError, ex:
                         continue
-                    norelsup = self._norel_support_set(rel.r_type)
+                    norelsup = self._norel_support_set(rel)
                     self._remove_var_sources(lhsv, norelsup, rhsv, vsources)
                     self._remove_var_sources(rhsv, norelsup, lhsv, vsources)
         # cleanup linked var
@@ -379,7 +442,7 @@
                 # on a multisource relation for a variable only used by this relation
                 # (eg "Any X WHERE NOT X multisource_rel Y" and over is Y), iif 
                 continue
-            norelsup = self._norel_support_set(rel.r_type)
+            norelsup = self._norel_support_set(rel)
             # compute invalid sources for variables and remove them
             self._remove_var_sources(var, norelsup, ovar, vsources)
             self._remove_var_sources(ovar, norelsup, var, vsources)
@@ -403,7 +466,7 @@
         * a source support an entity (non invariant) but doesn't support a
           relation on it
         * a source support an entity which is accessed by an optional relation
-        * there is more than one sources and either all sources'supported        
+        * there is more than one source and either all sources'supported        
           variable/solutions are not equivalent or multiple variables have to
           be fetched from some source
         """
@@ -416,9 +479,10 @@
             else:
                 sample = self._sourcesvars.itervalues().next()
                 if len(sample) > 1 and any(v for v in sample
-                                           if not v in self._linkedvars):
+                                           if not v in self._linkedvars
+                                           and not v in self._crosslinkedvars):
                     self.needsplit = True
-
+            
     def _set_source_for_var(self, source, var):
         self._sourcesvars.setdefault(source, {})[var] = set(self._solindices)
 
@@ -429,10 +493,10 @@
             return set((source, solindex) for solindex in self._solindices)
         else:
             var = getattr(term, 'variable', term)
-            sources = [source for source, varobjs in self._sourcesvars.iteritems()
+            sources = [source for source, varobjs in self.sourcesvars.iteritems()
                        if var in varobjs]
             return set((source, solindex) for source in sources
-                       for solindex in self._sourcesvars[source][var])
+                       for solindex in self.sourcesvars[source][var])
 
     def _remove_sources(self, var, sources):
         """removes invalid sources (`sources`) from `sourcesvars`
@@ -451,6 +515,9 @@
                 if not sourcesvars[source]:
                     del sourcesvars[source]
 
+    def crossed_relation(self, source, relation):
+        return relation in self._crossrelations.get(source, ())
+    
     def part_steps(self):
         """precompute necessary part steps before generating actual rql for
         each step. This is necessary to know if an aggregate step will be
@@ -474,7 +541,7 @@
                     sourcevars.clear()
                 else:
                     scope = var.scope
-                    variables = self._expand_vars(var, sourcevars, scope, solindices)
+                    variables = self._expand_vars(var, source, sourcevars, scope, solindices)
                     if not sourcevars:
                         del self._sourcesvars[source]
                 # find which sources support the same variables/solutions
@@ -504,11 +571,15 @@
                         eid = const.eval(self.plan.args)
                         _source = self._session.source_from_eid(eid)
                         if len(sources) > 1 or not _source in sources:
-                            # if constant is only used by an identity relation,
-                            # skip
+                            # if there is some rewriten constant used by a
+                            # not neged relation while there are some source
+                            # not supporting the associated entity, this step
+                            # can't be final (unless the relation is explicitly
+                            # in `variables`, eg cross relations)
                             for c in vconsts:
                                 rel = c.relation()
-                                if rel is None or not rel.neged(strict=True):
+                                if rel is None or not (rel in variables or rel.neged(strict=True)):
+                                #if rel is not None and rel.r_type == 'identity' and not rel.neged(strict=True):
                                     final = False
                                     break
                             break
@@ -523,6 +594,9 @@
                                 needsel.add(vref.name)
                             final = False
                             break
+                        elif self.crossed_relation(_source, rel) and not rel in variables:
+                            final = False
+                            break
                     else:
                         if not scope is select:
                             self._exists_relation(rel, variables, needsel)
@@ -578,8 +652,6 @@
             relation.children[1].operator = '=' 
             variables.append(newvar)
             needsel.add(newvar.name)
-            #self.insertedvars.append((var.name, self.schema['identity'],
-            #                          newvar.name))
         
     def _choose_var(self, sourcevars):
         secondchoice = None
@@ -589,7 +661,7 @@
                 if not var.scope is self.rqlst:
                     if isinstance(var, Variable):
                         return var, sourcevars.pop(var)
-                    secondchoice = var
+                    secondchoice = var, sourcevars.pop(var)
         else:
             # priority to variable outer scope
             for var in sourcevars:
@@ -607,7 +679,8 @@
         var = iter(sourcevars).next()
         return var, sourcevars.pop(var)
             
-    def _expand_vars(self, var, sourcevars, scope, solindices):
+            
+    def _expand_vars(self, var, source, sourcevars, scope, solindices):
         variables = [var]
         nbunlinked = 1
         linkedvars = self._linkedvars
@@ -617,28 +690,41 @@
             candidates = (v for v in sourcevars.keys() if scope is v.scope)
         else:
             candidates = sourcevars #.iterkeys()
+        # we only want one unlinked variable in each generated query
         candidates = [v for v in candidates
                       if isinstance(v, Constant) or
                       (solindices.issubset(sourcevars[v]) and v in linkedvars)]
+        accept_var = lambda x: (isinstance(x, Constant) or any(v for v in variables if v in linkedvars.get(x, ())))
+        source_cross_rels = self._crossrelations.get(source, ())
+        if isinstance(var, Relation) and var in source_cross_rels:
+            cross_vars = source_cross_rels.pop(var)
+            base_accept_var = accept_var
+            accept_var = lambda x: (base_accept_var(x) or x in cross_vars)
+            for refed in cross_vars:
+                if not refed in candidates:
+                    candidates.append(refed)
+        else:
+            cross_vars = ()
         # repeat until no variable can't be added, since addition of a new
         # variable may permit to another one to be added
         modified = True
         while modified and candidates:
             modified = False
             for var in candidates[:]:
-                # we only want one unlinked variable in each generated query
-                if isinstance(var, Constant) or \
-                       any(v for v in variables if v in linkedvars[var]):
+                if accept_var(var):
                     variables.append(var)
-                    # constant nodes should be systematically deleted
-                    if isinstance(var, Constant):
-                        del sourcevars[var]
-                    # variable nodes should be deleted once all possible solution
-                    # indices have been consumed
-                    else:
-                        sourcevars[var] -= solindices
-                        if not sourcevars[var]:
+                    try:
+                        # constant nodes should be systematically deleted
+                        if isinstance(var, Constant):
                             del sourcevars[var]
+                        else:
+                            # variable nodes should be deleted once all possible
+                            # solutions indices have been consumed
+                            sourcevars[var] -= solindices
+                            if not sourcevars[var]:
+                                del sourcevars[var]
+                    except KeyError:
+                        assert var in cross_vars
                     candidates.remove(var)
                     modified = True
         return variables
@@ -660,21 +746,22 @@
                         varsolindices = sourcesvars[source][var]
                         varsolindices -= solindices
                         if not varsolindices:
-                            del sourcesvars[source][var]
-                
+                            del sourcesvars[source][var]                
         return sources
     
     def _cleanup_sourcesvars(self, sources, solindices):
         """on final parts, remove solutions so we know they are already processed"""
         for source in sources:
             try:
-                sourcevar = self._sourcesvars[source]
+                sourcevars = self._sourcesvars[source]
             except KeyError:
                 continue
-            for var, varsolindices in sourcevar.items():
+            for var, varsolindices in sourcevars.items():
+                if isinstance(var, Relation) and self.crossed_relation(source, var):
+                    continue
                 varsolindices -= solindices
                 if not varsolindices:
-                    del sourcevar[var]
+                    del sourcevars[var]
                     
     def merge_input_maps(self, allsolindices):
         """inputmaps is a dictionary with tuple of solution indices as key with an
@@ -954,7 +1041,8 @@
         if server.DEBUG:
             print 'filter', final and 'final' or '', sources, variables, rqlst, solindices, needsel
         newroot = Select()
-        self.sources = sources
+        self.sources = sorted(sources)
+        self.variables = variables
         self.solindices = solindices
         self.final = final
         # variables which appear in unsupported branches
@@ -1067,10 +1155,20 @@
 
     visit_or = visit_and
 
-    def _relation_supported(self, rtype):
+    def _relation_supported(self, relation):
+        rtype = relation.r_type
         for source in self.sources:
-            if not source.support_relation(rtype):
+            if not source.support_relation(rtype) \
+                   or (rtype in source.cross_relations and not relation in self.variables):#self.ppi.crossed_relation(source, relation):
                 return False
+        if not self.final:
+            rschema = self.schema.rschema(relation.r_type)
+            if not rschema.is_final():
+                for term in relation.get_nodes((VariableRef, Constant)):
+                    term = getattr(term, 'variable', term)
+                    termsources = sorted(set(x[0] for x in self.ppi._term_sources(term)))
+                    if termsources and termsources != self.sources:
+                        return False
         return True
         
     def visit_relation(self, node, newroot, variables):
@@ -1087,12 +1185,12 @@
                         return None, node
                 else:
                     return None, node
-            if not self._relation_supported(node.r_type):
+            if not self._relation_supported(node):
                 raise UnsupportedBranch()
         # don't copy type restriction unless this is the only relation for the
         # rhs variable, else they'll be reinserted later as needed (else we may
         # copy a type restriction while the variable is not actually used)
-        elif not any(self._relation_supported(rel.r_type)
+        elif not any(self._relation_supported(rel)
                      for rel in node.children[0].variable.stinfo['relations']):
             rel, node = self.visit_default(node, newroot, variables)
             return rel, node
@@ -1107,8 +1205,7 @@
                 if not ored:
                     self.skip.setdefault(node, set()).update(self.solindices)
                 else:
-                    self.mayneedvar.setdefault((node.children[0].name, rschema), []).append( (res, ored) )
-                    
+                    self.mayneedvar.setdefault((node.children[0].name, rschema), []).append( (res, ored) )                    
             else:
                 assert len(vrefs) == 1
                 vref = vrefs[0]