rqlrewrite.py
author Sylvain Thénault <sylvain.thenault@logilab.fr>
Tue, 22 Dec 2009 21:02:37 +0100
branchstable
changeset 4195 86dcaf6bb92f
parent 3934 d9a29a1fbe43
child 3998 94cc7cad3d2d
child 4212 ab6573088b4a
permissions -rw-r--r--
closes #601987 1) sqlutils.restore_from_file have to use its confirm argument when a command fail, to propose to continue there (this can't be handled by the caller) 2) source.restore method hence needs to take this confirmation callback as argument 3) properly fix places where 'drop' was given instead of 'confirm'

"""RQL rewriting utilities : insert rql expression snippets into rql syntax
tree.

This is used for instance for read security checking in the repository.

:organization: Logilab
:copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
"""
__docformat__ = "restructuredtext en"

from rql import nodes as n, stmts, TypeResolverException

from logilab.common.compat import any
from logilab.common.graph import has_path

from cubicweb import Unauthorized, typed_eid


def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    if newroot is None:
        assert solutions is None
        if hasattr(rqlst, '_types_restr_added'):
            return
        solutions = rqlst.solutions
        newroot = rqlst
        rqlst._types_restr_added = True
    else:
        assert solutions is not None
        rqlst = rqlst.stmt
    eschema = schema.eschema
    allpossibletypes = {}
    for solution in solutions:
        for varname, etype in solution.iteritems():
            if not varname in newroot.defined_vars or eschema(etype).final:
                continue
            allpossibletypes.setdefault(varname, set()).add(etype)
    for varname in sorted(allpossibletypes):
        try:
            var = newroot.defined_vars[varname]
        except KeyError:
            continue
        stinfo = var.stinfo
        if stinfo.get('uidrels'):
            continue # eid specified, no need for additional type specification
        try:
            typerels = rqlst.defined_vars[varname].stinfo.get('typerels')
        except KeyError:
            assert varname in rqlst.aliases
            continue
        if newroot is rqlst and typerels:
            mytyperel = iter(typerels).next()
        else:
            for vref in newroot.defined_vars[varname].references():
                rel = vref.relation()
                if rel and rel.is_types_restriction():
                    mytyperel = rel
                    break
            else:
                mytyperel = None
        possibletypes = allpossibletypes[varname]
        if mytyperel is not None:
            # variable as already some types restriction. new possible types
            # can only be a subset of existing ones, so only remove no more
            # possible types
            for cst in mytyperel.get_nodes(n.Constant):
                if not cst.value in possibletypes:
                    cst.parent.remove(cst)
                    try:
                        stinfo['possibletypes'].remove(cst.value)
                    except KeyError:
                        # restriction on a type not used by this query, may
                        # occurs with X is IN(...)
                        pass
        else:
            # we have to add types restriction
            if stinfo.get('scope') is not None:
                rel = var.scope.add_type_restriction(var, possibletypes)
            else:
                # tree is not annotated yet, no scope set so add the restriction
                # to the root
                rel = newroot.add_type_restriction(var, possibletypes)
            stinfo['typerels'] = frozenset((rel,))
            stinfo['possibletypes'] = possibletypes


def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions


class Unsupported(Exception): pass


