[CWEP002] introduce RQLRelationRewriter
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 16 Jun 2014 10:22:24 +0200
changeset 9953 643b19d79e4a
parent 9952 0f3f965b6365
child 9954 79d34ba48612
[CWEP002] introduce RQLRelationRewriter Refactor existing RQLRewriter for later reuse of rewriting relation as specified by CWEP002. Work is different because we simply want to replace a relation by another rql snippet and we don't have to bother with EXISTS, subqueries and all. This rewriter is not yet plugged into the querier. Depends on yams 0.40 API. Related to #3546717
rqlrewrite.py
test/data/rewrite/schema.py
test/unittest_rqlrewrite.py
--- a/rqlrewrite.py	Fri Sep 12 14:46:11 2014 +0200
+++ b/rqlrewrite.py	Mon Jun 16 10:22:24 2014 +0200
@@ -31,7 +31,7 @@
 from logilab.common.graph import has_path
 
 from cubicweb import Unauthorized
-
+from cubicweb.schema import RRQLExpression
 
 def cleanup_solutions(rqlst, solutions):
     for sol in solutions:
@@ -208,11 +208,21 @@
     because it create an unresolvable query (eg no solutions found)
     """
 
+class VariableFromSubQuery(Exception):
+    """flow control exception to indicate that a variable is coming from a
+    subquery, and let parent act accordingly
+    """
+    def __init__(self, variable):
+        self.variable = variable
+
 
 class RQLRewriter(object):
-    """insert some rql snippets into another rql syntax tree
+    """Insert some rql snippets into another rql syntax tree, for security /
+    relation vocabulary. This implies that it should only restrict results of
+    the original query, not generate new ones. Hence, inserted snippets are
+    inserted under an EXISTS node.
 
-    this class *isn't thread safe*
+    This class *isn't thread safe*.
     """
 
     def __init__(self, session):
@@ -338,7 +348,7 @@
     def rewrite(self, select, snippets, kwargs, existingvars=None):
         """
         snippets: (varmap, list of rql expression)
-                  with varmap a *tuple* (select var, snippet var)
+                  with varmap a *dict* {select var: snippet var}
         """
         self.select = select
         # remove_solutions used below require a copy
@@ -350,7 +360,7 @@
         self.pending_keys = []
         self.existingvars = existingvars
         # we have to annotate the rqlst before inserting snippets, even though
-        # we'll have to redo it latter
+        # we'll have to redo it later
         self.annotate(select)
         self.insert_snippets(snippets)
         if not self.exists_snippet and self.u_varname:
@@ -362,7 +372,7 @@
         assert len(newsolutions) >= len(solutions), (
             'rewritten rql %s has lost some solutions, there is probably '
             'something wrong in your schema permission (for instance using a '
-            'RQLExpression which insert a relation which doesn\'t exists in '
+            'RQLExpression which inserts a relation which doesn\'t exist in '
             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
             select, solutions, newsolutions))
         if len(newsolutions) > len(solutions):
@@ -382,11 +392,10 @@
                 continue
             self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
 
-    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+    def init_from_varmap(self, varmap, varexistsmap=None):
         self.varmap = varmap
         self.revvarmap = {}
         self.varinfos = []
-        self._insert_scope = None
         for i, (selectvar, snippetvar) in enumerate(varmap):
             assert snippetvar in 'SOX'
             self.revvarmap[snippetvar] = (selectvar, i)
@@ -399,25 +408,35 @@
                 try:
                     vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                 except KeyError:
-                    # variable may have been moved to a newly inserted subquery
-                    # we should insert snippet in that subquery
-                    subquery = self.select.aliases[selectvar].query
-                    assert len(subquery.children) == 1
-                    subselect = subquery.children[0]
-                    RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
-                                                      self.kwargs)
-                    return
+                    vi['stinfo'] = sti = self._subquery_variable(selectvar)
                 if varexistsmap is None:
                     # build an index for quick access to relations
                     vi['rhs_rels'] = {}
-                    for rel in sti['rhsrelations']:
+                    for rel in sti.get('rhsrelations', []):
                         vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
                     vi['lhs_rels'] = {}
-                    for rel in sti['relations']:
-                        if not rel in sti['rhsrelations']:
+                    for rel in sti.get('relations', []):
+                        if not rel in sti.get('rhsrelations', []):
                             vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
                 else:
                     vi['rhs_rels'] = vi['lhs_rels'] = {}
