rqlrewrite.py
branchstable
changeset 3437 a30b5b5138a4
parent 3254 fe7ec595751c
child 3443 34e451da9b5d
--- a/rqlrewrite.py	Wed Sep 23 08:16:06 2009 +0200
+++ b/rqlrewrite.py	Wed Sep 23 15:29:31 2009 +0200
@@ -14,8 +14,74 @@
 
 from logilab.common.compat import any
 
-from cubicweb import Unauthorized, server, typed_eid
-from cubicweb.server.ssplanner import add_types_restriction
+from cubicweb import Unauthorized, typed_eid
+
+
+def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
+    if newroot is None:
+        assert solutions is None
+        if hasattr(rqlst, '_types_restr_added'):
+            return
+        solutions = rqlst.solutions
+        newroot = rqlst
+        rqlst._types_restr_added = True
+    else:
+        assert solutions is not None
+        rqlst = rqlst.stmt
+    eschema = schema.eschema
+    allpossibletypes = {}
+    for solution in solutions:
+        for varname, etype in solution.iteritems():
+            if not varname in newroot.defined_vars or eschema(etype).is_final():
+                continue
+            allpossibletypes.setdefault(varname, set()).add(etype)
+    for varname in sorted(allpossibletypes):
+        try:
+            var = newroot.defined_vars[varname]
+        except KeyError:
+            continue
+        stinfo = var.stinfo
+        if stinfo.get('uidrels'):
+            continue # eid specified, no need for additional type specification
+        try:
+            typerels = rqlst.defined_vars[varname].stinfo.get('typerels')
+        except KeyError:
+            assert varname in rqlst.aliases
+            continue
+        if newroot is rqlst and typerels:
+            mytyperel = iter(typerels).next()
+        else:
+            for vref in newroot.defined_vars[varname].references():
+                rel = vref.relation()
+                if rel and rel.is_types_restriction():
+                    mytyperel = rel
+                    break
+            else:
+                mytyperel = None
+        possibletypes = allpossibletypes[varname]
+        if mytyperel is not None:
+            # variable as already some types restriction. new possible types
+            # can only be a subset of existing ones, so only remove no more
+            # possible types
+            for cst in mytyperel.get_nodes(n.Constant):
+                if not cst.value in possibletypes:
+                    cst.parent.remove(cst)
+                    try:
+                        stinfo['possibletypes'].remove(cst.value)
+                    except KeyError:
+                        # restriction on a type not used by this query, may
+                        # occurs with X is IN(...)
+                        pass
+        else:
+            # we have to add types restriction
+            if stinfo.get('scope') is not None:
+                rel = var.scope.add_type_restriction(var, possibletypes)
+            else:
+                # tree is not annotated yet, no scope set so add the restriction
+                # to the root
+                rel = newroot.add_type_restriction(var, possibletypes)
+            stinfo['typerels'] = frozenset((rel,))
+            stinfo['possibletypes'] = possibletypes
 
 
 def remove_solutions(origsolutions, solutions, defined):
@@ -73,8 +139,6 @@
         snippets: (varmap, list of rql expression)
                   with varmap a *tuple* (select var, snippet var)
         """
-        if server.DEBUG:
-            print '---- rewrite', select, snippets, solutions
         self.select = self.insert_scope = select
         self.solutions = solutions
         self.kwargs = kwargs
@@ -102,8 +166,6 @@
             newsolutions = self.remove_ambiguities(snippets, newsolutions)
         select.solutions = newsolutions
         add_types_restriction(self.schema, select)
-        if server.DEBUG:
-            print '---- rewriten', select
 
     def insert_snippets(self, snippets, varexistsmap=None):
         self.rewritten = {}