rqlrewrite.py
author Sylvain Thénault <sylvain.thenault@logilab.fr>
Wed, 24 Apr 2013 14:37:48 +0200
branchstable
changeset 8909 f46b017db2d9
parent 8694 d901c36bcfce
child 8748 f5027f8d2478
permissions -rw-r--r--
[test] fix skipIf import, currently failing if unittest2 is not there

# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""RQL rewriting utilities : insert rql expression snippets into rql syntax
tree.

This is used for instance for read security checking in the repository.
"""
__docformat__ = "restructuredtext en"

from rql import nodes as n, stmts, TypeResolverException
from rql.utils import common_parent

from yams import BadSchemaDefinition

from logilab.common import tempattr
from logilab.common.graph import has_path

from cubicweb import Unauthorized, typed_eid


def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    if newroot is None:
        assert solutions is None
        if hasattr(rqlst, '_types_restr_added'):
            return
        solutions = rqlst.solutions
        newroot = rqlst
        rqlst._types_restr_added = True
    else:
        assert solutions is not None
        rqlst = rqlst.stmt
    eschema = schema.eschema
    allpossibletypes = {}
    for solution in solutions:
        for varname, etype in solution.iteritems():
            # XXX not considering aliases by design, right ?
            if varname not in newroot.defined_vars or eschema(etype).final:
                continue
            allpossibletypes.setdefault(varname, set()).add(etype)
    # XXX could be factorized with add_etypes_restriction from rql 0.31
    for varname in sorted(allpossibletypes):
        var = newroot.defined_vars[varname]
        stinfo = var.stinfo
        if stinfo.get('uidrel') is not None:
            continue # eid specified, no need for additional type specification
        try:
            typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
        except KeyError:
            assert varname in rqlst.aliases
            continue
        if newroot is rqlst and typerel is not None:
            mytyperel = typerel
        else:
            for vref in var.references():
                rel = vref.relation()
                if rel and rel.is_types_restriction():
                    mytyperel = rel
                    break
            else:
                mytyperel = None
        possibletypes = allpossibletypes[varname]
        if mytyperel is not None:
            if mytyperel.r_type == 'is_instance_of':
                # turn is_instance_of relation into a is relation since we've
                # all possible solutions and don't want to bother with
                # potential is_instance_of incompatibility
                mytyperel.r_type = 'is'
                if len(possibletypes) > 1:
                    node = n.Function('IN')
                    for etype in possibletypes:
                        node.append(n.Constant(etype, 'etype'))
                else:
                    node = n.Constant(etype, 'etype')
                comp = mytyperel.children[1]
                comp.replace(comp.children[0], node)
            else:
                # variable has already some strict types restriction. new
                # possible types can only be a subset of existing ones, so only
                # remove no more possible types
                for cst in mytyperel.get_nodes(n.Constant):
                    if not cst.value in possibletypes:
                        cst.parent.remove(cst)
        else:
            # we have to add types restriction
            if stinfo.get('scope') is not None:
                rel = var.scope.add_type_restriction(var, possibletypes)
            else:
                # tree is not annotated yet, no scope set so add the restriction
                # to the root
                rel = newroot.add_type_restriction(var, possibletypes)
            stinfo['typerel'] = rel
        stinfo['possibletypes'] = possibletypes


def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions


def iter_relations(stinfo):
    # this is a function so that test may return relation in a predictable order
    return stinfo['relations'] - stinfo['rhsrelations']

class Unsupported(Exception):
    """raised when an rql expression can't be inserted in some rql query
    because it create an unresolvable query (eg no solutions found)
    """


class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree

    this class *isn't thread safe*
    """

    def __init__(self, session):
        self.session = session
        vreg = session.vreg
        self.schema = vreg.schema
        self.annotate = vreg.rqlhelper.annotate
        self._compute_solutions = vreg.solutions

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported(str(self.select))
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()

    def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
        """
        snippets: (varmap, list of rql expression)
                  with varmap a *tuple* (select var, snippet var)
        """
        self.select = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        self.pending_keys = []
        self.existingvars = existingvars
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), (
            'rewritten rql %s has lost some solutions, there is probably '
            'something wrong in your schema permission (for instance using a '
            'RQLExpression which insert a relation which doesn\'t exists in '
            'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions))
        if len(newsolutions) > len(solutions):
            newsolutions = self.remove_ambiguities(snippets, newsolutions)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)

    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varmap, rqlexprs in snippets:
            if isinstance(varmap, dict):
                varmap = tuple(sorted(varmap.items()))
            else:
                assert isinstance(varmap, tuple), varmap
            if varexistsmap is not None and not varmap in varexistsmap:
                continue
            self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)

    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
        self.varmap = varmap
        self.revvarmap = {}
        self.varinfos = []
        self._insert_scope = None
        for i, (selectvar, snippetvar) in enumerate(varmap):
            assert snippetvar in 'SOX'
            self.revvarmap[snippetvar] = (selectvar, i)
            vi = {}
            self.varinfos.append(vi)
            try:
                vi['const'] = typed_eid(selectvar)
                vi['rhs_rels'] = vi['lhs_rels'] = {}
            except ValueError:
                try:
                    vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                except KeyError:
                    # variable may have been moved to a newly inserted subquery
                    # we should insert snippet in that subquery
                    subquery = self.select.aliases[selectvar].query
                    assert len(subquery.children) == 1
                    subselect = subquery.children[0]
                    RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
                                                      subselect.solutions, self.kwargs)
                    return
                if varexistsmap is None:
                    vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                    vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                           if not r in sti['rhsrelations'])
                else:
                    vi['rhs_rels'] = vi['lhs_rels'] = {}
        previous = None
        inserted = False
        for rqlexpr in rqlexprs:
            self.current_expr = rqlexpr
            if varexistsmap is None:
                try:
                    new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, previous)
                except Unsupported:
                    continue
                inserted = True
                if new is not None and self._insert_scope is None:
                    self.exists_snippet[rqlexpr] = new
                previous = previous or new
            else:
                # called to reintroduce snippet due to ambiguity creation,
                # so skip snippets which are not introducing this ambiguity
                exists = varexistsmap[varmap]
                if self.exists_snippet.get(rqlexpr) is exists:
                    self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
        if varexistsmap is None and not inserted:
            # no rql expression found matching rql solutions. User has no access right
            raise Unauthorized() # XXX may also be because of bad constraints in schema definition

    def insert_snippet(self, varmap, snippetrqlst, previous=None):
        new = snippetrqlst.where.accept(self)
        existing = self.existingvars
        self.existingvars = None
        try:
            return self._insert_snippet(varmap, previous, new)
        finally:
            self.existingvars = existing

    def _insert_snippet(self, varmap, previous, new):
        """insert `new` snippet into the syntax tree, which have been rewritten
        using `varmap`. In cases where an action is protected by several rql
        expresssion, `previous` will be the first rql expression which has been
        inserted, and so should be ORed with the following expressions.
        """
        if new is not None:
            if self._insert_scope is None:
                insert_scope = None
                for vi in self.varinfos:
                    scope = vi.get('stinfo', {}).get('scope', self.select)
                    if insert_scope is None:
                        insert_scope = scope
                    else:
                        insert_scope = common_parent(scope, insert_scope)
            else:
                insert_scope = self._insert_scope
            if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
                                                  for vi in self.varinfos):
                assert previous is None
                self._insert_scope, new = self.snippet_subquery(varmap, new)
                self.insert_pending()
                #self._insert_scope = None
                return new
            if not isinstance(new, (n.Exists, n.Not)):
                new = n.Exists(new)
            if previous is None:
                insert_scope.add_restriction(new)
            else:
                grandpa = previous.parent
                or_ = n.Or(previous, new)
                grandpa.replace(previous, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if previous is None:
                        self.current_statement().remove_node(new, undefine=True)
                    else:
                        grandpa.replace(or_, previous)
                        self._cleanup_inserted(new)
                    raise
                else:
                    with tempattr(self, '_insert_scope', new):
                        self.insert_pending()
            return new
        self.insert_pending()

    def insert_pending(self):
        """pending_keys hold variable referenced by U has_<action>_permission X
        relation.

        Once the snippet introducing this has been inserted and solutions
        recomputed, we have to insert snippet defined for <action> of entity
        types taken by X
        """
        stmt = self.current_statement()
        while self.pending_keys:
            key, action = self.pending_keys.pop()
            try:
                varname = self.rewritten[key]
            except KeyError:
                try:
                    varname = self.revvarmap[key[-1]][0]
                except KeyError:
                    # variable isn't used anywhere else, we can't insert security
                    raise Unauthorized()
            ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
            if len(ptypes) > 1:
                # XXX dunno how to handle this
                self.session.error(
                    'cant check security of %s, ambigous type for %s in %s',
                    stmt, varname, key[0]) # key[0] == the rql expression
                raise Unauthorized()
            etype = iter(ptypes).next()
            eschema = self.schema.eschema(etype)
            if not eschema.has_perm(self.session, action):
                rqlexprs = eschema.get_rqlexprs(action)
                if not rqlexprs:
                    raise Unauthorized()
                self.insert_snippets([({varname: 'X'}, rqlexprs)])

    def snippet_subquery(self, varmap, transformedsnippet):
        """introduce the given snippet in a subquery"""
        subselect = stmts.Select()
        snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
        get_rschema = self.schema.rschema
        aliases = []
        done = set()
        for i, (selectvar, _) in enumerate(varmap):
            need_null_test = False
            subselectvar = subselect.get_variable(selectvar)
            subselect.append_selected(n.VariableRef(subselectvar))
            aliases.append(selectvar)
            todo = [(selectvar, self.varinfos[i]['stinfo'])]
            while todo:
                varname, stinfo = todo.pop()
                done.add(varname)
                for rel in iter_relations(stinfo):
                    if rel in done:
                        continue
                    done.add(rel)
                    rschema = get_rschema(rel.r_type)
                    if rschema.final or rschema.inlined:
                        rel.children[0].name = varname # XXX explain why
                        subselect.add_restriction(rel.copy(subselect))
                        for vref in rel.children[1].iget_nodes(n.VariableRef):
                            if isinstance(vref.variable, n.ColumnAlias):
                                # XXX could probably be handled by generating the
                                # subquery into the detected subquery
                                raise BadSchemaDefinition(
                                    "cant insert security because of usage two inlined "
                                    "relations in this query. You should probably at "
                                    "least uninline %s" % rel.r_type)
                            subselect.append_selected(vref.copy(subselect))
                            aliases.append(vref.name)
                        self.select.remove_node(rel)
                        # when some inlined relation has to be copied in the
                        # subquery and that relation is optional, we need to
                        # test that either value is NULL or that the snippet
                        # condition is satisfied
                        if varname == selectvar and rel.optional and rschema.inlined:
                            need_null_test = True
                        # also, if some attributes or inlined relation of the
                        # object variable are accessed, we need to get all those
                        # from the subquery as well
                        if vref.name not in done and rschema.inlined:
                            # we can use vref here define in above for loop
                            ostinfo = vref.variable.stinfo
                            for orel in iter_relations(ostinfo):
                                orschema = get_rschema(orel.r_type)
                                if orschema.final or orschema.inlined:
                                    todo.append( (vref.name, ostinfo) )
                                    break
            if need_null_test:
                snippetrqlst = n.Or(
                    n.make_relation(subselect.get_variable(selectvar), 'is',
                                    (None, None), n.Constant,
                                    operator='='),
                    snippetrqlst)
        subselect.add_restriction(snippetrqlst)
        if self.u_varname:
            # generate an identifier for the substitution
            argname = subselect.allocate_varname()
            while argname in self.kwargs:
                argname = subselect.allocate_varname()
            subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                               'eid', unicode(argname), 'Substitute')
            self.kwargs[argname] = self.session.user.eid
        add_types_restriction(self.schema, subselect, subselect,
                              solutions=self.solutions)
        myunion = stmts.Union()
        myunion.append(subselect)
        aliases = [n.VariableRef(self.select.get_variable(name, i))
                   for i, name in enumerate(aliases)]
        self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
        self._cleanup_inserted(transformedsnippet)
        try:
            self.compute_solutions()
        except Unsupported:
            # some solutions have been lost, can't apply this rql expr
            self.select.remove_subquery(self.select.with_[-1])
            raise
        return subselect, snippetrqlst

    def remove_ambiguities(self, snippets, newsolutions):
        # the snippet has introduced some ambiguities, we have to resolve them
        # "manually"
        variantes = self.build_variantes(newsolutions)
        # insert "is" where necessary
        varexistsmap = {}
        self.removing_ambiguity = True
        for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
            varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
            var = self.select.defined_vars[varname]
            exists = var.references()[0].scope
            exists.add_constant_restriction(var, 'is', etype, 'etype')
            varexistsmap[varmap] = exists
        # insert ORED exists where necessary
        for variante in variantes[1:]:
            self.insert_snippets(snippets, varexistsmap)
            for key, etype in variante.iteritems():
                varname = self.rewritten[key]
                try:
                    var = self.select.defined_vars[varname]
                except KeyError:
                    # not a newly inserted variable
                    continue
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
        # recompute solutions
        #select.annotated = False # avoid assertion error
        self.compute_solutions()
        # clean solutions according to initial solutions
        return remove_solutions(self.solutions, self.select.solutions,
                                self.select.defined_vars)

    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for key, newvar in self.rewritten.iteritems():
                variante.append( (key, sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for key in self.rewritten:
            it = iter(variantes)
            etype = it.next()[key]
            for variante in it:
                if variante[key] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[key]
        return variantes

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        removed = set()
        for vref in node.iget_nodes(n.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]
                removed.add(vref.name)
        for key, newvar in self.rewritten.items(): # I mean items we alter it
            if newvar in removed:
                del self.rewritten[key]


    def _may_be_shared_with(self, sniprel, target):
        """if the snippet relation can be skipped to use a relation from the
        original query, return that relation node
        """
        rschema = self.schema.rschema(sniprel.r_type)
        stmt = self.current_statement()
        for vi in self.varinfos:
            try:
                if target == 'object':
                    orel = vi['lhs_rels'][sniprel.r_type]
                    cardindex = 0
                    ttypes_func = rschema.objects
                    rdef = rschema.rdef
                else: # target == 'subject':
                    orel = vi['rhs_rels'][sniprel.r_type]
                    cardindex = 1
                    ttypes_func = rschema.subjects
                    rdef = lambda x, y: rschema.rdef(y, x)
            except KeyError:
                # may be raised by vi['xhs_rels'][sniprel.r_type]
                return None
            # don't share if relation's statement is not the current statement
            if orel.stmt is not stmt:
                return None
            # can't share neged relation or relations with different outer join
            if (orel.neged(strict=True) or sniprel.neged(strict=True)
                or (orel.optional and orel.optional != sniprel.optional)):
                return None
            # if cardinality is in '?1', we can ignore the snippet relation and use
            # variable from the original query
            for etype in vi['stinfo']['possibletypes']:
                for ttype in ttypes_func(etype):
                    if rdef(etype, ttype).cardinality[cardindex] in '+*':
                        return None
            break
        return orel

    def _use_orig_term(self, snippet_varname, term):
        key = (self.current_expr, self.varmap, snippet_varname)
        if key in self.rewritten:
            stmt = self.current_statement()
            insertedvar = stmt.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(stmt))
        self.rewritten[key] = term.name

    def _get_varname_or_term(self, vname):
        stmt = self.current_statement()
        if vname == 'U':
            stmt = self.select
            if self.u_varname is None:
                self.u_varname = stmt.allocate_varname()
                # generate an identifier for the substitution
                argname = stmt.allocate_varname()
                while argname in self.kwargs:
                    argname = stmt.allocate_varname()
                # insert "U eid %(u)s"
                stmt.add_constant_restriction(
                    stmt.get_variable(self.u_varname),
                    'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varmap, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = stmt.allocate_varname()
            return newvname

    # visitor methods ##########################################################

    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode

    def visit_and(self, node):
        return self._visit_binary(node, n.And)

    def visit_or(self, node):
        return self._visit_binary(node, n.Or)

    def visit_not(self, node):
        return self._visit_unary(node, n.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, n.Exists)

    def keep_var(self, varname):
        if varname in 'SO':
            return varname in self.existingvars
        if varname == 'U':
            return True
        vargraph = self.current_expr.vargraph
        for existingvar in self.existingvars:
            #path = has_path(vargraph, varname, existingvar)
            if has_path(vargraph, varname, existingvar):
                return True
        # no path from this variable to an existing variable
        return False

    def visit_relation(self, node):
        lhs, rhs = node.get_variable_parts()
        # remove relations where an unexistant variable and or a variable linked
        # to an unexistant variable is used.
        if self.existingvars:
            if not self.keep_var(lhs.name):
                return
        if node.r_type in ('has_add_permission', 'has_update_permission',
                           'has_delete_permission', 'has_read_permission'):
            assert lhs.name == 'U'
            action = node.r_type.split('_')[1]
            key = (self.current_expr, self.varmap, rhs.name)
            self.pending_keys.append( (key, action) )
            return
        if isinstance(rhs, n.VariableRef):
            if self.existingvars and not self.keep_var(rhs.name):
                return
            if lhs.name in self.revvarmap and rhs.name != 'U':
                orel = self._may_be_shared_with(node, 'object')
                if orel is not None:
                    self._use_orig_term(rhs.name, orel.children[1].children[0])
                    return
            elif rhs.name in self.revvarmap and lhs.name != 'U':
                orel = self._may_be_shared_with(node, 'subject')
                if orel is not None:
                    self._use_orig_term(lhs.name, orel.children[0])
                    return
        rel = n.Relation(node.r_type, node.optional)
        for c in node.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, node):
        cmp_ = n.Comparison(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, node):
        cmp_ = n.MathExpression(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_function(self, node):
        """generate filter name for a function"""
        function_ = n.Function(node.name)
        for c in node.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, node):
        """generate filter name for a constant"""
        return n.Constant(node.value, node.type)

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""
        stmt = self.current_statement()
        if node.name in self.revvarmap:
            selectvar, index = self.revvarmap[node.name]
            vi = self.varinfos[index]
            if vi.get('const') is not None:
                return n.Constant(vi['const'], 'Int')
            return n.VariableRef(stmt.get_variable(selectvar))
        vname_or_term = self._get_varname_or_term(node.name)
        if isinstance(vname_or_term, basestring):
            return n.VariableRef(stmt.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(stmt)

    def current_statement(self):
        if self._insert_scope is None:
            return self.select
        return self._insert_scope.stmt