+
+    def _subquery_variable(self, selectvar):
+        raise VariableFromSubQuery(selectvar)
+
+    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+        try:
+            self.init_from_varmap(varmap, varexistsmap)
+        except VariableFromSubQuery, ex:
+            # variable may have been moved to a newly inserted subquery
+            # we should insert snippet in that subquery
+            subquery = self.select.aliases[ex.variable].query
+            assert len(subquery.children) == 1, subquery
+            subselect = subquery.children[0]
+            RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
+                                              self.kwargs)
+            return
+        self._insert_scope = None
         previous = None
         inserted = False
         for rqlexpr in rqlexprs:
@@ -450,6 +469,11 @@
         finally:
             self.existingvars = existing
 
+    def _inserted_root(self, new):
+        if not isinstance(new, (n.Exists, n.Not)):
+            new = n.Exists(new)
+        return new
+
     def _insert_snippet(self, varmap, previous, new):
         """insert `new` snippet into the syntax tree, which have been rewritten
         using `varmap`. In cases where an action is protected by several rql
@@ -474,8 +498,7 @@
                 self.insert_pending()
                 #self._insert_scope = None
                 return new
-            if not isinstance(new, (n.Exists, n.Not)):
-                new = n.Exists(new)
+            new = self._inserted_root(new)
             if previous is None:
                 insert_scope.add_restriction(new)
             else:
@@ -869,3 +892,40 @@
         if self._insert_scope is None:
             return self.select
         return self._insert_scope.stmt
+
+
+class RQLRelationRewriter(RQLRewriter):
+    """Insert some rql snippets into another rql syntax tree, replacing computed
+    relations by their associated rule.
+
+    This class *isn't thread safe*.
+    """
+    def __init__(self, session):
+        super(RQLRelationRewriter, self).__init__(session)
+        self.rules = {}
+        for rschema in self.schema.iter_computed_relations():
+            self.rules[rschema.type] = RRQLExpression(rschema.rule)
+
+    def rewrite(self, union, kwargs=None):
+        self.kwargs = kwargs
+        self.removing_ambiguity = False
+        self.existingvars = None
+        self.pending_keys = None
+        for relation in union.iget_nodes(n.Relation):
+            if relation.r_type in self.rules:
+                self.select = relation.stmt
+                self.solutions = solutions = self.select.solutions[:]
+                self.current_expr = self.rules[relation.r_type]
+                self._insert_scope = relation.scope
+                self.rewritten = {}
+                lhs, rhs = relation.get_variable_parts()
+                varmap = {lhs.name: 'S', rhs.name: 'O'}
+                self.init_from_varmap(tuple(sorted(varmap.items())))
+                self.insert_snippet(varmap, self.current_expr.snippet_rqlst)
+                self.select.remove_node(relation)
+
+    def _subquery_variable(self, selectvar):
+        return self.select.aliases[selectvar].stinfo
+
+    def _inserted_root(self, new):
+        return new
--- a/test/data/rewrite/schema.py	Fri Sep 12 14:46:11 2014 +0200
+++ b/test/data/rewrite/schema.py	Mon Jun 16 10:22:24 2014 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -15,9 +15,15 @@
 #
 # You should have received a copy of the GNU Lesser General Public License along
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
-from yams.buildobjs import EntityType, RelationDefinition, String, SubjectRelation
+from yams.buildobjs import (EntityType, RelationDefinition, String, SubjectRelation,
+                            ComputedRelation, Int)
 from cubicweb.schema import ERQLExpression
 
+
+class Person(EntityType):
+    name = String()
+
+
 class Affaire(EntityType):
     __permissions__ = {
         'read':   ('managers',
@@ -82,3 +88,37 @@
     object = 'CWUser'
     inlined = True
     cardinality = '1*'
+
+class Contribution(EntityType):
+    code = Int()
+
+class ArtWork(EntityType):
+    name = String()
+
+class Role(EntityType):
+    name = String()
+
+class contributor(RelationDefinition):
+    subject = 'Contribution'
+    object = 'Person'
+    cardinality = '1*'
+    inlined = True
+
+class manifestation(RelationDefinition):
+    subject = 'Contribution'
+    object = 'ArtWork'
+
+class role(RelationDefinition):
+    subject = 'Contribution'
+    object = 'Role'
+
+class illustrator_of(ComputedRelation):
+    rule = ('C is Contribution, C contributor S, C manifestation O, '
+            'C role R, R name "illustrator"')
+
+class participated_in(ComputedRelation):
+    rule = 'S contributor O'
+
+class match(RelationDefinition):
+    subject = 'ArtWork'
+    object = 'Note'
--- a/test/unittest_rqlrewrite.py	Fri Sep 12 14:46:11 2014 +0200
+++ b/test/unittest_rqlrewrite.py	Mon Jun 16 10:22:24 2014 +0200
@@ -19,6 +19,7 @@
 from logilab.common.testlib import unittest_main, TestCase
 from logilab.common.testlib import mock_object
 from yams import BadSchemaDefinition
+from yams.buildobjs import RelationDefinition
 from rql import parse, nodes, RQLHelper
 
 from cubicweb import Unauthorized, rqlrewrite
@@ -31,10 +32,8 @@
     config = TestServerConfiguration(RQLRewriteTC.datapath('rewrite'))
     config.bootstrap_cubes()
     schema = config.load_schema()
-    from yams.buildobjs import RelationDefinition
     schema.add_relation_def(RelationDefinition(subject='Card', name='in_state',
                                                object='State', cardinality='1*'))
-
     rqlhelper = RQLHelper(schema, special_relations={'eid': 'uid',
                                                      'has_text': 'fti'})
     repotest.do_monkey_patch()
@@ -49,11 +48,11 @@
             2: 'Card',
             3: 'Affaire'}[eid]
 
-def rewrite(rqlst, snippets_map, kwargs, existingvars=None):
+def _prepare_rewriter(rewriter_cls, kwargs):
     class FakeVReg:
         schema = schema
         @staticmethod
-        def solutions(sqlcursor, mainrqlst, kwargs):
+        def solutions(sqlcursor, rqlst, kwargs):
             rqlhelper.compute_solutions(rqlst, {'eid': eid_func_map}, kwargs=kwargs)
         class rqlhelper:
             @staticmethod
@@ -62,8 +61,10 @@
             @staticmethod
             def simplify(mainrqlst, needcopy=False):
                 rqlhelper.simplify(rqlst, needcopy)
-    rewriter = rqlrewrite.RQLRewriter(
-        mock_object(vreg=FakeVReg, user=(mock_object(eid=1))))
+    return rewriter_cls(mock_object(vreg=FakeVReg, user=(mock_object(eid=1))))
+
+def rewrite(rqlst, snippets_map, kwargs, existingvars=None):
+    rewriter = _prepare_rewriter(rqlrewrite.RQLRewriter, kwargs)
     snippets = []
     for v, exprs in sorted(snippets_map.items()):
         rqlexprs = [isinstance(snippet, basestring)
@@ -87,7 +88,7 @@
         except KeyError:
             vrefmaps[stmt] = {vref.name: set( (vref,) )}
             selects.append(stmt)
-    assert node in selects
+    assert node in selects, (node, selects)
     for stmt in selects:
         for var in stmt.defined_vars.itervalues():
             assert var.stinfo['references']
@@ -591,5 +592,204 @@
         finally:
             RQLRewriter.insert_snippets = orig_insert_snippets
 
+
+class RQLRelationRewriterTC(TestCase):
+    # XXX valid rules: S and O specified, not in a SET, INSERT, DELETE scope
+    #     valid uses: no outer join
+
+    # Basic tests
+    def test_base_rule(self):
+        rules = {'participated_in': 'S contributor O'}
+        rqlst = rqlhelper.parse('Any X WHERE X participated_in S')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any X WHERE X contributor S',
+                         rqlst.as_string())
+
+    def test_complex_rule_1(self):
+        rules = {'illustrator_of': ('C is Contribution, C contributor S, '
+                                    'C manifestation O, C role R, '
+                                    'R name "illustrator"')}
+        rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE C is Contribution, '
+                         'C contributor A, C manifestation B, '
+                         'C role D, D name "illustrator"',
+                         rqlst.as_string())
+
+    def test_complex_rule_2(self):
+        rules = {'illustrator_of': ('C is Contribution, C contributor S, '
+                                    'C manifestation O, C role R, '
+                                    'R name "illustrator"')}
+        rqlst = rqlhelper.parse('Any A WHERE EXISTS(A illustrator_of B)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A WHERE EXISTS(C is Contribution, '
+                         'C contributor A, C manifestation B, '
+                         'C role D, D name "illustrator")',
+                         rqlst.as_string())
+
+
+    def test_rewrite2(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B, C require_permission R, S'
+                                'require_state O')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE C require_permission R, S require_state O, '
+                         'D is Contribution, D contributor A, D manifestation B, D role E, '
+                         'E name "illustrator"',
+                          rqlst.as_string())
+
+    def test_rewrite3(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE E require_permission T, A illustrator_of B')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE E require_permission T, '
+                         'C is Contribution, C contributor A, C manifestation B, '
+                         'C role D, D name "illustrator"',
+                         rqlst.as_string())
+
+    def test_rewrite4(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE C require_permission R, A illustrator_of B')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE C require_permission R, '
+                         'D is Contribution, D contributor A, D manifestation B, '
+                         'D role E, E name "illustrator"',
+                         rqlst.as_string())
+
+    def test_rewrite5(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE C require_permission R, A illustrator_of B, '
+                                'S require_state O')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE C require_permission R, S require_state O, '
+                         'D is Contribution, D contributor A, D manifestation B, D role E, '
+                         'E name "illustrator"',
+                         rqlst.as_string())
+
+    # Tests for the with clause
+    def test_rewrite_with(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WITH A,B BEING '
+                         '(Any X,Y WHERE A is Contribution, A contributor X, '
+                         'A manifestation Y, A role B, B name "illustrator")',
+                         rqlst.as_string())
+
+    def test_rewrite_with2(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE T require_permission C WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE T require_permission C '
+                         'WITH A,B BEING (Any X,Y WHERE A is Contribution, '
+                         'A contributor X, A manifestation Y, A role B, B name "illustrator")',
+                         rqlst.as_string())
+
+    def test_rewrite_with3(self):
+        rules = {'participated_in': 'S contributor O'}
+        rqlst = rqlhelper.parse('Any A,B WHERE A participated_in B '
+                                'WITH A, B BEING(Any X,Y WHERE X contributor Y)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE A contributor B WITH A,B BEING '
+                         '(Any X,Y WHERE X contributor Y)', 
+                         rqlst.as_string())
+
+    def test_rewrite_with4(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B '
+                               'WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE C is Contribution, '
+                         'C contributor A, C manifestation B, C role D, '
+                         'D name "illustrator" WITH A,B BEING '
+                         '(Any X,Y WHERE A is Contribution, A contributor X, '
+                         'A manifestation Y, A role B, B name "illustrator")',
+                          rqlst.as_string()) 
+
+    # Tests for the union
+    def test_rewrite_union(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B) UNION'
+                                '(Any X,Y WHERE X is CWUser, Z manifestation Y)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('(Any A,B WHERE C is Contribution, '
+                         'C contributor A, C manifestation B, C role D, '
+                         'D name "illustrator") UNION (Any X,Y WHERE X is CWUser, Z manifestation Y)',
+                         rqlst.as_string())
+
+    def test_rewrite_union2(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('(Any Y WHERE Y match W) UNION '
+                                '(Any A WHERE A illustrator_of B) UNION '
+                                '(Any Y WHERE Y is ArtWork)')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('(Any Y WHERE Y match W) '
+                         'UNION (Any A WHERE C is Contribution, C contributor A, '
+                         'C manifestation B, C role D, D name "illustrator") '
+                         'UNION (Any Y WHERE Y is ArtWork)',
+                         rqlst.as_string())
+
+    # Tests for the exists clause
+    def test_rewrite_exists(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B, '
+                     'EXISTS(B is ArtWork))')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE EXISTS(B is ArtWork), '
+                         'C is Contribution, C contributor A, C manifestation B, C role D, '
+                         'D name "illustrator"',
+                         rqlst.as_string())
+
+    def test_rewrite_exists2(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('(Any A,B WHERE B contributor A, EXISTS(A illustrator_of W))')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE B contributor A, '
+                         'EXISTS(C is Contribution, C contributor A, C manifestation W, '
+                         'C role D, D name "illustrator")',
+                         rqlst.as_string())
+
+    def test_rewrite_exists3(self):
+        rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+                'C manifestation O, C role R, R name "illustrator"'}
+        rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B, EXISTS(A illustrator_of W))')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any A,B WHERE EXISTS(C is Contribution, C contributor A, '
+                         'C manifestation W, C role D, D name "illustrator"), '
+                         'E is Contribution, E contributor A, E manifestation B, E role F, '
+                         'F name "illustrator"',
+                         rqlst.as_string())
+
+    # Test for GROUPBY
+    def test_rewrite_groupby(self):
+        rules = {'participated_in': 'S contributor O'}
+        rqlst = rqlhelper.parse('Any SUM(SA) GROUPBY S WHERE P participated_in S, P manifestation SA')
+        rule_rewrite(rqlst, rules)
+        self.assertEqual('Any SUM(SA) GROUPBY S WHERE P manifestation SA, P contributor S',
+                         rqlst.as_string())
+
+
+
+def rule_rewrite(rqlst, kwargs=None):
+    rewriter = _prepare_rewriter(rqlrewrite.RQLRelationRewriter, kwargs)
+    rqlhelper.compute_solutions(rqlst.children[0], {'eid': eid_func_map},
+                                kwargs=kwargs)
+    rewriter.rewrite(rqlst)
+    for select in rqlst.children:
+        test_vrefs(select)
+    return rewriter.rewritten
+
+
 if __name__ == '__main__':
     unittest_main()