server/rqlrewrite.py
author Adrien Di Mascio <Adrien.DiMascio@logilab.fr>
Tue, 17 Feb 2009 16:33:52 +0100
changeset 676 270eb87a768a
parent 0 b97547f5f1fa
child 1132 96752791c2b6
permissions -rw-r--r--
provide a new add_cubes() migration function for cases where the new cubes are linked together by new relations In this case, we need to add all new cubes at once.

"""RQL rewriting utilities, used for read security checking

:organization: Logilab
:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""

from rql import nodes, stmts, TypeResolverException
from cubicweb import Unauthorized, server, typed_eid
from cubicweb.server.ssplanner import add_types_restriction

def remove_solutions(origsolutions, solutions, defined):
    """when a rqlst has been generated from another by introducing security
    assertions, this method returns solutions which are contained in orig
    solutions
    """
    newsolutions = []
    for origsol in origsolutions:
        for newsol in solutions[:]:
            for var, etype in origsol.items():
                try:
                    if newsol[var] != etype:
                        try:
                            defined[var].stinfo['possibletypes'].remove(newsol[var])
                        except KeyError:
                            pass
                        break
                except KeyError,ex:
                    # variable has been rewritten
                    continue
            else:
                newsolutions.append(newsol)
                solutions.remove(newsol)
    return newsolutions

class Unsupported(Exception): pass
        
class RQLRewriter(object):
    """insert some rql snippets into another rql syntax tree"""
    def __init__(self, querier, session):
        self.session = session
        self.annotate = querier._rqlhelper.annotate
        self._compute_solutions = querier.solutions
        self.schema = querier.schema

    def compute_solutions(self):
        self.annotate(self.select)
        try:
            self._compute_solutions(self.session, self.select, self.kwargs)
        except TypeResolverException:
            raise Unsupported()
        if len(self.select.solutions) < len(self.solutions):
            raise Unsupported()
        
    def rewrite(self, select, snippets, solutions, kwargs):
        if server.DEBUG:
            print '---- rewrite', select, snippets, solutions
        self.select = select
        self.solutions = solutions
        self.kwargs = kwargs
        self.u_varname = None
        self.removing_ambiguity = False
        self.exists_snippet = {}
        # we have to annotate the rqlst before inserting snippets, even though
        # we'll have to redo it latter
        self.annotate(select)
        self.insert_snippets(snippets)
        if not self.exists_snippet and self.u_varname:
            # U has been inserted than cancelled, cleanup
            select.undefine_variable(select.defined_vars[self.u_varname])
        # clean solutions according to initial solutions
        newsolutions = remove_solutions(solutions, select.solutions,
                                        select.defined_vars)
        assert len(newsolutions) >= len(solutions), \
               'rewritten rql %s has lost some solutions, there is probably something '\
               'wrong in your schema permission (for instance using a '\
              'RQLExpression which insert a relation which doesn\'t exists in '\
               'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
            select, solutions, newsolutions)
        if len(newsolutions) > len(solutions):
            # the snippet has introduced some ambiguities, we have to resolve them
            # "manually"
            variantes = self.build_variantes(newsolutions)
            # insert "is" where necessary
            varexistsmap = {}
            self.removing_ambiguity = True
            for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
                varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
                var = select.defined_vars[varname]
                exists = var.references()[0].scope
                exists.add_constant_restriction(var, 'is', etype, 'etype')
                varexistsmap[mainvar] = exists
            # insert ORED exists where necessary
            for variante in variantes[1:]:
                self.insert_snippets(snippets, varexistsmap)
                for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
                    varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
                    try:
                        var = select.defined_vars[varname]
                    except KeyError:
                        # not a newly inserted variable
                        continue
                    exists = var.references()[0].scope
                    exists.add_constant_restriction(var, 'is', etype, 'etype')
            # recompute solutions
            #select.annotated = False # avoid assertion error
            self.compute_solutions()
            # clean solutions according to initial solutions
            newsolutions = remove_solutions(solutions, select.solutions,
                                            select.defined_vars)
        select.solutions = newsolutions
        add_types_restriction(self.schema, select)
        if server.DEBUG:
            print '---- rewriten', select
            
    def build_variantes(self, newsolutions):
        variantes = set()
        for sol in newsolutions:
            variante = []
            for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
                variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
            variantes.add(tuple(variante))
        # rebuild variantes as dict
        variantes = [dict(variante) for variante in variantes]
        # remove variable which have always the same type
        for erqlexpr, mainvar, oldvar in self.rewritten:
            it = iter(variantes)
            etype = it.next()[(erqlexpr, mainvar, oldvar)]
            for variante in it:
                if variante[(erqlexpr, mainvar, oldvar)] != etype:
                    break
            else:
                for variante in variantes:
                    del variante[(erqlexpr, mainvar, oldvar)]
        return variantes
    
    def insert_snippets(self, snippets, varexistsmap=None):
        self.rewritten = {}
        for varname, erqlexprs in snippets:
            if varexistsmap is not None and not varname in varexistsmap:
                continue
            try:
                self.const = typed_eid(varname)
                self.varname = self.const
                self.rhs_rels = self.lhs_rels = {}
            except ValueError:
                self.varname = varname
                self.const = None
                self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
                if varexistsmap is None:
                    self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
                    self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
                                                  if not rel in stinfo['rhsrelations'])
                else:
                    self.rhs_rels = self.lhs_rels = {}
            parent = None
            inserted = False
            for erqlexpr in erqlexprs:
                self.current_expr = erqlexpr
                if varexistsmap is None:
                    try:
                        new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
                    except Unsupported:
                        continue
                    inserted = True
                    if new is not None:
                        self.exists_snippet[erqlexpr] = new
                    parent = parent or new
                else:
                    # called to reintroduce snippet due to ambiguity creation,
                    # so skip snippets which are not introducing this ambiguity
                    exists = varexistsmap[varname]
                    if self.exists_snippet[erqlexpr] is exists:
                        self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
            if varexistsmap is None and not inserted:
                # no rql expression found matching rql solutions. User has no access right
                raise Unauthorized()
            
    def insert_snippet(self, varname, snippetrqlst, parent=None):
        new = snippetrqlst.where.accept(self)
        if new is not None:
            try:
                var = self.select.defined_vars[varname]
            except KeyError:
                # not a variable
                pass
            else:
                if var.stinfo['optrelations']:
                    # use a subquery
                    subselect = stmts.Select()
                    subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
                    subselect.add_restriction(new.copy(subselect))
                    aliases = [varname]
                    for rel in var.stinfo['relations']:
                        rschema = self.schema.rschema(rel.r_type)
                        if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
                            self.select.remove_node(rel)
                            rel.children[0].name = varname
                            subselect.add_restriction(rel.copy(subselect))
                            for vref in rel.children[1].iget_nodes(nodes.VariableRef):
                                subselect.append_selected(vref.copy(subselect))
                                aliases.append(vref.name)
                    if self.u_varname:
                        # generate an identifier for the substitution
                        argname = subselect.allocate_varname()
                        while argname in self.kwargs:
                            argname = subselect.allocate_varname()
                        subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
                                                        'eid', unicode(argname), 'Substitute')
                        self.kwargs[argname] = self.session.user.eid
                    add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
                    assert parent is None
                    myunion = stmts.Union()
                    myunion.append(subselect)
                    aliases = [nodes.VariableRef(self.select.get_variable(name, i))
                               for i, name in enumerate(aliases)]
                    self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
                    self._cleanup_inserted(new)
                    try:
                        self.compute_solutions()
                    except Unsupported:
                        # some solutions have been lost, can't apply this rql expr
                        self.select.remove_subquery(new, undefine=True)
                        raise
                    return
            new = nodes.Exists(new)
            if parent is None:
                self.select.add_restriction(new)
            else:
                grandpa = parent.parent
                or_ = nodes.Or(parent, new)
                grandpa.replace(parent, or_)
            if not self.removing_ambiguity:
                try:
                    self.compute_solutions()
                except Unsupported:
                    # some solutions have been lost, can't apply this rql expr
                    if parent is None:
                        self.select.remove_node(new, undefine=True)
                    else:
                        parent.parent.replace(or_, or_.children[0])
                        self._cleanup_inserted(new)
                    raise 
            return new

    def _cleanup_inserted(self, node):
        # cleanup inserted variable references
        for vref in node.iget_nodes(nodes.VariableRef):
            vref.unregister_reference()
            if not vref.variable.stinfo['references']:
                # no more references, undefine the variable
                del self.select.defined_vars[vref.name]
        
    def _visit_binary(self, node, cls):
        newnode = cls()
        for c in node.children:
            new = c.accept(self)
            if new is None:
                continue
            newnode.append(new)
        if len(newnode.children) == 0:
            return None
        if len(newnode.children) == 1:
            return newnode.children[0]
        return newnode

    def _visit_unary(self, node, cls):
        newc = node.children[0].accept(self)
        if newc is None:
            return None
        newnode = cls()
        newnode.append(newc)
        return newnode 
        
    def visit_and(self, et):
        return self._visit_binary(et, nodes.And)

    def visit_or(self, ou):
        return self._visit_binary(ou, nodes.Or)
        
    def visit_not(self, node):
        return self._visit_unary(node, nodes.Not)

    def visit_exists(self, node):
        return self._visit_unary(node, nodes.Exists)
   
    def visit_relation(self, relation):
        lhs, rhs = relation.get_variable_parts()
        if lhs.name == 'X':
            # on lhs
            # see if we can reuse this relation
            if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
                if self._may_be_shared(relation, 'object'):
                    # ok, can share variable
                    term = self.lhs_rels[relation.r_type].children[1].children[0]
                    self._use_outer_term(rhs.name, term)
                    return
        elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
            # on rhs
            # see if we can reuse this relation
            if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
                # ok, can share variable
                term = self.rhs_rels[relation.r_type].children[0]
                self._use_outer_term(lhs.name, term)            
                return
        rel = nodes.Relation(relation.r_type, relation.optional)
        for c in relation.children:
            rel.append(c.accept(self))
        return rel

    def visit_comparison(self, cmp):
        cmp_ = nodes.Comparison(cmp.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_

    def visit_mathexpression(self, mexpr):
        cmp_ = nodes.MathExpression(cmp.operator)
        for c in cmp.children:
            cmp_.append(c.accept(self))
        return cmp_
        
    def visit_function(self, function):
        """generate filter name for a function"""
        function_ = nodes.Function(function.name)
        for c in function.children:
            function_.append(c.accept(self))
        return function_

    def visit_constant(self, constant):
        """generate filter name for a constant"""
        return nodes.Constant(constant.value, constant.type)

    def visit_variableref(self, vref):
        """get the sql name for a variable reference"""
        if vref.name == 'X':
            if self.const is not None:
                return nodes.Constant(self.const, 'Int')
            return nodes.VariableRef(self.select.get_variable(self.varname))
        vname_or_term = self._get_varname_or_term(vref.name)
        if isinstance(vname_or_term, basestring):
            return nodes.VariableRef(self.select.get_variable(vname_or_term))
        # shared term
        return vname_or_term.copy(self.select)

    def _may_be_shared(self, relation, target):
        """return True if the snippet relation can be skipped to use a relation
        from the original query
        """
        # if cardinality is in '?1', we can ignore the relation and use variable
        # from the original query
        rschema = self.schema.rschema(relation.r_type)
        if target == 'object':
            cardindex = 0
            ttypes_func = rschema.objects
            rprop = rschema.rproperty
        else: # target == 'subject':
            cardindex = 1
            ttypes_func = rschema.subjects
            rprop = lambda x,y,z: rschema.rproperty(y, x, z)
        for etype in self.varstinfo['possibletypes']:
            for ttype in ttypes_func(etype):
                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
                    return False
        return True

    def _use_outer_term(self, snippet_varname, term):
        key = (self.current_expr, self.varname, snippet_varname)
        if key in self.rewritten:
            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
            for inserted_vref in insertedvar.references():
                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
        self.rewritten[key] = term
        
    def _get_varname_or_term(self, vname):
        if vname == 'U':
            if self.u_varname is None:
                select = self.select
                self.u_varname = select.allocate_varname()
                # generate an identifier for the substitution
                argname = select.allocate_varname()
                while argname in self.kwargs:
                    argname = select.allocate_varname()
                # insert "U eid %(u)s"
                var = select.get_variable(self.u_varname)
                select.add_constant_restriction(select.get_variable(self.u_varname),
                                                'eid', unicode(argname), 'Substitute')
                self.kwargs[argname] = self.session.user.eid
            return self.u_varname
        key = (self.current_expr, self.varname, vname)
        try:
            return self.rewritten[key]
        except KeyError:
            self.rewritten[key] = newvname = self.select.allocate_varname()
            return newvname