--- a/rqlrewrite.py Fri Sep 12 14:46:11 2014 +0200
+++ b/rqlrewrite.py Mon Jun 16 10:22:24 2014 +0200
@@ -31,7 +31,7 @@
from logilab.common.graph import has_path
from cubicweb import Unauthorized
-
+from cubicweb.schema import RRQLExpression
def cleanup_solutions(rqlst, solutions):
for sol in solutions:
@@ -208,11 +208,21 @@
because it create an unresolvable query (eg no solutions found)
"""
+class VariableFromSubQuery(Exception):
+ """flow control exception to indicate that a variable is coming from a
+ subquery, and let parent act accordingly
+ """
+ def __init__(self, variable):
+ self.variable = variable
+
class RQLRewriter(object):
- """insert some rql snippets into another rql syntax tree
+ """Insert some rql snippets into another rql syntax tree, for security /
+ relation vocabulary. This implies that it should only restrict results of
+ the original query, not generate new ones. Hence, inserted snippets are
+ inserted under an EXISTS node.
- this class *isn't thread safe*
+ This class *isn't thread safe*.
"""
def __init__(self, session):
@@ -338,7 +348,7 @@
def rewrite(self, select, snippets, kwargs, existingvars=None):
"""
snippets: (varmap, list of rql expression)
- with varmap a *tuple* (select var, snippet var)
+ with varmap a *dict* {select var: snippet var}
"""
self.select = select
# remove_solutions used below require a copy
@@ -350,7 +360,7 @@
self.pending_keys = []
self.existingvars = existingvars
# we have to annotate the rqlst before inserting snippets, even though
- # we'll have to redo it latter
+ # we'll have to redo it later
self.annotate(select)
self.insert_snippets(snippets)
if not self.exists_snippet and self.u_varname:
@@ -362,7 +372,7 @@
assert len(newsolutions) >= len(solutions), (
'rewritten rql %s has lost some solutions, there is probably '
'something wrong in your schema permission (for instance using a '
- 'RQLExpression which insert a relation which doesn\'t exists in '
+ 'RQLExpression which inserts a relation which doesn\'t exist in '
'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
select, solutions, newsolutions))
if len(newsolutions) > len(solutions):
@@ -382,11 +392,10 @@
continue
self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
- def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+ def init_from_varmap(self, varmap, varexistsmap=None):
self.varmap = varmap
self.revvarmap = {}
self.varinfos = []
- self._insert_scope = None
for i, (selectvar, snippetvar) in enumerate(varmap):
assert snippetvar in 'SOX'
self.revvarmap[snippetvar] = (selectvar, i)
@@ -399,25 +408,35 @@
try:
vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
except KeyError:
- # variable may have been moved to a newly inserted subquery
- # we should insert snippet in that subquery
- subquery = self.select.aliases[selectvar].query
- assert len(subquery.children) == 1
- subselect = subquery.children[0]
- RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
- self.kwargs)
- return
+ vi['stinfo'] = sti = self._subquery_variable(selectvar)
if varexistsmap is None:
# build an index for quick access to relations
vi['rhs_rels'] = {}
- for rel in sti['rhsrelations']:
+ for rel in sti.get('rhsrelations', []):
vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
vi['lhs_rels'] = {}
- for rel in sti['relations']:
- if not rel in sti['rhsrelations']:
+ for rel in sti.get('relations', []):
+ if not rel in sti.get('rhsrelations', []):
vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
else:
vi['rhs_rels'] = vi['lhs_rels'] = {}
+
+ def _subquery_variable(self, selectvar):
+ raise VariableFromSubQuery(selectvar)
+
+ def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+ try:
+ self.init_from_varmap(varmap, varexistsmap)
+ except VariableFromSubQuery, ex:
+ # variable may have been moved to a newly inserted subquery
+ # we should insert snippet in that subquery
+ subquery = self.select.aliases[ex.variable].query
+ assert len(subquery.children) == 1, subquery
+ subselect = subquery.children[0]
+ RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
+ self.kwargs)
+ return
+ self._insert_scope = None
previous = None
inserted = False
for rqlexpr in rqlexprs:
@@ -450,6 +469,11 @@
finally:
self.existingvars = existing
+ def _inserted_root(self, new):
+ if not isinstance(new, (n.Exists, n.Not)):
+ new = n.Exists(new)
+ return new
+
def _insert_snippet(self, varmap, previous, new):
"""insert `new` snippet into the syntax tree, which have been rewritten
using `varmap`. In cases where an action is protected by several rql
@@ -474,8 +498,7 @@
self.insert_pending()
#self._insert_scope = None
return new
- if not isinstance(new, (n.Exists, n.Not)):
- new = n.Exists(new)
+ new = self._inserted_root(new)
if previous is None:
insert_scope.add_restriction(new)
else:
@@ -869,3 +892,40 @@
if self._insert_scope is None:
return self.select
return self._insert_scope.stmt
+
+
+class RQLRelationRewriter(RQLRewriter):
+ """Insert some rql snippets into another rql syntax tree, replacing computed
+ relations by their associated rule.
+
+ This class *isn't thread safe*.
+ """
+ def __init__(self, session):
+ super(RQLRelationRewriter, self).__init__(session)
+ self.rules = {}
+ for rschema in self.schema.iter_computed_relations():
+ self.rules[rschema.type] = RRQLExpression(rschema.rule)
+
+ def rewrite(self, union, kwargs=None):
+ self.kwargs = kwargs
+ self.removing_ambiguity = False
+ self.existingvars = None
+ self.pending_keys = None
+ for relation in union.iget_nodes(n.Relation):
+ if relation.r_type in self.rules:
+ self.select = relation.stmt
+ self.solutions = solutions = self.select.solutions[:]
+ self.current_expr = self.rules[relation.r_type]
+ self._insert_scope = relation.scope
+ self.rewritten = {}
+ lhs, rhs = relation.get_variable_parts()
+ varmap = {lhs.name: 'S', rhs.name: 'O'}
+ self.init_from_varmap(tuple(sorted(varmap.items())))
+ self.insert_snippet(varmap, self.current_expr.snippet_rqlst)
+ self.select.remove_node(relation)
+
+ def _subquery_variable(self, selectvar):
+ return self.select.aliases[selectvar].stinfo
+
+ def _inserted_root(self, new):
+ return new
--- a/test/unittest_rqlrewrite.py Fri Sep 12 14:46:11 2014 +0200
+++ b/test/unittest_rqlrewrite.py Mon Jun 16 10:22:24 2014 +0200
@@ -19,6 +19,7 @@
from logilab.common.testlib import unittest_main, TestCase
from logilab.common.testlib import mock_object
from yams import BadSchemaDefinition
+from yams.buildobjs import RelationDefinition
from rql import parse, nodes, RQLHelper
from cubicweb import Unauthorized, rqlrewrite
@@ -31,10 +32,8 @@
config = TestServerConfiguration(RQLRewriteTC.datapath('rewrite'))
config.bootstrap_cubes()
schema = config.load_schema()
- from yams.buildobjs import RelationDefinition
schema.add_relation_def(RelationDefinition(subject='Card', name='in_state',
object='State', cardinality='1*'))
-
rqlhelper = RQLHelper(schema, special_relations={'eid': 'uid',
'has_text': 'fti'})
repotest.do_monkey_patch()
@@ -49,11 +48,11 @@
2: 'Card',
3: 'Affaire'}[eid]
-def rewrite(rqlst, snippets_map, kwargs, existingvars=None):
+def _prepare_rewriter(rewriter_cls, kwargs):
class FakeVReg:
schema = schema
@staticmethod
- def solutions(sqlcursor, mainrqlst, kwargs):
+ def solutions(sqlcursor, rqlst, kwargs):
rqlhelper.compute_solutions(rqlst, {'eid': eid_func_map}, kwargs=kwargs)
class rqlhelper:
@staticmethod
@@ -62,8 +61,10 @@
@staticmethod
def simplify(mainrqlst, needcopy=False):
rqlhelper.simplify(rqlst, needcopy)
- rewriter = rqlrewrite.RQLRewriter(
- mock_object(vreg=FakeVReg, user=(mock_object(eid=1))))
+ return rewriter_cls(mock_object(vreg=FakeVReg, user=(mock_object(eid=1))))
+
+def rewrite(rqlst, snippets_map, kwargs, existingvars=None):
+ rewriter = _prepare_rewriter(rqlrewrite.RQLRewriter, kwargs)
snippets = []
for v, exprs in sorted(snippets_map.items()):
rqlexprs = [isinstance(snippet, basestring)
@@ -87,7 +88,7 @@
except KeyError:
vrefmaps[stmt] = {vref.name: set( (vref,) )}
selects.append(stmt)
- assert node in selects
+ assert node in selects, (node, selects)
for stmt in selects:
for var in stmt.defined_vars.itervalues():
assert var.stinfo['references']
@@ -591,5 +592,204 @@
finally:
RQLRewriter.insert_snippets = orig_insert_snippets
+
+class RQLRelationRewriterTC(TestCase):
+ # XXX valid rules: S and O specified, not in a SET, INSERT, DELETE scope
+ # valid uses: no outer join
+
+ # Basic tests
+ def test_base_rule(self):
+ rules = {'participated_in': 'S contributor O'}
+ rqlst = rqlhelper.parse('Any X WHERE X participated_in S')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any X WHERE X contributor S',
+ rqlst.as_string())
+
+ def test_complex_rule_1(self):
+ rules = {'illustrator_of': ('C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, '
+ 'R name "illustrator"')}
+ rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE C is Contribution, '
+ 'C contributor A, C manifestation B, '
+ 'C role D, D name "illustrator"',
+ rqlst.as_string())
+
+ def test_complex_rule_2(self):
+ rules = {'illustrator_of': ('C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, '
+ 'R name "illustrator"')}
+ rqlst = rqlhelper.parse('Any A WHERE EXISTS(A illustrator_of B)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A WHERE EXISTS(C is Contribution, '
+ 'C contributor A, C manifestation B, '
+ 'C role D, D name "illustrator")',
+ rqlst.as_string())
+
+
+ def test_rewrite2(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B, C require_permission R, S'
+ 'require_state O')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE C require_permission R, S require_state O, '
+ 'D is Contribution, D contributor A, D manifestation B, D role E, '
+ 'E name "illustrator"',
+ rqlst.as_string())
+
+ def test_rewrite3(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE E require_permission T, A illustrator_of B')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE E require_permission T, '
+ 'C is Contribution, C contributor A, C manifestation B, '
+ 'C role D, D name "illustrator"',
+ rqlst.as_string())
+
+ def test_rewrite4(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE C require_permission R, A illustrator_of B')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE C require_permission R, '
+ 'D is Contribution, D contributor A, D manifestation B, '
+ 'D role E, E name "illustrator"',
+ rqlst.as_string())
+
+ def test_rewrite5(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE C require_permission R, A illustrator_of B, '
+ 'S require_state O')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE C require_permission R, S require_state O, '
+ 'D is Contribution, D contributor A, D manifestation B, D role E, '
+ 'E name "illustrator"',
+ rqlst.as_string())
+
+ # Tests for the with clause
+ def test_rewrite_with(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WITH A,B BEING '
+ '(Any X,Y WHERE A is Contribution, A contributor X, '
+ 'A manifestation Y, A role B, B name "illustrator")',
+ rqlst.as_string())
+
+ def test_rewrite_with2(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE T require_permission C WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE T require_permission C '
+ 'WITH A,B BEING (Any X,Y WHERE A is Contribution, '
+ 'A contributor X, A manifestation Y, A role B, B name "illustrator")',
+ rqlst.as_string())
+
+ def test_rewrite_with3(self):
+ rules = {'participated_in': 'S contributor O'}
+ rqlst = rqlhelper.parse('Any A,B WHERE A participated_in B '
+ 'WITH A, B BEING(Any X,Y WHERE X contributor Y)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE A contributor B WITH A,B BEING '
+ '(Any X,Y WHERE X contributor Y)',
+ rqlst.as_string())
+
+ def test_rewrite_with4(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('Any A,B WHERE A illustrator_of B '
+ 'WITH A, B BEING(Any X, Y WHERE X illustrator_of Y)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE C is Contribution, '
+ 'C contributor A, C manifestation B, C role D, '
+ 'D name "illustrator" WITH A,B BEING '
+ '(Any X,Y WHERE A is Contribution, A contributor X, '
+ 'A manifestation Y, A role B, B name "illustrator")',
+ rqlst.as_string())
+
+ # Tests for the union
+ def test_rewrite_union(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B) UNION'
+ '(Any X,Y WHERE X is CWUser, Z manifestation Y)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('(Any A,B WHERE C is Contribution, '
+ 'C contributor A, C manifestation B, C role D, '
+ 'D name "illustrator") UNION (Any X,Y WHERE X is CWUser, Z manifestation Y)',
+ rqlst.as_string())
+
+ def test_rewrite_union2(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('(Any Y WHERE Y match W) UNION '
+ '(Any A WHERE A illustrator_of B) UNION '
+ '(Any Y WHERE Y is ArtWork)')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('(Any Y WHERE Y match W) '
+ 'UNION (Any A WHERE C is Contribution, C contributor A, '
+ 'C manifestation B, C role D, D name "illustrator") '
+ 'UNION (Any Y WHERE Y is ArtWork)',
+ rqlst.as_string())
+
+ # Tests for the exists clause
+ def test_rewrite_exists(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B, '
+ 'EXISTS(B is ArtWork))')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE EXISTS(B is ArtWork), '
+ 'C is Contribution, C contributor A, C manifestation B, C role D, '
+ 'D name "illustrator"',
+ rqlst.as_string())
+
+ def test_rewrite_exists2(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('(Any A,B WHERE B contributor A, EXISTS(A illustrator_of W))')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE B contributor A, '
+ 'EXISTS(C is Contribution, C contributor A, C manifestation W, '
+ 'C role D, D name "illustrator")',
+ rqlst.as_string())
+
+ def test_rewrite_exists3(self):
+ rules = {'illustrator_of': 'C is Contribution, C contributor S, '
+ 'C manifestation O, C role R, R name "illustrator"'}
+ rqlst = rqlhelper.parse('(Any A,B WHERE A illustrator_of B, EXISTS(A illustrator_of W))')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any A,B WHERE EXISTS(C is Contribution, C contributor A, '
+ 'C manifestation W, C role D, D name "illustrator"), '
+ 'E is Contribution, E contributor A, E manifestation B, E role F, '
+ 'F name "illustrator"',
+ rqlst.as_string())
+
+ # Test for GROUPBY
+ def test_rewrite_groupby(self):
+ rules = {'participated_in': 'S contributor O'}
+ rqlst = rqlhelper.parse('Any SUM(SA) GROUPBY S WHERE P participated_in S, P manifestation SA')
+ rule_rewrite(rqlst, rules)
+ self.assertEqual('Any SUM(SA) GROUPBY S WHERE P manifestation SA, P contributor S',
+ rqlst.as_string())
+
+
+
+def rule_rewrite(rqlst, kwargs=None):
+ rewriter = _prepare_rewriter(rqlrewrite.RQLRelationRewriter, kwargs)
+ rqlhelper.compute_solutions(rqlst.children[0], {'eid': eid_func_map},
+ kwargs=kwargs)
+ rewriter.rewrite(rqlst)
+ for select in rqlst.children:
+ test_vrefs(select)
+ return rewriter.rewritten
+
+
if __name__ == '__main__':
unittest_main()