rqlrewrite.py
changeset 9953 643b19d79e4a
parent 9593 48a84fb4f301
child 10249 e38b8d37c5d8
--- a/rqlrewrite.py	Fri Sep 12 14:46:11 2014 +0200
+++ b/rqlrewrite.py	Mon Jun 16 10:22:24 2014 +0200
@@ -31,7 +31,7 @@
 from logilab.common.graph import has_path
 
 from cubicweb import Unauthorized
-
+from cubicweb.schema import RRQLExpression
 
 def cleanup_solutions(rqlst, solutions):
     for sol in solutions:
@@ -208,11 +208,21 @@
     because it create an unresolvable query (eg no solutions found)
     """
 
+class VariableFromSubQuery(Exception):
+    """flow control exception to indicate that a variable is coming from a
+    subquery, and let parent act accordingly
+    """
+    def __init__(self, variable):
+        self.variable = variable
+
 
 class RQLRewriter(object):
-    """insert some rql snippets into another rql syntax tree
+    """Insert some rql snippets into another rql syntax tree, for security /
+    relation vocabulary. This implies that it should only restrict results of
+    the original query, not generate new ones. Hence, inserted snippets are
+    inserted under an EXISTS node.
 
-    this class *isn't thread safe*
+    This class *isn't thread safe*.
     """
 
     def __init__(self, session):
@@ -338,7 +348,7 @@
     def rewrite(self, select, snippets, kwargs, existingvars=None):
         """
         snippets: (varmap, list of rql expression)
-                  with varmap a *tuple* (select var, snippet var)
+                  with varmap a *dict* {select var: snippet var}
         """
         self.select = select
         # remove_solutions used below require a copy
@@ -350,7 +360,7 @@
         self.pending_keys = []
         self.existingvars = existingvars
         # we have to annotate the rqlst before inserting snippets, even though
-        # we'll have to redo it latter
+        # we'll have to redo it later
         self.annotate(select)
         self.insert_snippets(snippets)
         if not self.exists_snippet and self.u_varname:
@@ -362,7 +372,7 @@
         assert len(newsolutions) >= len(solutions), (
             'rewritten rql %s has lost some solutions, there is probably '
             'something wrong in your schema permission (for instance using a '
-            'RQLExpression which insert a relation which doesn\'t exists in '
+            'RQLExpression which inserts a relation which doesn\'t exist in '
             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
             select, solutions, newsolutions))
         if len(newsolutions) > len(solutions):
@@ -382,11 +392,10 @@
                 continue
             self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
 
-    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+    def init_from_varmap(self, varmap, varexistsmap=None):
         self.varmap = varmap
         self.revvarmap = {}
         self.varinfos = []
-        self._insert_scope = None
         for i, (selectvar, snippetvar) in enumerate(varmap):
             assert snippetvar in 'SOX'
             self.revvarmap[snippetvar] = (selectvar, i)
@@ -399,25 +408,35 @@
                 try:
                     vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                 except KeyError:
-                    # variable may have been moved to a newly inserted subquery
-                    # we should insert snippet in that subquery
-                    subquery = self.select.aliases[selectvar].query
-                    assert len(subquery.children) == 1
-                    subselect = subquery.children[0]
-                    RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
-                                                      self.kwargs)
-                    return
+                    vi['stinfo'] = sti = self._subquery_variable(selectvar)
                 if varexistsmap is None:
                     # build an index for quick access to relations
                     vi['rhs_rels'] = {}
-                    for rel in sti['rhsrelations']:
+                    for rel in sti.get('rhsrelations', []):
                         vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
                     vi['lhs_rels'] = {}
-                    for rel in sti['relations']:
-                        if not rel in sti['rhsrelations']:
+                    for rel in sti.get('relations', []):
+                        if not rel in sti.get('rhsrelations', []):
                             vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
                 else:
                     vi['rhs_rels'] = vi['lhs_rels'] = {}
+
+    def _subquery_variable(self, selectvar):
+        raise VariableFromSubQuery(selectvar)
+
+    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+        try:
+            self.init_from_varmap(varmap, varexistsmap)
+        except VariableFromSubQuery, ex:
+            # variable may have been moved to a newly inserted subquery
+            # we should insert snippet in that subquery
+            subquery = self.select.aliases[ex.variable].query
+            assert len(subquery.children) == 1, subquery
+            subselect = subquery.children[0]
+            RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
+                                              self.kwargs)
+            return
+        self._insert_scope = None
         previous = None
         inserted = False
         for rqlexpr in rqlexprs:
@@ -450,6 +469,11 @@
         finally:
             self.existingvars = existing
 
+    def _inserted_root(self, new):
+        if not isinstance(new, (n.Exists, n.Not)):
+            new = n.Exists(new)
+        return new
+
     def _insert_snippet(self, varmap, previous, new):
         """insert `new` snippet into the syntax tree, which have been rewritten
         using `varmap`. In cases where an action is protected by several rql
@@ -474,8 +498,7 @@
                 self.insert_pending()
                 #self._insert_scope = None
                 return new
-            if not isinstance(new, (n.Exists, n.Not)):
-                new = n.Exists(new)
+            new = self._inserted_root(new)
             if previous is None:
                 insert_scope.add_restriction(new)
             else:
@@ -869,3 +892,40 @@
         if self._insert_scope is None:
             return self.select
         return self._insert_scope.stmt
+
+
+class RQLRelationRewriter(RQLRewriter):
+    """Insert some rql snippets into another rql syntax tree, replacing computed
+    relations by their associated rule.
+
+    This class *isn't thread safe*.
+    """
+    def __init__(self, session):
+        super(RQLRelationRewriter, self).__init__(session)
+        self.rules = {}
+        for rschema in self.schema.iter_computed_relations():
+            self.rules[rschema.type] = RRQLExpression(rschema.rule)
+
+    def rewrite(self, union, kwargs=None):
+        self.kwargs = kwargs
+        self.removing_ambiguity = False
+        self.existingvars = None
+        self.pending_keys = None
+        for relation in union.iget_nodes(n.Relation):
+            if relation.r_type in self.rules:
+                self.select = relation.stmt
+                self.solutions = solutions = self.select.solutions[:]
+                self.current_expr = self.rules[relation.r_type]
+                self._insert_scope = relation.scope
+                self.rewritten = {}
+                lhs, rhs = relation.get_variable_parts()
+                varmap = {lhs.name: 'S', rhs.name: 'O'}
+                self.init_from_varmap(tuple(sorted(varmap.items())))
+                self.insert_snippet(varmap, self.current_expr.snippet_rqlst)
+                self.select.remove_node(relation)
+
+    def _subquery_variable(self, selectvar):
+        return self.select.aliases[selectvar].stinfo
+
+    def _inserted_root(self, new):
+        return new