server/rqlrewrite.py
author sylvain.thenault@logilab.fr
Mon, 23 Mar 2009 17:55:59 +0100
branchtls-sprint
changeset 1132 96752791c2b6
parent 0 b97547f5f1fa
child 1138 22f634977c95
permissions -rw-r--r--
pylint cleanup

"""RQL rewriting utilities, used for read security checking

:organization: Logilab
:copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""

from rql import nodes, stmts, TypeResolverException
from cubicweb import Unauthorized, server, typed_eid
from cubicweb.server.ssplanner import add_types_restriction

def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions

class Unsupported(Exception): pass
        
class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree"""
    def __init__(self, querier, session):
        self.session = session
        self.annotate = querier._rqlhelper.annotate
        self._compute_solutions = querier.solutions
        self.schema = querier.schema

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported()
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()
        
    def rewrite(self, select, snippets, solutions, kwargs):
        if server.DEBUG:
            print '---- rewrite', select, snippets, solutions
        self.select = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), \
               'rewritten rql %s has lost some solutions, there is probably something '\
               'wrong in your schema permission (for instance using a '\
              'RQLExpression which insert a relation which doesn\'t exists in '\
               'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions)
        if len(newsolutions) > len(solutions):
            # the snippet has introduced some ambiguities, we have to resolve them
            # "manually"
            variantes = self.build_variantes(newsolutions)
            # insert "is" where necessary
            varexistsmap = {}
            self.removing_ambiguity = True
            for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
                varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
                var = select.defined_vars[varname]
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
                varexistsmap[mainvar] = exists
            # insert ORED exists where necessary
            for variante in variantes[1:]:
                self.insert_snippets(snippets, varexistsmap)
                for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
                    varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
                    try:
                        var = select.defined_vars[varname]
                    except KeyError:
                        # not a newly inserted variable
                        continue
                    exists = var.references()[0].scope
                    exists.add_constant_restriction(var, 'is', etype, 'etype')
            # recompute solutions
            #select.annotated = False # avoid assertion error
            self.compute_solutions()
            # clean solutions according to initial solutions
            newsolutions = remove_solutions(solutions, select.solutions,
                                            select.defined_vars)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)
        if server.DEBUG:
            print '---- rewriten', select
            
    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
                variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for erqlexpr, mainvar, oldvar in self.rewritten:
            it = iter(variantes)
            etype = it.next()[(erqlexpr, mainvar, oldvar)]
            for variante in it:
                if variante[(erqlexpr, mainvar, oldvar)] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[(erqlexpr, mainvar, oldvar)]
        return variantes
    
    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varname, erqlexprs in snippets:
            if varexistsmap is not None and not varname in varexistsmap:
                continue
            try:
                self.const = typed_eid(varname)
                self.varname = self.const
                self.rhs_rels = self.lhs_rels = {}
            except ValueError:
                self.varname = varname
                self.const = None
                self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
                if varexistsmap is None:
                    self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
                    self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
                                                  if not rel in stinfo['rhsrelations'])
                else:
                    self.rhs_rels = self.lhs_rels = {}
            parent = None
            inserted = False
            for erqlexpr in erqlexprs:
                self.current_expr = erqlexpr
                if varexistsmap is None:
                    try:
                        new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
                    except Unsupported:
                        continue
                    inserted = True
                    if new is not None:
                        self.exists_snippet[erqlexpr] = new
                    parent = parent or new
                else:
                    # called to reintroduce snippet due to ambiguity creation,
                    # so skip snippets which are not introducing this ambiguity
                    exists = varexistsmap[varname]
                    if self.exists_snippet[erqlexpr] is exists:
                        self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
            if varexistsmap is None and not inserted:
                # no rql expression found matching rql solutions. User has no access right
                raise Unauthorized()
            
    def insert_snippet(self, varname, snippetrqlst, parent=None):
        new = snippetrqlst.where.accept(self)
        if new is not None:
            try:
                var = self.select.defined_vars[varname]
            except KeyError:
                # not a variable
                pass
            else:
                if var.stinfo['optrelations']:
                    # use a subquery
                    subselect = stmts.Select()
                    subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
                    subselect.add_restriction(new.copy(subselect))
                    aliases = [varname]
                    for rel in var.stinfo['relations']:
                        rschema = self.schema.rschema(rel.r_type)
                        if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
                            self.select.remove_node(rel)
                            rel.children[0].name = varname
                            subselect.add_restriction(rel.copy(subselect))
                            for vref in rel.children[1].iget_nodes(nodes.VariableRef):
                                subselect.append_selected(vref.copy(subselect))
                                aliases.append(vref.name)
                    if self.u_varname:
                        # generate an identifier for the substitution
                        argname = subselect.allocate_varname()
                        while argname in self.kwargs:
                            argname = subselect.allocate_varname()
                        subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                                        'eid', unicode(argname), 'Substitute')
                        self.kwargs[argname] = self.session.user.eid
                    add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
                    assert parent is None
                    myunion = stmts.Union()
                    myunion.append(subselect)
                    aliases = [nodes.VariableRef(self.select.get_variable(name, i))
                               for i, name in enumerate(aliases)]
                    self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
                    self._cleanup_inserted(new)
                    try:
                        self.compute_solutions()
                    except Unsupported:
                        # some solutions have been lost, can't apply this rql expr
                        self.select.remove_subquery(new, undefine=True)
                        raise
                    return
            new = nodes.Exists(new)
            if parent is None:
                self.select.add_restriction(new)
            else:
                grandpa = parent.parent
                or_ = nodes.Or(parent, new)
                grandpa.replace(parent, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if parent is None:
                        self.select.remove_node(new, undefine=True)
                    else:
                        parent.parent.replace(or_, or_.children[0])
                        self._cleanup_inserted(new)
                    raise 
            return new

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        for vref in node.iget_nodes(nodes.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]
        
    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode 
        
    def visit_and(self, et):
        return self._visit_binary(et, nodes.And)

    def visit_or(self, ou):
        return self._visit_binary(ou, nodes.Or)
        
    def visit_not(self, node):
        return self._visit_unary(node, nodes.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, nodes.Exists)
   
    def visit_relation(self, relation):
        lhs, rhs = relation.get_variable_parts()
        if lhs.name == 'X':
            # on lhs
            # see if we can reuse this relation
            if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
                if self._may_be_shared(relation, 'object'):
                    # ok, can share variable
                    term = self.lhs_rels[relation.r_type].children[1].children[0]
                    self._use_outer_term(rhs.name, term)
                    return
        elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
            # on rhs
            # see if we can reuse this relation
            if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
                # ok, can share variable
                term = self.rhs_rels[relation.r_type].children[0]
                self._use_outer_term(lhs.name, term)            
                return
        rel = nodes.Relation(relation.r_type, relation.optional)
        for c in relation.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, cmp):
        cmp_ = nodes.Comparison(cmp.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, mexpr):
        cmp_ = nodes.MathExpression(cmp.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_
        
    def visit_function(self, function):
        """generate filter name for a function"""
        function_ = nodes.Function(function.name)
        for c in function.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, constant):
        """generate filter name for a constant"""
        return nodes.Constant(constant.value, constant.type)

    def visit_variableref(self, vref):
        """get the sql name for a variable reference"""
        if vref.name == 'X':
            if self.const is not None:
                return nodes.Constant(self.const, 'Int')
            return nodes.VariableRef(self.select.get_variable(self.varname))
        vname_or_term = self._get_varname_or_term(vref.name)
        if isinstance(vname_or_term, basestring):
            return nodes.VariableRef(self.select.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(self.select)

    def _may_be_shared(self, relation, target):
        """return True if the snippet relation can be skipped to use a relation
        from the original query
        """
        # if cardinality is in '?1', we can ignore the relation and use variable
        # from the original query
        rschema = self.schema.rschema(relation.r_type)
        if target == 'object':
            cardindex = 0
            ttypes_func = rschema.objects
            rprop = rschema.rproperty
        else: # target == 'subject':
            cardindex = 1
            ttypes_func = rschema.subjects
            rprop = lambda x, y, z: rschema.rproperty(y, x, z)
        for etype in self.varstinfo['possibletypes']:
            for ttype in ttypes_func(etype):
                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
                    return False
        return True

    def _use_outer_term(self, snippet_varname, term):
        key = (self.current_expr, self.varname, snippet_varname)
        if key in self.rewritten:
            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
        self.rewritten[key] = term
        
    def _get_varname_or_term(self, vname):
        if vname == 'U':
            if self.u_varname is None:
                select = self.select
                self.u_varname = select.allocate_varname()
                # generate an identifier for the substitution
                argname = select.allocate_varname()
                while argname in self.kwargs:
                    argname = select.allocate_varname()
                # insert "U eid %(u)s"
                var = select.get_variable(self.u_varname)
                select.add_constant_restriction(select.get_variable(self.u_varname),
                                                'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varname, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = self.select.allocate_varname()
            return newvname