class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree

    this class *isn't thread safe*
    """

    def __init__(self, session):
        self.session = session
        vreg = session.vreg
        self.schema = vreg.schema
        self.annotate = vreg.rqlhelper.annotate
        self._compute_solutions = vreg.solutions

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported(str(self.select))
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()

    def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
        """
        snippets: (varmap, list of rql expression)
                  with varmap a *tuple* (select var, snippet var)
        """
        self.select = self.insert_scope = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        self.pending_keys = []
        self.existingvars = existingvars
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), (
            'rewritten rql %s has lost some solutions, there is probably '
            'something wrong in your schema permission (for instance using a '
            'RQLExpression which insert a relation which doesn\'t exists in '
            'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions))
        if len(newsolutions) > len(solutions):
            newsolutions = self.remove_ambiguities(snippets, newsolutions)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)

    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varmap, rqlexprs in snippets:
            if varexistsmap is not None and not varmap in varexistsmap:
                continue
            self.varmap = varmap
            selectvar, snippetvar = varmap
            assert snippetvar in 'SOX'
            self.revvarmap = {snippetvar: selectvar}
            self.varinfo = vi = {}
            try:
                vi['const'] = typed_eid(selectvar) # XXX gae
                vi['rhs_rels'] = vi['lhs_rels'] = {}
            except ValueError:
                vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                if varexistsmap is None:
                    vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                    vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                           if not r in sti['rhsrelations'])
                else:
                    vi['rhs_rels'] = vi['lhs_rels'] = {}
            parent = None
            inserted = False
            for rqlexpr in rqlexprs:
                self.current_expr = rqlexpr
                if varexistsmap is None:
                    try:
                        new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
                    except Unsupported:
                        continue
                    inserted = True
                    if new is not None:
                        self.exists_snippet[rqlexpr] = new
                    parent = parent or new
                else:
                    # called to reintroduce snippet due to ambiguity creation,
                    # so skip snippets which are not introducing this ambiguity
                    exists = varexistsmap[varmap]
                    if self.exists_snippet[rqlexpr] is exists:
                        self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
            if varexistsmap is None and not inserted:
                # no rql expression found matching rql solutions. User has no access right
                raise Unauthorized()

    def insert_snippet(self, varmap, snippetrqlst, parent=None):
        new = snippetrqlst.where.accept(self)
        existing = self.existingvars
        self.existingvars = None
        try:
            return self._insert_snippet(varmap, parent, new)
        finally:
            self.existingvars = existing

    def _insert_snippet(self, varmap, parent, new):
        if new is not None:
            if self.varinfo.get('stinfo', {}).get('optrelations'):
                assert parent is None
                self.insert_scope = self.snippet_subquery(varmap, new)
                self.insert_pending()
                self.insert_scope = self.select
                return
            new = n.Exists(new)
            if parent is None:
                self.insert_scope.add_restriction(new)
            else:
                grandpa = parent.parent
                or_ = n.Or(parent, new)
                grandpa.replace(parent, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if parent is None:
                        self.select.remove_node(new, undefine=True)
                    else:
                        parent.parent.replace(or_, or_.children[0])
                        self._cleanup_inserted(new)
                    raise
                else:
                    self.insert_scope = new
                    self.insert_pending()
                    self.insert_scope = self.select
            return new
        self.insert_pending()

    def insert_pending(self):
        """pending_keys hold variable referenced by U has_<action>_permission X
        relation.

        Once the snippet introducing this has been inserted and solutions
        recomputed, we have to insert snippet defined for <action> of entity
        types taken by X
        """
        while self.pending_keys:
            key, action = self.pending_keys.pop()
            try:
                varname = self.rewritten[key]
            except KeyError:
                try:
                    varname = self.revvarmap[key[-1]]
                except KeyError:
                    # variable isn't used anywhere else, we can't insert security
                    raise Unauthorized()
            ptypes = self.select.defined_vars[varname].stinfo['possibletypes']
            if len(ptypes) > 1:
                # XXX dunno how to handle this
                self.session.error(
                    'cant check security of %s, ambigous type for %s in %s',
                    self.select, varname, key[0]) # key[0] == the rql expression
                raise Unauthorized()
            etype = iter(ptypes).next()
            eschema = self.schema.eschema(etype)
            if not eschema.has_perm(self.session, action):
                rqlexprs = eschema.get_rqlexprs(action)
                if not rqlexprs:
                    raise Unauthorized()
                self.insert_snippets([((varname, 'X'), rqlexprs)])

    def snippet_subquery(self, varmap, transformedsnippet):
        """introduce the given snippet in a subquery"""
        subselect = stmts.Select()
        selectvar, snippetvar = varmap
        subselect.append_selected(n.VariableRef(
            subselect.get_variable(selectvar)))
        aliases = [selectvar]
        subselect.add_restriction(transformedsnippet.copy(subselect))
        stinfo = self.varinfo['stinfo']
        for rel in stinfo['relations']:
            rschema = self.schema.rschema(rel.r_type)
            if rschema.final or (rschema.inlined and
                                      not rel in stinfo['rhsrelations']):
                self.select.remove_node(rel)
                rel.children[0].name = selectvar
                subselect.add_restriction(rel.copy(subselect))
                for vref in rel.children[1].iget_nodes(n.VariableRef):
                    subselect.append_selected(vref.copy(subselect))
                    aliases.append(vref.name)
        if self.u_varname:
            # generate an identifier for the substitution
            argname = subselect.allocate_varname()
            while argname in self.kwargs:
                argname = subselect.allocate_varname()
            subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                               'eid', unicode(argname), 'Substitute')
            self.kwargs[argname] = self.session.user.eid
        add_types_restriction(self.schema, subselect, subselect,
                              solutions=self.solutions)
        myunion = stmts.Union()
        myunion.append(subselect)
        aliases = [n.VariableRef(self.select.get_variable(name, i))
                   for i, name in enumerate(aliases)]
        self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
        self._cleanup_inserted(transformedsnippet)
        try:
            self.compute_solutions()
        except Unsupported:
            # some solutions have been lost, can't apply this rql expr
            self.select.remove_subquery(new, undefine=True)
            raise
        return subselect

    def remove_ambiguities(self, snippets, newsolutions):
        # the snippet has introduced some ambiguities, we have to resolve them
        # "manually"
        variantes = self.build_variantes(newsolutions)
        # insert "is" where necessary
        varexistsmap = {}
        self.removing_ambiguity = True
        for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
            varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
            var = self.select.defined_vars[varname]
            exists = var.references()[0].scope
            exists.add_constant_restriction(var, 'is', etype, 'etype')
            varexistsmap[varmap] = exists
        # insert ORED exists where necessary
        for variante in variantes[1:]:
            self.insert_snippets(snippets, varexistsmap)
            for key, etype in variante.iteritems():
                varname = self.rewritten[key]
                try:
                    var = self.select.defined_vars[varname]
                except KeyError:
                    # not a newly inserted variable
                    continue
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
        # recompute solutions
        #select.annotated = False # avoid assertion error
        self.compute_solutions()
        # clean solutions according to initial solutions
        return remove_solutions(self.solutions, self.select.solutions,
                                self.select.defined_vars)

    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for key, newvar in self.rewritten.iteritems():
                variante.append( (key, sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for key in self.rewritten:
            it = iter(variantes)
            etype = it.next()[key]
            for variante in it:
                if variante[key] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[key]
        return variantes

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        for vref in node.iget_nodes(n.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]

    def _may_be_shared_with(self, sniprel, target, searchedvarname):
        """if the snippet relation can be skipped to use a relation from the
        original query, return that relation node
        """
        rschema = self.schema.rschema(sniprel.r_type)
        try:
            if target == 'object':
                orel = self.varinfo['lhs_rels'][sniprel.r_type]
                cardindex = 0
                ttypes_func = rschema.objects
                rprop = rschema.rproperty
            else: # target == 'subject':
                orel = self.varinfo['rhs_rels'][sniprel.r_type]
                cardindex = 1
                ttypes_func = rschema.subjects
                rprop = lambda x, y, z: rschema.rproperty(y, x, z)
        except KeyError, ex:
            # may be raised by self.varinfo['xhs_rels'][sniprel.r_type]
            return None
        # can't share neged relation or relations with different outer join
        if (orel.neged(strict=True) or sniprel.neged(strict=True)
            or (orel.optional and orel.optional != sniprel.optional)):
            return None
        # if cardinality is in '?1', we can ignore the snippet relation and use
        # variable from the original query
        for etype in self.varinfo['stinfo']['possibletypes']:
            for ttype in ttypes_func(etype):
                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
                    return None
        return orel

    def _use_orig_term(self, snippet_varname, term):
        key = (self.current_expr, self.varmap, snippet_varname)
        if key in self.rewritten:
            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
        self.rewritten[key] = term.name

    def _get_varname_or_term(self, vname):
        if vname == 'U':
            if self.u_varname is None:
                select = self.select
                self.u_varname = select.allocate_varname()
                # generate an identifier for the substitution
                argname = select.allocate_varname()
                while argname in self.kwargs:
                    argname = select.allocate_varname()
                # insert "U eid %(u)s"
                var = select.get_variable(self.u_varname)
                select.add_constant_restriction(select.get_variable(self.u_varname),
                                                'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varmap, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = self.select.allocate_varname()
            return newvname

    # visitor methods ##########################################################

    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode

    def visit_and(self, node):
        return self._visit_binary(node, n.And)

    def visit_or(self, node):
        return self._visit_binary(node, n.Or)

    def visit_not(self, node):
        return self._visit_unary(node, n.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, n.Exists)

    def keep_var(self, varname):
        if varname in 'SO':
            return varname in self.existingvars
        if varname == 'U':
            return True
        vargraph = self.current_expr.vargraph
        for existingvar in self.existingvars:
            #path = has_path(vargraph, varname, existingvar)
            if has_path(vargraph, varname, existingvar):
                return True
        # no path from this variable to an existing variable
        return False

    def visit_relation(self, node):
        lhs, rhs = node.get_variable_parts()
        # remove relations where an unexistant variable and or a variable linked
        # to an unexistant variable is used.
        if self.existingvars:
            if not self.keep_var(lhs.name):
                return
        if node.r_type in ('has_add_permission', 'has_update_permission',
                           'has_delete_permission', 'has_read_permission'):
            assert lhs.name == 'U'
            action = node.r_type.split('_')[1]
            key = (self.current_expr, self.varmap, rhs.name)
            self.pending_keys.append( (key, action) )
            return
        if isinstance(rhs, n.VariableRef):
            if self.existingvars and not self.keep_var(rhs.name):
                return
            if lhs.name in self.revvarmap and rhs.name != 'U':
                orel = self._may_be_shared_with(node, 'object', lhs.name)
                if orel is not None:
                    self._use_orig_term(rhs.name, orel.children[1].children[0])
                    return
            elif rhs.name in self.revvarmap and lhs.name != 'U':
                orel = self._may_be_shared_with(node, 'subject', rhs.name)
                if orel is not None:
                    self._use_orig_term(lhs.name, orel.children[0])
                    return
        rel = n.Relation(node.r_type, node.optional)
        for c in node.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, node):
        cmp_ = n.Comparison(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, node):
        cmp_ = n.MathExpression(node.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_function(self, node):
        """generate filter name for a function"""
        function_ = n.Function(node.name)
        for c in node.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, node):
        """generate filter name for a constant"""
        return n.Constant(node.value, node.type)

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""
        if node.name in self.revvarmap:
            if self.varinfo.get('const') is not None:
                return n.Constant(self.varinfo['const'], 'Int') # XXX gae
            return n.VariableRef(self.select.get_variable(
                self.revvarmap[node.name]))
        vname_or_term = self._get_varname_or_term(node.name)
        if isinstance(vname_or_term, basestring):
            return n.VariableRef(self.select.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(self.select)