rqlrewrite.py
branch3.5
changeset 3240 8604a15995d1
parent 1977 606923dff11b
child 3254 fe7ec595751c
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/rqlrewrite.py	Wed Sep 16 14:24:31 2009 +0200
@@ -0,0 +1,480 @@
+"""RQL rewriting utilities : insert rql expression snippets into rql syntax
+tree.
+
+This is used for instance for read security checking in the repository.
+
+:organization: Logilab
+:copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
+"""
+__docformat__ = "restructuredtext en"
+
+from rql import nodes as n, stmts, TypeResolverException
+
+from logilab.common.compat import any
+
+from cubicweb import Unauthorized, server, typed_eid
+from cubicweb.server.ssplanner import add_types_restriction
+
+
+def remove_solutions(origsolutions, solutions, defined):
+    """when a rqlst has been generated from another by introducing security
+    assertions, this method returns solutions which are contained in orig
+    solutions
+    """
+    newsolutions = []
+    for origsol in origsolutions:
+        for newsol in solutions[:]:
+            for var, etype in origsol.items():
+                try:
+                    if newsol[var] != etype:
+                        try:
+                            defined[var].stinfo['possibletypes'].remove(newsol[var])
+                        except KeyError:
+                            pass
+                        break
+                except KeyError:
+                    # variable has been rewritten
+                    continue
+            else:
+                newsolutions.append(newsol)
+                solutions.remove(newsol)
+    return newsolutions
+
+
+class Unsupported(Exception): pass
+
+
+class RQLRewriter(object):
+    """insert some rql snippets into another rql syntax tree
+
+    this class *isn't thread safe*
+    """
+
+    def __init__(self, session):
+        self.session = session
+        vreg = session.vreg
+        self.schema = vreg.schema
+        self.annotate = vreg.rqlhelper.annotate
+        self._compute_solutions = vreg.solutions
+
+    def compute_solutions(self):
+        self.annotate(self.select)
+        try:
+            self._compute_solutions(self.session, self.select, self.kwargs)
+        except TypeResolverException:
+            raise Unsupported(str(self.select))
+        if len(self.select.solutions) < len(self.solutions):
+            raise Unsupported()
+
+    def rewrite(self, select, snippets, solutions, kwargs):
+        """
+        snippets: (varmap, list of rql expression)
+                  with varmap a *tuple* (select var, snippet var)
+        """
+        if server.DEBUG:
+            print '---- rewrite', select, snippets, solutions
+        self.select = self.insert_scope = select
+        self.solutions = solutions
+        self.kwargs = kwargs
+        self.u_varname = None
+        self.removing_ambiguity = False
+        self.exists_snippet = {}
+        self.pending_keys = []
+        # we have to annotate the rqlst before inserting snippets, even though
+        # we'll have to redo it latter
+        self.annotate(select)
+        self.insert_snippets(snippets)
+        if not self.exists_snippet and self.u_varname:
+            # U has been inserted than cancelled, cleanup
+            select.undefine_variable(select.defined_vars[self.u_varname])
+        # clean solutions according to initial solutions
+        newsolutions = remove_solutions(solutions, select.solutions,
+                                        select.defined_vars)
+        assert len(newsolutions) >= len(solutions), (
+            'rewritten rql %s has lost some solutions, there is probably '
+            'something wrong in your schema permission (for instance using a '
+            'RQLExpression which insert a relation which doesn\'t exists in '
+            'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
+            select, solutions, newsolutions))
+        if len(newsolutions) > len(solutions):
+            newsolutions = self.remove_ambiguities(snippets, newsolutions)
+        select.solutions = newsolutions
+        add_types_restriction(self.schema, select)
+        if server.DEBUG:
+            print '---- rewriten', select
+
+    def insert_snippets(self, snippets, varexistsmap=None):
+        self.rewritten = {}
+        for varmap, rqlexprs in snippets:
+            if varexistsmap is not None and not varmap in varexistsmap:
+                continue
+            self.varmap = varmap
+            selectvar, snippetvar = varmap
+            assert snippetvar in 'SOX'
+            self.revvarmap = {snippetvar: selectvar}
+            self.varinfo = vi = {}
+            try:
+                vi['const'] = typed_eid(selectvar) # XXX gae
+                vi['rhs_rels'] = vi['lhs_rels'] = {}
+            except ValueError:
+                vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
+                if varexistsmap is None:
+                    vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
+                    vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
+                                           if not r in sti['rhsrelations'])
+                else:
+                    vi['rhs_rels'] = vi['lhs_rels'] = {}
+            parent = None
+            inserted = False
+            for rqlexpr in rqlexprs:
+                self.current_expr = rqlexpr
+                if varexistsmap is None:
+                    try:
+                        new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
+                    except Unsupported:
+                        import traceback
+                        traceback.print_exc()
+                        continue
+                    inserted = True
+                    if new is not None:
+                        self.exists_snippet[rqlexpr] = new
+                    parent = parent or new
+                else:
+                    # called to reintroduce snippet due to ambiguity creation,
+                    # so skip snippets which are not introducing this ambiguity
+                    exists = varexistsmap[varmap]
+                    if self.exists_snippet[rqlexpr] is exists:
+                        self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
+            if varexistsmap is None and not inserted:
+                # no rql expression found matching rql solutions. User has no access right
+                raise Unauthorized(str((varmap, str(self.select), [expr.expression for expr in rqlexprs])))
+
+    def insert_snippet(self, varmap, snippetrqlst, parent=None):
+        new = snippetrqlst.where.accept(self)
+        if new is not None:
+            if self.varinfo.get('stinfo', {}).get('optrelations'):
+                assert parent is None
+                self.insert_scope = self.snippet_subquery(varmap, new)
+                self.insert_pending()
+                self.insert_scope = self.select
+                return
+            new = n.Exists(new)
+            if parent is None:
+                self.insert_scope.add_restriction(new)
+            else:
+                grandpa = parent.parent
+                or_ = n.Or(parent, new)
+                grandpa.replace(parent, or_)
+            if not self.removing_ambiguity:
+                try:
+                    self.compute_solutions()
+                except Unsupported:
+                    # some solutions have been lost, can't apply this rql expr
+                    if parent is None:
+                        self.select.remove_node(new, undefine=True)
+                    else:
+                        parent.parent.replace(or_, or_.children[0])
+                        self._cleanup_inserted(new)
+                    raise
+                else:
+                    self.insert_scope = new
+                    self.insert_pending()
+                    self.insert_scope = self.select
+            return new
+        self.insert_pending()
+
+    def insert_pending(self):
+        """pending_keys hold variable referenced by U has_<action>_permission X
+        relation.
+
+        Once the snippet introducing this has been inserted and solutions
+        recomputed, we have to insert snippet defined for <action> of entity
+        types taken by X
+        """
+        while self.pending_keys:
+            key, action = self.pending_keys.pop()
+            try:
+                varname = self.rewritten[key]
+            except KeyError:
+                try:
+                    varname = self.revvarmap[key[-1]]
+                except KeyError:
+                    # variable isn't used anywhere else, we can't insert security
+                    raise Unauthorized()
+            ptypes = self.select.defined_vars[varname].stinfo['possibletypes']
+            if len(ptypes) > 1:
+                # XXX dunno how to handle this
+                self.session.error(
+                    'cant check security of %s, ambigous type for %s in %s',
+                    self.select, varname, key[0]) # key[0] == the rql expression
+                raise Unauthorized()
+            etype = iter(ptypes).next()
+            eschema = self.schema.eschema(etype)
+            if not eschema.has_perm(self.session, action):
+                rqlexprs = eschema.get_rqlexprs(action)
+                if not rqlexprs:
+                    raise Unauthorised()
+                self.insert_snippets([((varname, 'X'), rqlexprs)])
+
+    def snippet_subquery(self, varmap, transformedsnippet):
+        """introduce the given snippet in a subquery"""
+        subselect = stmts.Select()
+        selectvar, snippetvar = varmap
+        subselect.append_selected(n.VariableRef(
+            subselect.get_variable(selectvar)))
+        aliases = [selectvar]
+        subselect.add_restriction(transformedsnippet.copy(subselect))
+        stinfo = self.varinfo['stinfo']
+        for rel in stinfo['relations']:
+            rschema = self.schema.rschema(rel.r_type)
+            if rschema.is_final() or (rschema.inlined and
+                                      not rel in stinfo['rhsrelations']):
+                self.select.remove_node(rel)
+                rel.children[0].name = selectvar
+                subselect.add_restriction(rel.copy(subselect))
+                for vref in rel.children[1].iget_nodes(n.VariableRef):
+                    subselect.append_selected(vref.copy(subselect))
+                    aliases.append(vref.name)
+        if self.u_varname:
+            # generate an identifier for the substitution
+            argname = subselect.allocate_varname()
+            while argname in self.kwargs:
+                argname = subselect.allocate_varname()
+            subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
+                                               'eid', unicode(argname), 'Substitute')
+            self.kwargs[argname] = self.session.user.eid
+        add_types_restriction(self.schema, subselect, subselect,
+                              solutions=self.solutions)
+        myunion = stmts.Union()
+        myunion.append(subselect)
+        aliases = [n.VariableRef(self.select.get_variable(name, i))
+                   for i, name in enumerate(aliases)]
+        self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
+        self._cleanup_inserted(transformedsnippet)
+        try:
+            self.compute_solutions()
+        except Unsupported:
+            # some solutions have been lost, can't apply this rql expr
+            self.select.remove_subquery(new, undefine=True)
+            raise
+        return subselect
+
+    def remove_ambiguities(self, snippets, newsolutions):
+        # the snippet has introduced some ambiguities, we have to resolve them
+        # "manually"
+        variantes = self.build_variantes(newsolutions)
+        # insert "is" where necessary
+        varexistsmap = {}
+        self.removing_ambiguity = True
+        for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
+            varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
+            var = self.select.defined_vars[varname]
+            exists = var.references()[0].scope
+            exists.add_constant_restriction(var, 'is', etype, 'etype')
+            varexistsmap[varmap] = exists
+        # insert ORED exists where necessary
+        for variante in variantes[1:]:
+            self.insert_snippets(snippets, varexistsmap)
+            for key, etype in variante.iteritems():
+                varname = self.rewritten[key]
+                try:
+                    var = self.select.defined_vars[varname]
+                except KeyError:
+                    # not a newly inserted variable
+                    continue
+                exists = var.references()[0].scope
+                exists.add_constant_restriction(var, 'is', etype, 'etype')
+        # recompute solutions
+        #select.annotated = False # avoid assertion error
+        self.compute_solutions()
+        # clean solutions according to initial solutions
+        return remove_solutions(self.solutions, self.select.solutions,
+                                self.select.defined_vars)
+
+    def build_variantes(self, newsolutions):
+        variantes = set()
+        for sol in newsolutions:
+            variante = []
+            for key, newvar in self.rewritten.iteritems():
+                variante.append( (key, sol[newvar]) )
+            variantes.add(tuple(variante))
+        # rebuild variantes as dict
+        variantes = [dict(variante) for variante in variantes]
+        # remove variable which have always the same type
+        for key in self.rewritten:
+            it = iter(variantes)
+            etype = it.next()[key]
+            for variante in it:
+                if variante[key] != etype:
+                    break
+            else:
+                for variante in variantes:
+                    del variante[key]
+        return variantes
+
+    def _cleanup_inserted(self, node):
+        # cleanup inserted variable references
+        for vref in node.iget_nodes(n.VariableRef):
+            vref.unregister_reference()
+            if not vref.variable.stinfo['references']:
+                # no more references, undefine the variable
+                del self.select.defined_vars[vref.name]
+
+    def _may_be_shared(self, relation, target, searchedvarname):
+        """return True if the snippet relation can be skipped to use a relation
+        from the original query
+        """
+        # if cardinality is in '?1', we can ignore the relation and use variable
+        # from the original query
+        rschema = self.schema.rschema(relation.r_type)
+        if target == 'object':
+            cardindex = 0
+            ttypes_func = rschema.objects
+            rprop = rschema.rproperty
+        else: # target == 'subject':
+            cardindex = 1
+            ttypes_func = rschema.subjects
+            rprop = lambda x, y, z: rschema.rproperty(y, x, z)
+        for etype in self.varinfo['stinfo']['possibletypes']:
+            for ttype in ttypes_func(etype):
+                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
+                    return False
+        return True
+
+    def _use_outer_term(self, snippet_varname, term):
+        key = (self.current_expr, self.varmap, snippet_varname)
+        if key in self.rewritten:
+            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
+            for inserted_vref in insertedvar.references():
+                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
+        self.rewritten[key] = term.name
+
+    def _get_varname_or_term(self, vname):
+        if vname == 'U':
+            if self.u_varname is None:
+                select = self.select
+                self.u_varname = select.allocate_varname()
+                # generate an identifier for the substitution
+                argname = select.allocate_varname()
+                while argname in self.kwargs:
+                    argname = select.allocate_varname()
+                # insert "U eid %(u)s"
+                var = select.get_variable(self.u_varname)
+                select.add_constant_restriction(select.get_variable(self.u_varname),
+                                                'eid', unicode(argname), 'Substitute')
+                self.kwargs[argname] = self.session.user.eid
+            return self.u_varname
+        key = (self.current_expr, self.varmap, vname)
+        try:
+            return self.rewritten[key]
+        except KeyError:
+            self.rewritten[key] = newvname = self.select.allocate_varname()
+            return newvname
+
+    # visitor methods ##########################################################
+
+    def _visit_binary(self, node, cls):
+        newnode = cls()
+        for c in node.children:
+            new = c.accept(self)
+            if new is None:
+                continue
+            newnode.append(new)
+        if len(newnode.children) == 0:
+            return None
+        if len(newnode.children) == 1:
+            return newnode.children[0]
+        return newnode
+
+    def _visit_unary(self, node, cls):
+        newc = node.children[0].accept(self)
+        if newc is None:
+            return None
+        newnode = cls()
+        newnode.append(newc)
+        return newnode
+
+    def visit_and(self, node):
+        return self._visit_binary(node, n.And)
+
+    def visit_or(self, node):
+        return self._visit_binary(node, n.Or)
+
+    def visit_not(self, node):
+        return self._visit_unary(node, n.Not)
+
+    def visit_exists(self, node):
+        return self._visit_unary(node, n.Exists)
+
+    def visit_relation(self, node):
+        lhs, rhs = node.get_variable_parts()
+        if node.r_type in ('has_add_permission', 'has_update_permission',
+                           'has_delete_permission', 'has_read_permission'):
+            assert lhs.name == 'U'
+            action = node.r_type.split('_')[1]
+            key = (self.current_expr, self.varmap, rhs.name)
+            self.pending_keys.append( (key, action) )
+            return
+        if lhs.name in self.revvarmap:
+            # on lhs
+            # see if we can reuse this relation
+            rels = self.varinfo['lhs_rels']
+            if (node.r_type in rels and isinstance(rhs, n.VariableRef)
+                and rhs.name != 'U' and not rels[node.r_type].neged(strict=True)
+                and self._may_be_shared(node, 'object', lhs.name)):
+                # ok, can share variable
+                term = rels[node.r_type].children[1].children[0]
+                self._use_outer_term(rhs.name, term)
+                return
+        elif isinstance(rhs, n.VariableRef) and rhs.name in self.revvarmap and lhs.name != 'U':
+            # on rhs
+            # see if we can reuse this relation
+            rels = self.varinfo['rhs_rels']
+            if (node.r_type in rels and not rels[node.r_type].neged(strict=True)
+                and self._may_be_shared(node, 'subject', rhs.name)):
+                # ok, can share variable
+                term = rels[node.r_type].children[0]
+                self._use_outer_term(lhs.name, term)
+                return
+        rel = n.Relation(node.r_type, node.optional)
+        for c in node.children:
+            rel.append(c.accept(self))
+        return rel
+
+    def visit_comparison(self, node):
+        cmp_ = n.Comparison(node.operator)
+        for c in node.children:
+            cmp_.append(c.accept(self))
+        return cmp_
+
+    def visit_mathexpression(self, node):
+        cmp_ = n.MathExpression(node.operator)
+        for c in cmp.children:
+            cmp_.append(c.accept(self))
+        return cmp_
+
+    def visit_function(self, node):
+        """generate filter name for a function"""
+        function_ = n.Function(node.name)
+        for c in node.children:
+            function_.append(c.accept(self))
+        return function_
+
+    def visit_constant(self, node):
+        """generate filter name for a constant"""
+        return n.Constant(node.value, node.type)
+
+    def visit_variableref(self, node):
+        """get the sql name for a variable reference"""
+        if node.name in self.revvarmap:
+            if self.varinfo.get('const') is not None:
+                return n.Constant(self.varinfo['const'], 'Int') # XXX gae
+            return n.VariableRef(self.select.get_variable(
+                self.revvarmap[node.name]))
+        vname_or_term = self._get_varname_or_term(node.name)
+        if isinstance(vname_or_term, basestring):
+            return n.VariableRef(self.select.get_variable(vname_or_term))
+        # shared term
+        return vname_or_term.copy(self.select)