rqlrewrite.py
author Sylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 27 Jun 2011 18:46:08 +0200
branchstable
changeset 7564 1d64c8d33156
parent 7555 c3bf459268d7
child 7843 3b51806da60b
permissions -rw-r--r--
[server] "overrule" case insensitivity of database name (closes: #611294) The only instances where you are required to use quotes are either when a database object's identifier is identical to a keyword, or when the identifier has at least one capitalized letter in its name. In either of these circumstances, you must remember to quote the identifier both when creating the object, as well as in any subsequent references to that object

# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""RQL rewriting utilities : insert rql expression snippets into rql syntax
tree.

This is used for instance for read security checking in the repository.
"""
from __future__ import with_statement

__docformat__ = "restructuredtext en"

from rql import nodes as n, stmts, TypeResolverException
from rql.utils import common_parent

from yams import BadSchemaDefinition

from logilab.common import tempattr
from logilab.common.graph import has_path

from cubicweb import Unauthorized, typed_eid


def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    if newroot is None:
        assert solutions is None
        if hasattr(rqlst, '_types_restr_added'):
            return
        solutions = rqlst.solutions
        newroot = rqlst
        rqlst._types_restr_added = True
    else:
        assert solutions is not None
        rqlst = rqlst.stmt
    eschema = schema.eschema
    allpossibletypes = {}
    for solution in solutions:
        for varname, etype in solution.iteritems():
            # XXX not considering aliases by design, right ?
            if varname not in newroot.defined_vars or eschema(etype).final:
                continue
            allpossibletypes.setdefault(varname, set()).add(etype)
    for varname in sorted(allpossibletypes):
        var = newroot.defined_vars[varname]
        stinfo = var.stinfo
        if stinfo.get('uidrel') is not None:
            continue # eid specified, no need for additional type specification
        try:
            typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
        except KeyError:
            assert varname in rqlst.aliases
            continue
        if newroot is rqlst and typerel is not None:
            mytyperel = typerel
        else:
            for vref in var.references():
                rel = vref.relation()
                if rel and rel.is_types_restriction():
                    mytyperel = rel
                    break
            else:
                mytyperel = None
        possibletypes = allpossibletypes[varname]
        if mytyperel is not None:
            # variable as already some types restriction. new possible types
            # can only be a subset of existing ones, so only remove no more
            # possible types
            for cst in mytyperel.get_nodes(n.Constant):
                if not cst.value in possibletypes:
                    cst.parent.remove(cst)
        else:
            # we have to add types restriction
            if stinfo.get('scope') is not None:
                rel = var.scope.add_type_restriction(var, possibletypes)
            else:
                # tree is not annotated yet, no scope set so add the restriction
                # to the root
                rel = newroot.add_type_restriction(var, possibletypes)
            stinfo['typerel'] = rel
        stinfo['possibletypes'] = possibletypes


def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions


class Unsupported(Exception):
    """raised when an rql expression can't be inserted in some rql query
    because it create an unresolvable query (eg no solutions found)
    """


class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree

    this class *isn't thread safe*
    """

    def __init__(self, session):
        self.session = session
        vreg = session.vreg
        self.schema = vreg.schema
        self.annotate = vreg.rqlhelper.annotate
        self._compute_solutions = vreg.solutions

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported(str(self.select))
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()

    def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
        """
        snippets: (varmap, list of rql expression)
                  with varmap a *tuple* (select var, snippet var)
        """
        self.select = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        self.pending_keys = []
        self.existingvars = existingvars
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), (
            'rewritten rql %s has lost some solutions, there is probably '
            'something wrong in your schema permission (for instance using a '
            'RQLExpression which insert a relation which doesn\'t exists in '
            'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions))
        if len(newsolutions) > len(solutions):
            newsolutions = self.remove_ambiguities(snippets, newsolutions)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)

    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varmap, rqlexprs in snippets:
            if isinstance(varmap, dict):
                varmap = tuple(sorted(varmap.items()))
            else:
                assert isinstance(varmap, tuple), varmap
            if varexistsmap is not None and not varmap in varexistsmap:
                continue
            self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)

    def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
        self.varmap = varmap
        self.revvarmap = {}
        self.varinfos = []
        self._insert_scope = None
        for i, (selectvar, snippetvar) in enumerate(varmap):
            assert snippetvar in 'SOX'
            self.revvarmap[snippetvar] = (selectvar, i)
            vi = {}
            self.varinfos.append(vi)
            try:
                vi['const'] = typed_eid(selectvar) # XXX gae
                vi['rhs_rels'] = vi['lhs_rels'] = {}
            except ValueError:
                try:
                    vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
                except KeyError:
                    # variable may have been moved to a newly inserted subquery
                    # we should insert snippet in that subquery
                    subquery = self.select.aliases[selectvar].query
                    assert len(subquery.children) == 1
                    subselect = subquery.children[0]
                    RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
                                                      subselect.solutions, self.kwargs)
                    return
                if varexistsmap is None:
                    vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
                    vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
                                           if not r in sti['rhsrelations'])
                else:
                    vi['rhs_rels'] = vi['lhs_rels'] = {}
        parent = None
        inserted = False
        for rqlexpr in rqlexprs:
            self.current_expr = rqlexpr
            if varexistsmap is None:
                try:
                    new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
                except Unsupported:
                    continue
                inserted = True
                if new is not None and self._insert_scope is None:
                    self.exists_snippet[rqlexpr] = new
                parent = parent or new
            else:
                # called to reintroduce snippet due to ambiguity creation,
                # so skip snippets which are not introducing this ambiguity
                exists = varexistsmap[varmap]
                if self.exists_snippet[rqlexpr] is exists:
                    self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
        if varexistsmap is None and not inserted:
            # no rql expression found matching rql solutions. User has no access right
            raise Unauthorized() # XXX bad constraint when inserting constraints

    def insert_snippet(self, varmap, snippetrqlst, parent=None):
        new = snippetrqlst.where.accept(self)
        existing = self.existingvars
        self.existingvars = None
        try:
            return self._insert_snippet(varmap, parent, new)
        finally:
            self.existingvars = existing

    def _insert_snippet(self, varmap, parent, new):
        if new is not None:
            if self._insert_scope is None:
                insert_scope = None
                for vi in self.varinfos:
                    scope = vi.get('stinfo', {}).get('scope', self.select)
                    if insert_scope is None:
                        insert_scope = scope
                    else:
                        insert_scope = common_parent(scope, insert_scope)
            else:
                insert_scope = self._insert_scope
            if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
                                                  for vi in self.varinfos):
                assert parent is None
                self._insert_scope = self.snippet_subquery(varmap, new)
                self.insert_pending()
                #self._insert_scope = None
                return
            if not isinstance(new, (n.Exists, n.Not)):
                new = n.Exists(new)
            if parent is None:
                insert_scope.add_restriction(new)
            else:
                grandpa = parent.parent
                or_ = n.Or(parent, new)
                grandpa.replace(parent, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if parent is None:
                        self.current_statement().remove_node(new, undefine=True)
                    else:
                        parent.parent.replace(or_, or_.children[0])
                        self._cleanup_inserted(new)
                    raise
                else:
                    with tempattr(self, '_insert_scope', new):
                        self.insert_pending()
            return new
        self.insert_pending()

    def insert_pending(self):
        """pending_keys hold variable referenced by U has_<action>_permission X
        relation.

        Once the snippet introducing this has been inserted and solutions
        recomputed, we have to insert snippet defined for <action> of entity
        types taken by X
        """
        stmt = self.current_statement()
        while self.pending_keys:
            key, action = self.pending_keys.pop()
            try:
                varname = self.rewritten[key]
            except KeyError:
                try:
                    varname = self.revvarmap[key[-1]][0]
                except KeyError:
                    # variable isn't used anywhere else, we can't insert security
                    raise Unauthorized()
            ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
            if len(ptypes) > 1:
                # XXX dunno how to handle this
                self.session.error(
                    'cant check security of %s, ambigous type for %s in %s',
                    stmt, varname, key[0]) # key[0] == the rql expression
                raise Unauthorized()
            etype = iter(ptypes).next()
            eschema = self.schema.eschema(etype)
            if not eschema.has_perm(self.session, action):
                rqlexprs = eschema.get_rqlexprs(action)
                if not rqlexprs:
                    raise Unauthorized()
                self.insert_snippets([({varname: 'X'}, rqlexprs)])

    def snippet_subquery(self, varmap, transformedsnippet):
        """introduce the given snippet in a subquery"""
        subselect = stmts.Select()
        snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
        aliases = []
        rels_done = set()
        for i, (selectvar, snippetvar) in enumerate(varmap):
            subselectvar = subselect.get_variable(selectvar)
            subselect.append_selected(n.VariableRef(subselectvar))
            aliases.append(selectvar)
            vi = self.varinfos[i]
            need_null_test = False
            stinfo = vi['stinfo']
            for rel in stinfo['relations']:
                if rel in rels_done:
                    continue
                rels_done.add(rel)
                rschema = self.schema.rschema(rel.r_type)
                if rschema.final or (rschema.inlined and
                                     not rel in stinfo['rhsrelations']):
                    rel.children[0].name = selectvar # XXX explain why
                    subselect.add_restriction(rel.copy(subselect))
                    for vref in rel.children[1].iget_nodes(n.VariableRef):
                        if isinstance(vref.variable, n.ColumnAlias):
                            # XXX could probably be handled by generating the
                            # subquery into the detected subquery
                            raise BadSchemaDefinition(
                                "cant insert security because of usage two inlined "
                                "relations in this query. You should probably at "
                                "least uninline %s" % rel.r_type)
                        subselect.append_selected(vref.copy(subselect))
                        aliases.append(vref.name)
                    self.select.remove_node(rel)
                    # when some inlined relation has to be copied in the
                    # subquery, we need to test that either value is NULL or
                    # that the snippet condition is satisfied
                    if rschema.inlined and rel.optional:
                        need_null_test = True
            if need_null_test:
                snippetrqlst = n.Or(
                    n.make_relation(subselectvar, 'is', (None, None), n.Constant,
                                    operator='='),
                    snippetrqlst)
        subselect.add_restriction(snippetrqlst)
        if self.u_varname:
            # generate an identifier for the substitution
            argname = subselect.allocate_varname()
            while argname in self.kwargs:
                argname = subselect.allocate_varname()
            subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                               'eid', unicode(argname), 'Substitute')
            self.kwargs[argname] = self.session.user.eid
        add_types_restriction(self.schema, subselect, subselect,
                              solutions=self.solutions)
        myunion = stmts.Union()
        myunion.append(subselect)
        aliases = [n.VariableRef(self.select.get_variable(name, i))
                   for i, name in enumerate(aliases)]
        self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
        self._cleanup_inserted(transformedsnippet)
        try:
            self.compute_solutions()
        except Unsupported:
            # some solutions have been lost, can't apply this rql expr
            self.select.remove_subquery(self.select.with_[-1])
            raise
        return subselect

    def remove_ambiguities(self, snippets, newsolutions):
        # the snippet has introduced some ambiguities, we have to resolve them
        # "manually"
        variantes = self.build_variantes(newsolutions)
        # insert "is" where necessary
        varexistsmap = {}
        self.removing_ambiguity = True
        for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
            varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
            var = self.select.defined_vars[varname]
            exists = var.references()[0].scope
            exists.add_constant_restriction(var, 'is', etype, 'etype')
            varexistsmap[varmap] = exists
        # insert ORED exists where necessary
        for variante in variantes[1:]:
            self.insert_snippets(snippets, varexistsmap)
            for key, etype in variante.iteritems():
                varname = self.rewritten[key]
                try:
                    var = self.select.defined_vars[varname]
                except KeyError:
                    # not a newly inserted variable
                    continue
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
        # recompute solutions
        #select.annotated = False # avoid assertion error
        self.compute_solutions()
        # clean solutions according to initial solutions
        return remove_solutions(self.solutions, self.select.solutions,
                                self.select.defined_vars)

    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for key, newvar in self.rewritten.iteritems():
                variante.append( (key, sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for key in self.rewritten:
            it = iter(variantes)
            etype = it.next()[key]
            for variante in it:
                if variante[key] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[key]
        return variantes

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        for vref in node.iget_nodes(n.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]

    def _may_be_shared_with(self, sniprel, target):
        """if the snippet relation can be skipped to use a relation from the
        original query, return that relation node
        """
        rschema = self.schema.rschema(sniprel.r_type)
        stmt = self.current_statement()
        for vi in self.varinfos:
            try:
                if target == 'object':
                    orel = vi['lhs_rels'][sniprel.r_type]
                    cardindex = 0
                    ttypes_func = rschema.objects
                    rdef = rschema.rdef
                else: # target == 'subject':
                    orel = vi['rhs_rels'][sniprel.r_type]
                    cardindex = 1
                    ttypes_func = rschema.subjects
                    rdef = lambda x, y: rschema.rdef(y, x)
            except KeyError:
                # may be raised by vi['xhs_rels'][sniprel.r_type]
                return None
            # don't share if relation's statement is not the current statement
            if orel.stmt is not stmt:
                return None
            # can't share neged relation or relations with different outer join
            if (orel.neged(strict=True) or sniprel.neged(strict=True)
                or (orel.optional and orel.optional != sniprel.optional)):
                return None
            # if cardinality is in '?1', we can ignore the snippet relation and use
            # variable from the original query
            for etype in vi['stinfo']['possibletypes']:
                for ttype in ttypes_func(etype):
                    if rdef(etype, ttype).cardinality[cardindex] in '+*':
                        return None
            break
        return orel

    def _use_orig_term(self, snippet_varname, term):
        key = (self.current_expr, self.varmap, snippet_varname)
        if key in self.rewritten:
            stmt = self.current_statement()
            insertedvar = stmt.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(stmt))
        self.rewritten[key] = term.name

    def _get_varname_or_term(self, vname):
        stmt = self.current_statement()
        if vname == 'U':
            stmt = self.select
            if self.u_varname is None:
                self.u_varname = stmt.allocate_varname()
                # generate an identifier for the substitution
                argname = stmt.allocate_varname()
                while argname in self.kwargs:
                    argname = stmt.allocate_varname()
                # insert "U eid %(u)s"
                stmt.add_constant_restriction(
                    stmt.get_variable(self.u_varname),
                    'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varmap, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = stmt.allocate_varname()
            return newvname

    # visitor methods ##########################################################

    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode

    def visit_and(self, node):
        return self._visit_binary(node, n.And)

    def visit_or(self, node):
        return self._visit_binary(node, n.Or)

    def visit_not(self, node):
        return self._visit_unary(node, n.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, n.Exists)

    def keep_var(self, varname):
        if varname in 'SO':
            return varname in self.existingvars
        if varname == 'U':
            return True
        vargraph = self.current_expr.vargraph
        for existingvar in self.existingvars:
            #path = has_path(vargraph, varname, existingvar)
            if has_path(vargraph, varname, existingvar):
                return True
        # no path from this variable to an existing variable
        return False

    def visit_relation(self, node):
        lhs, rhs = node.get_variable_parts()
        # remove relations where an unexistant variable and or a variable linked
        # to an unexistant variable is used.
        if self.existingvars:
            if not self.keep_var(lhs.name):
                return
        if node.r_type in ('has_add_permission', 'has_update_permission',
                           'has_delete_permission', 'has_read_permission'):
            assert lhs.name == 'U'
            action = node.r_type.split('_')[1]
            key = (self.current_expr, self.varmap, rhs.name)
            self.pending_keys.append( (key, action) )
            return
        if isinstance(rhs, n.VariableRef):
            if self.existingvars and not self.keep_var(rhs.name):
                return
            if lhs.name in self.revvarmap and rhs.name != 'U':
                orel = self._may_be_shared_with(node, 'object')
                if orel is not None:
                    self._use_orig_term(rhs.name, orel.children[1].children[0])
                    return
            elif rhs.name in self.revvarmap and lhs.name != 'U':
                orel = self._may_be_shared_with(node, 'subject')
                if orel is not None:
                    self._use_orig_term(lhs.name, orel.children[0])
                    return
        rel = n.Relation(node.r_type, node.optional)
        for c in node.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, node):
        cmp_ = n.Comparison(node.operator)
        for c in node.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, node):
        cmp_ = n.MathExpression(node.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_function(self, node):
        """generate filter name for a function"""
        function_ = n.Function(node.name)
        for c in node.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, node):
        """generate filter name for a constant"""
        return n.Constant(node.value, node.type)

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""
        stmt = self.current_statement()
        if node.name in self.revvarmap:
            selectvar, index = self.revvarmap[node.name]
            vi = self.varinfos[index]
            if vi.get('const') is not None:
                return n.Constant(vi['const'], 'Int') # XXX gae
            return n.VariableRef(stmt.get_variable(selectvar))
        vname_or_term = self._get_varname_or_term(node.name)
        if isinstance(vname_or_term, basestring):
            return n.VariableRef(stmt.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(stmt)

    def current_statement(self):
        if self._insert_scope is None:
            return self.select
        return self._insert_scope.stmt