rqlrewrite.py
author Pierre-Yves David <pierre-yves.david@logilab.fr>
Fri, 21 Jun 2013 15:47:01 +0200
changeset 9049 9d62d53b49df
parent 8748 f5027f8d2478
child 9167 c05652b108ce
permissions -rw-r--r--
[server/session] allow access to session id using sessionid session.sessionid is a DBAPISession attribute. Having it on server side session will helps the rework of the API to access repository. The new schema drop the concept of DBAPISession and use server side session for the same purpose. related to #2503918

# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""RQL rewriting utilities : insert rql expression snippets into rql syntax
tree.

This is used for instance for read security checking in the repository.
"""
__docformat__ = "restructuredtext en"

from rql import nodes as n, stmts, TypeResolverException
from rql.utils import common_parent

from yams import BadSchemaDefinition

from logilab.common import tempattr
from logilab.common.graph import has_path

from cubicweb import Unauthorized


def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    if newroot is None:
        assert solutions is None
        if hasattr(rqlst, '_types_restr_added'):
            return
        solutions = rqlst.solutions
        newroot = rqlst
        rqlst._types_restr_added = True
    else:
        assert solutions is not None
        rqlst = rqlst.stmt
    eschema = schema.eschema
    allpossibletypes = {}
    for solution in solutions:
        for varname, etype in solution.iteritems():
            # XXX not considering aliases by design, right ?
            if varname not in newroot.defined_vars or eschema(etype).final:
                continue
            allpossibletypes.setdefault(varname, set()).add(etype)
    # XXX could be factorized with add_etypes_restriction from rql 0.31
    for varname in sorted(allpossibletypes):
        var = newroot.defined_vars[varname]
        stinfo = var.stinfo
        if stinfo.get('uidrel') is not None:
            continue # eid specified, no need for additional type specification
        try:
            typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
        except KeyError:
            assert varname in rqlst.aliases
            continue
        if newroot is rqlst and typerel is not None:
            mytyperel = typerel
        else:
            for vref in var.references():
                rel = vref.relation()
                if rel and rel.is_types_restriction():
                    mytyperel = rel
                    break
            else:
                mytyperel = None
        possibletypes = allpossibletypes[varname]
        if mytyperel is not None:
            if mytyperel.r_type == 'is_instance_of':
                # turn is_instance_of relation into a is relation since we've
                # all possible solutions and don't want to bother with
                # potential is_instance_of incompatibility
                mytyperel.r_type = 'is'
                if len(possibletypes) > 1:
                    node = n.Function('IN')
                    for etype in possibletypes:
                        node.append(n.Constant(etype, 'etype'))
                else:
                    node = n.Constant(etype, 'etype')
                comp = mytyperel.children[1]
                comp.replace(comp.children[0], node)
            else:
                # variable has already some strict types restriction. new
                # possible types can only be a subset of existing ones, so only
                # remove no more possible types
                for cst in mytyperel.get_nodes(n.Constant):
                    if not cst.value in possibletypes:
                        cst.parent.remove(cst)
        else:
            # we have to add types restriction
            if stinfo.get('scope') is not None:
                rel = var.scope.add_type_restriction(var, possibletypes)
            else:
                # tree is not annotated yet, no scope set so add the restriction
                # to the root
                rel = newroot.add_type_restriction(var, possibletypes)
            stinfo['typerel'] = rel
        stinfo['possibletypes'] = possibletypes


def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions


def iter_relations(stinfo):
    # this is a function so that test may return relation in a predictable order
    return stinfo['relations'] - stinfo['rhsrelations']

class Unsupported(Exception):
    """raised when an rql expression can't be inserted in some rql query
    because it create an unresolvable query (eg no solutions found)
    """


class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree

    this class *isn't thread safe*
    """

    def __init__(self, session):
        self.session = session
        vreg = session.vreg
        self.schema = vreg.schema
        self.annotate = vreg.rqlhelper.annotate
        self._compute_solutions = vreg.solutions

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported(str(self.select))
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()

    def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
        """
        snippets: (varmap, list of rql expression)
                  with varmap a *tuple* (select var, snippet var)
        """
        self.select = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        self.pending_keys = []
        self.existingvars = existingvars
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), (
            'rewritten rql %s has lost some solutions, there is probably '
            'something wrong in your schema permission (for instance using a '
            'RQLExpression which insert a relation which doesn\'t exists in '
            'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions))
        if len(newsolutions) > len(solutions):
            newsolutions = self.remove_ambiguities(snippets, newsolutions)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)

    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varmap, rqlexprs in snippets:
            if isinstance(varmap, dict):
                varmap = tuple(sorted(varmap.items()))
            else:
                assert isinstance(varmap, tuple), varmap
            if varexistsmap is not None and not varmap in varexistsmap:
                continue
            self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)

    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
        self.varmap = varmap
        self.revvarmap = {}
        self.varinfos = []
        self._insert_scope = None
        for i, (selectvar, snippetvar) in enumerate(varmap):
            assert snippetvar in 'SOX'
            self.revvarmap[snippetvar] = (selectvar, i)
            vi = {}
            self.varinfos.append(vi)
            try:
                vi['const'] = int(selectvar)
                vi['rhs_rels'] = vi['lhs_rels'] = {}
            except ValueError:
                try:
                    vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                except KeyError:
                    # variable may have been moved to a newly inserted subquery
                    # we should insert snippet in that subquery
                    subquery = self.select.aliases[selectvar].query
                    assert len(subquery.children) == 1
                    subselect = subquery.children[0]
                    RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
                                                      subselect.solutions, self.kwargs)
                    return
                if varexistsmap is None:
                    vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                    vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                           if not r in sti['rhsrelations'])
                else:
                    vi['rhs_rels'] = vi['lhs_rels'] = {}
        previous = None
        inserted = False
        for rqlexpr in rqlexprs:
            self.current_expr = rqlexpr
            if varexistsmap is None:
                try:
                    new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, previous)
                except Unsupported:
                    continue
                inserted = True
                if new is not None and self._insert_scope is None:
                    self.exists_snippet[rqlexpr] = new
                previous = previous or new
            else:
                # called to reintroduce snippet due to ambiguity creation,
                # so skip snippets which are not introducing this ambiguity
                exists = varexistsmap[varmap]
                if self.exists_snippet.get(rqlexpr) is exists:
                    self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
        if varexistsmap is None and not inserted:
            # no rql expression found matching rql solutions. User has no access right
            raise Unauthorized() # XXX may also be because of bad constraints in schema definition

    def insert_snippet(self, varmap, snippetrqlst, previous=None):
        new = snippetrqlst.where.accept(self)
        existing = self.existingvars
        self.existingvars = None
        try:
            return self._insert_snippet(varmap, previous, new)
        finally:
            self.existingvars = existing

    def _insert_snippet(self, varmap, previous, new):
        """insert `new` snippet into the syntax tree, which have been rewritten
        using `varmap`. In cases where an action is protected by several rql
        expresssion, `previous` will be the first rql expression which has been
        inserted, and so should be ORed with the following expressions.
        """
        if new is not None:
            if self._insert_scope is None:
                insert_scope = None
                for vi in self.varinfos:
                    scope = vi.get('stinfo', {}).get('scope', self.select)
                    if insert_scope is None:
                        insert_scope = scope
                    else:
                        insert_scope = common_parent(scope, insert_scope)
            else:
                insert_scope = self._insert_scope
            if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
                                                  for vi in self.varinfos):
                assert previous is None
                self._insert_scope, new = self.snippet_subquery(varmap, new)
                self.insert_pending()
                #self._insert_scope = None
                return new
            if not isinstance(new, (n.Exists, n.Not)):
                new = n.Exists(new)
            if previous is None:
                insert_scope.add_restriction(new)
            else:
                grandpa = previous.parent
                or_ = n.Or(previous, new)
                grandpa.replace(previous, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if previous is None:
                        self.current_statement().remove_node(new, undefine=True)
                    else:
                        grandpa.replace(or_, previous)
                        self._cleanup_inserted(new)
                    raise
                else:
                    with tempattr(self, '_insert_scope', new):
                        self.insert_pending()
            return new
        self.insert_pending()

    def insert_pending(self):
        """pending_keys hold variable referenced by U has_<action>_permission X
        relation.

        Once the snippet introducing this has been inserted and solutions
        recomputed, we have to insert snippet defined for <action> of entity
        types taken by X
        """
        stmt = self.current_statement()
        while self.pending_keys:
            key, action = self.pending_keys.pop()
            try:
                varname = self.rewritten[key]
            except KeyError:
                try:
                    varname = self.revvarmap[key[-1]][0]
                except KeyError:
                    # variable isn't used anywhere else, we can't insert security
                    raise Unauthorized()
            ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
            if len(ptypes) > 1:
                # XXX dunno how to handle this
                self.session.error(
                    'cant check security of %s, ambigous type for %s in %s',
                    stmt, varname, key[0]) # key[0] == the rql expression
                raise Unauthorized()
            etype = iter(ptypes).next()
            eschema = self.schema.eschema(etype)
            if not eschema.has_perm(self.session, action):
                rqlexprs = eschema.get_rqlexprs(action)
                if not rqlexprs:
                    raise Unauthorized()
                self.insert_snippets([({varname: 'X'}, rqlexprs)])

    def snippet_subquery(self, varmap, transformedsnippet):
        """introduce the given snippet in a subquery"""
        subselect = stmts.Select()
        snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
        get_rschema = self.schema.rschema
        aliases = []
        done = set()
        for i, (selectvar, _) in enumerate(varmap):
            need_null_test = False
            subselectvar = subselect.get_variable(selectvar)
            subselect.append_selected(n.VariableRef(subselectvar))
            aliases.append(selectvar)
            todo = [(selectvar, self.varinfos[i]['stinfo'])]
            while todo:
                varname, stinfo = todo.pop()
                done.add(varname)
                for rel in iter_relations(stinfo):
                    if rel in done:
                        continue
                    done.add(rel)
                    rschema = get_rschema(rel.r_type)
                    if rschema.final or rschema.inlined:
                        rel.children[0].name = varname # XXX explain why
                        subselect.add_restriction(rel.copy(subselect))
                        for vref in rel.children[1].iget_nodes(n.VariableRef):
                            if isinstance(vref.variable, n.ColumnAlias):
                                # XXX could probably be handled by generating the
                                # subquery into the detected subquery
                                raise BadSchemaDefinition(
                                    "cant insert security because of usage two inlined "
                                    "relations in this query. You should probably at "
                                    "least uninline %s" % rel.r_type)
                            subselect.append_selected(vref.copy(subselect))
                            aliases.append(vref.name)
                        self.select.remove_node(rel)
                        # when some inlined relation has to be copied in the
                        # subquery and that relation is optional, we need to
                        # test that either value is NULL or that the snippet
                        # condition is satisfied
                        if varname == selectvar and rel.optional and rschema.inlined:
                            need_null_test = True
                        # also, if some attributes or inlined relation of the
                        # object variable are accessed, we need to get all those
                        # from the subquery as well
                        if vref.name not in done and rschema.inlined:
                            # we can use vref here define in above for loop
                            ostinfo = vref.variable.stinfo
                            for orel in iter_relations(ostinfo):
                                orschema = get_rschema(orel.r_type)
                                if orschema.final or orschema.inlined:
                                    todo.append( (vref.name, ostinfo) )
                                    break
            if need_null_test:
                snippetrqlst = n.Or(
                    n.make_relation(subselect.get_variable(selectvar), 'is',
                                    (None, None), n.Constant,
                                    operator='='),
                    snippetrqlst)
        subselect.add_restriction(snippetrqlst)
        if self.u_varname:
            # generate an identifier for the substitution
            argname = subselect.allocate_varname()
            while argname in self.kwargs:
                argname = subselect.allocate_varname()
            subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                               'eid', unicode(argname), 'Substitute')
            self.kwargs[argname] = self.session.user.eid
        add_types_restriction(self.schema, subselect, subselect,
                              solutions=self.solutions)
        myunion = stmts.Union()
        myunion.append(subselect)
        aliases = [n.VariableRef(self.select.get_variable(name, i))
                   for i, name in enumerate(aliases)]
        self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
        self._cleanup_inserted(transformedsnippet)
        try:
            self.compute_solutions()
        except Unsupported:
            # some solutions have been lost, can't apply this rql expr
            self.select.remove_subquery(self.select.with_[-1])
            raise
        return subselect, snippetrqlst

    def remove_ambiguities(self, snippets, newsolutions):
        # the snippet has introduced some ambiguities, we have to resolve them
        # "manually"
        variantes = self.build_variantes(newsolutions)
        # insert "is" where necessary
        varexistsmap = {}
        self.removing_ambiguity = True
        for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
            varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
            var = self.select.defined_vars[varname]
            exists = var.references()[0].scope
            exists.add_constant_restriction(var, 'is', etype, 'etype')
            varexistsmap[varmap] = exists
        # insert ORED exists where necessary
        for variante in variantes[1:]:
            self.insert_snippets(snippets, varexistsmap)
            for key, etype in variante.iteritems():
                varname = self.rewritten[key]
                try:
                    var = self.select.defined_vars[varname]
                except KeyError:
                    # not a newly inserted variable
                    continue
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
        # recompute solutions
        #select.annotated = False # avoid assertion error
        self.compute_solutions()
        # clean solutions according to initial solutions
        return remove_solutions(self.solutions, self.select.solutions,
                                self.select.defined_vars)

    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for key, newvar in self.rewritten.iteritems():
                variante.append( (key, sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for key in self.rewritten:
            it = iter(variantes)
            etype = it.next()[key]
            for variante in it:
                if variante[key] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[key]
        return variantes

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        removed = set()
        for vref in node.iget_nodes(n.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]
                removed.add(vref.name)
        for key, newvar in self.rewritten.items(): # I mean items we alter it
            if newvar in removed:
                del self.rewritten[key]


    def _may_be_shared_with(self, sniprel, target):
        """if the snippet relation can be skipped to use a relation from the
        original query, return that relation node
        """
        rschema = self.schema.rschema(sniprel.r_type)
        stmt = self.current_statement()
        for vi in self.varinfos:
            try:
                if target == 'object':
                    orel = vi['lhs_rels'][sniprel.r_type]
                    cardindex = 0
                    ttypes_func = rschema.objects
                    rdef = rschema.rdef
                else: # target == 'subject':
                    orel = vi['rhs_rels'][sniprel.r_type]
                    cardindex = 1
                    ttypes_func = rschema.subjects
                    rdef = lambda x, y: rschema.rdef(y, x)
            except KeyError:
                # may be raised by vi['xhs_rels'][sniprel.r_type]
                return None
            # don't share if relation's statement is not the current statement
            if orel.stmt is not stmt:
                return None
            # can't share neged relation or relations with different outer join
            if (orel.neged(strict=True) or sniprel.neged(strict=True)
                or (orel.optional and orel.optional != sniprel.optional)):
                return None
            # if cardinality is in '?1', we can ignore the snippet relation and use
            # variable from the original query
            for etype in vi['stinfo']['possibletypes']:
                for ttype in ttypes_func(etype):
                    if rdef(etype, ttype).cardinality[cardindex] in '+*':
                        return None
            break
        return orel

    def _use_orig_term(self, snippet_varname, term):
        key = (self.current_expr, self.varmap, snippet_varname)
        if key in self.rewritten:
            stmt = self.current_statement()
            insertedvar = stmt.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(stmt))
        self.rewritten[key] = term.name

    def _get_varname_or_term(self, vname):
        stmt = self.current_statement()
        if vname == 'U':
            stmt = self.select
            if self.u_varname is None:
                self.u_varname = stmt.allocate_varname()
                # generate an identifier for the substitution
                argname = stmt.allocate_varname()
                while argname in self.kwargs:
                    argname = stmt.allocate_varname()
                # insert "U eid %(u)s"
                stmt.add_constant_restriction(
                    stmt.get_variable(self.u_varname),
                    'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varmap, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = stmt.allocate_varname()
            return newvname

    # visitor methods ##########################################################

    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode

    def visit_and(self, node):
        return self._visit_binary(node, n.And)

    def visit_or(self, node):
        return self._visit_binary(node, n.Or)

    def visit_not(self, node):
        return self._visit_unary(node, n.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, n.Exists)

    def keep_var(self, varname):
        if varname in 'SO':
            return varname in self.existingvars
        if varname == 'U':
            return True
        vargraph = self.current_expr.vargraph
        for existingvar in self.existingvars:
            #path = has_path(vargraph, varname, existingvar)
            if has_path(vargraph, varname, existingvar):
                return True
        # no path from this variable to an existing variable
        return False

    def visit_relation(self, node):
        lhs, rhs = node.get_variable_parts()
        # remove relations where an unexistant variable and or a variable linked
        # to an unexistant variable is used.
        if self.existingvars:
            if not self.keep_var(lhs.name):
                return
        if node.r_type in ('has_add_permission', 'has_update_permission',
                           'has_delete_permission', 'has_read_permission'):
            assert lhs.name == 'U'
            action = node.r_type.split('_')[1]
            key = (self.current_expr, self.varmap, rhs.name)
            self.pending_keys.append( (key, action) )
            return
        if isinstance(rhs, n.VariableRef):
            if self.existingvars and not self.keep_var(rhs.name):
                return
            if lhs.name in self.revvarmap and rhs.name != 'U':
                orel = self._may_be_shared_with(node, 'object')
                if orel is not None:
                    self._use_orig_term(rhs.name, orel.children[1].children[0])
                    return
            elif rhs.name in self.revvarmap and lhs.name != 'U':
                orel = self._may_be_shared_with(node, 'subject')
                if orel is not None:
                    self._use_orig_term(lhs.name, orel.children[0])
                    return
        rel = n.Relation(node.r_type, node.optional)
        for c in node.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, node):
        cmp_ = n.Comparison(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, node):
        cmp_ = n.MathExpression(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_function(self, node):
        """generate filter name for a function"""
        function_ = n.Function(node.name)
        for c in node.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, node):
        """generate filter name for a constant"""
        return n.Constant(node.value, node.type)

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""
        stmt = self.current_statement()
        if node.name in self.revvarmap:
            selectvar, index = self.revvarmap[node.name]
            vi = self.varinfos[index]
            if vi.get('const') is not None:
                return n.Constant(vi['const'], 'Int')
            return n.VariableRef(stmt.get_variable(selectvar))
        vname_or_term = self._get_varname_or_term(node.name)
        if isinstance(vname_or_term, basestring):
            return n.VariableRef(stmt.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(stmt)

    def current_statement(self):
        if self._insert_scope is None:
            return self.select
        return self._insert_scope.stmt