server/rqlrewrite.py
changeset 0 b97547f5f1fa
child 1132 96752791c2b6
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/server/rqlrewrite.py	Wed Nov 05 15:52:50 2008 +0100
@@ -0,0 +1,395 @@
+"""RQL rewriting utilities, used for read security checking
+
+:organization: Logilab
+:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+
+from rql import nodes, stmts, TypeResolverException
+from cubicweb import Unauthorized, server, typed_eid
+from cubicweb.server.ssplanner import add_types_restriction
+
+def remove_solutions(origsolutions, solutions, defined):
+    """when a rqlst has been generated from another by introducing security
+    assertions, this method returns solutions which are contained in orig
+    solutions
+    """
+    newsolutions = []
+    for origsol in origsolutions:
+        for newsol in solutions[:]:
+            for var, etype in origsol.items():
+                try:
+                    if newsol[var] != etype:
+                        try:
+                            defined[var].stinfo['possibletypes'].remove(newsol[var])
+                        except KeyError:
+                            pass
+                        break
+                except KeyError,ex:
+                    # variable has been rewritten
+                    continue
+            else:
+                newsolutions.append(newsol)
+                solutions.remove(newsol)
+    return newsolutions
+
+class Unsupported(Exception): pass
+        
+class RQLRewriter(object):
+    """insert some rql snippets into another rql syntax tree"""
+    def __init__(self, querier, session):
+        self.session = session
+        self.annotate = querier._rqlhelper.annotate
+        self._compute_solutions = querier.solutions
+        self.schema = querier.schema
+
+    def compute_solutions(self):
+        self.annotate(self.select)
+        try:
+            self._compute_solutions(self.session, self.select, self.kwargs)
+        except TypeResolverException:
+            raise Unsupported()
+        if len(self.select.solutions) < len(self.solutions):
+            raise Unsupported()
+        
+    def rewrite(self, select, snippets, solutions, kwargs):
+        if server.DEBUG:
+            print '---- rewrite', select, snippets, solutions
+        self.select = select
+        self.solutions = solutions
+        self.kwargs = kwargs
+        self.u_varname = None
+        self.removing_ambiguity = False
+        self.exists_snippet = {}
+        # we have to annotate the rqlst before inserting snippets, even though
+        # we'll have to redo it latter
+        self.annotate(select)
+        self.insert_snippets(snippets)
+        if not self.exists_snippet and self.u_varname:
+            # U has been inserted than cancelled, cleanup
+            select.undefine_variable(select.defined_vars[self.u_varname])
+        # clean solutions according to initial solutions
+        newsolutions = remove_solutions(solutions, select.solutions,
+                                        select.defined_vars)
+        assert len(newsolutions) >= len(solutions), \
+               'rewritten rql %s has lost some solutions, there is probably something '\
+               'wrong in your schema permission (for instance using a '\
+              'RQLExpression which insert a relation which doesn\'t exists in '\
+               'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
+            select, solutions, newsolutions)
+        if len(newsolutions) > len(solutions):
+            # the snippet has introduced some ambiguities, we have to resolve them
+            # "manually"
+            variantes = self.build_variantes(newsolutions)
+            # insert "is" where necessary
+            varexistsmap = {}
+            self.removing_ambiguity = True
+            for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
+                varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
+                var = select.defined_vars[varname]
+                exists = var.references()[0].scope
+                exists.add_constant_restriction(var, 'is', etype, 'etype')
+                varexistsmap[mainvar] = exists
+            # insert ORED exists where necessary
+            for variante in variantes[1:]:
+                self.insert_snippets(snippets, varexistsmap)
+                for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
+                    varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
+                    try:
+                        var = select.defined_vars[varname]
+                    except KeyError:
+                        # not a newly inserted variable
+                        continue
+                    exists = var.references()[0].scope
+                    exists.add_constant_restriction(var, 'is', etype, 'etype')
+            # recompute solutions
+            #select.annotated = False # avoid assertion error
+            self.compute_solutions()
+            # clean solutions according to initial solutions
+            newsolutions = remove_solutions(solutions, select.solutions,
+                                            select.defined_vars)
+        select.solutions = newsolutions
+        add_types_restriction(self.schema, select)
+        if server.DEBUG:
+            print '---- rewriten', select
+            
+    def build_variantes(self, newsolutions):
+        variantes = set()
+        for sol in newsolutions:
+            variante = []
+            for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
+                variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
+            variantes.add(tuple(variante))
+        # rebuild variantes as dict
+        variantes = [dict(variante) for variante in variantes]
+        # remove variable which have always the same type
+        for erqlexpr, mainvar, oldvar in self.rewritten:
+            it = iter(variantes)
+            etype = it.next()[(erqlexpr, mainvar, oldvar)]
+            for variante in it:
+                if variante[(erqlexpr, mainvar, oldvar)] != etype:
+                    break
+            else:
+                for variante in variantes:
+                    del variante[(erqlexpr, mainvar, oldvar)]
+        return variantes
+    
+    def insert_snippets(self, snippets, varexistsmap=None):
+        self.rewritten = {}
+        for varname, erqlexprs in snippets:
+            if varexistsmap is not None and not varname in varexistsmap:
+                continue
+            try:
+                self.const = typed_eid(varname)
+                self.varname = self.const
+                self.rhs_rels = self.lhs_rels = {}
+            except ValueError:
+                self.varname = varname
+                self.const = None
+                self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
+                if varexistsmap is None:
+                    self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
+                    self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
+                                                  if not rel in stinfo['rhsrelations'])
+                else:
+                    self.rhs_rels = self.lhs_rels = {}
+            parent = None
+            inserted = False
+            for erqlexpr in erqlexprs:
+                self.current_expr = erqlexpr
+                if varexistsmap is None:
+                    try:
+                        new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
+                    except Unsupported:
+                        continue
+                    inserted = True
+                    if new is not None:
+                        self.exists_snippet[erqlexpr] = new
+                    parent = parent or new
+                else:
+                    # called to reintroduce snippet due to ambiguity creation,
+                    # so skip snippets which are not introducing this ambiguity
+                    exists = varexistsmap[varname]
+                    if self.exists_snippet[erqlexpr] is exists:
+                        self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
+            if varexistsmap is None and not inserted:
+                # no rql expression found matching rql solutions. User has no access right
+                raise Unauthorized()
+            
+    def insert_snippet(self, varname, snippetrqlst, parent=None):
+        new = snippetrqlst.where.accept(self)
+        if new is not None:
+            try:
+                var = self.select.defined_vars[varname]
+            except KeyError:
+                # not a variable
+                pass
+            else:
+                if var.stinfo['optrelations']:
+                    # use a subquery
+                    subselect = stmts.Select()
+                    subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
+                    subselect.add_restriction(new.copy(subselect))
+                    aliases = [varname]
+                    for rel in var.stinfo['relations']:
+                        rschema = self.schema.rschema(rel.r_type)
+                        if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
+                            self.select.remove_node(rel)
+                            rel.children[0].name = varname
+                            subselect.add_restriction(rel.copy(subselect))
+                            for vref in rel.children[1].iget_nodes(nodes.VariableRef):
+                                subselect.append_selected(vref.copy(subselect))
+                                aliases.append(vref.name)
+                    if self.u_varname:
+                        # generate an identifier for the substitution
+                        argname = subselect.allocate_varname()
+                        while argname in self.kwargs:
+                            argname = subselect.allocate_varname()
+                        subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
+                                                        'eid', unicode(argname), 'Substitute')
+                        self.kwargs[argname] = self.session.user.eid
+                    add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
+                    assert parent is None
+                    myunion = stmts.Union()
+                    myunion.append(subselect)
+                    aliases = [nodes.VariableRef(self.select.get_variable(name, i))
+                               for i, name in enumerate(aliases)]
+                    self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
+                    self._cleanup_inserted(new)
+                    try:
+                        self.compute_solutions()
+                    except Unsupported:
+                        # some solutions have been lost, can't apply this rql expr
+                        self.select.remove_subquery(new, undefine=True)
+                        raise
+                    return
+            new = nodes.Exists(new)
+            if parent is None:
+                self.select.add_restriction(new)
+            else:
+                grandpa = parent.parent
+                or_ = nodes.Or(parent, new)
+                grandpa.replace(parent, or_)
+            if not self.removing_ambiguity:
+                try:
+                    self.compute_solutions()
+                except Unsupported:
+                    # some solutions have been lost, can't apply this rql expr
+                    if parent is None:
+                        self.select.remove_node(new, undefine=True)
+                    else:
+                        parent.parent.replace(or_, or_.children[0])
+                        self._cleanup_inserted(new)
+                    raise 
+            return new
+
+    def _cleanup_inserted(self, node):
+        # cleanup inserted variable references
+        for vref in node.iget_nodes(nodes.VariableRef):
+            vref.unregister_reference()
+            if not vref.variable.stinfo['references']:
+                # no more references, undefine the variable
+                del self.select.defined_vars[vref.name]
+        
+    def _visit_binary(self, node, cls):
+        newnode = cls()
+        for c in node.children:
+            new = c.accept(self)
+            if new is None:
+                continue
+            newnode.append(new)
+        if len(newnode.children) == 0:
+            return None
+        if len(newnode.children) == 1:
+            return newnode.children[0]
+        return newnode
+
+    def _visit_unary(self, node, cls):
+        newc = node.children[0].accept(self)
+        if newc is None:
+            return None
+        newnode = cls()
+        newnode.append(newc)
+        return newnode 
+        
+    def visit_and(self, et):
+        return self._visit_binary(et, nodes.And)
+
+    def visit_or(self, ou):
+        return self._visit_binary(ou, nodes.Or)
+        
+    def visit_not(self, node):
+        return self._visit_unary(node, nodes.Not)
+
+    def visit_exists(self, node):
+        return self._visit_unary(node, nodes.Exists)
+   
+    def visit_relation(self, relation):
+        lhs, rhs = relation.get_variable_parts()
+        if lhs.name == 'X':
+            # on lhs
+            # see if we can reuse this relation
+            if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
+                if self._may_be_shared(relation, 'object'):
+                    # ok, can share variable
+                    term = self.lhs_rels[relation.r_type].children[1].children[0]
+                    self._use_outer_term(rhs.name, term)
+                    return
+        elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
+            # on rhs
+            # see if we can reuse this relation
+            if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
+                # ok, can share variable
+                term = self.rhs_rels[relation.r_type].children[0]
+                self._use_outer_term(lhs.name, term)            
+                return
+        rel = nodes.Relation(relation.r_type, relation.optional)
+        for c in relation.children:
+            rel.append(c.accept(self))
+        return rel
+
+    def visit_comparison(self, cmp):
+        cmp_ = nodes.Comparison(cmp.operator)
+        for c in cmp.children:
+            cmp_.append(c.accept(self))
+        return cmp_
+
+    def visit_mathexpression(self, mexpr):
+        cmp_ = nodes.MathExpression(cmp.operator)
+        for c in cmp.children:
+            cmp_.append(c.accept(self))
+        return cmp_
+        
+    def visit_function(self, function):
+        """generate filter name for a function"""
+        function_ = nodes.Function(function.name)
+        for c in function.children:
+            function_.append(c.accept(self))
+        return function_
+
+    def visit_constant(self, constant):
+        """generate filter name for a constant"""
+        return nodes.Constant(constant.value, constant.type)
+
+    def visit_variableref(self, vref):
+        """get the sql name for a variable reference"""
+        if vref.name == 'X':
+            if self.const is not None:
+                return nodes.Constant(self.const, 'Int')
+            return nodes.VariableRef(self.select.get_variable(self.varname))
+        vname_or_term = self._get_varname_or_term(vref.name)
+        if isinstance(vname_or_term, basestring):
+            return nodes.VariableRef(self.select.get_variable(vname_or_term))
+        # shared term
+        return vname_or_term.copy(self.select)
+
+    def _may_be_shared(self, relation, target):
+        """return True if the snippet relation can be skipped to use a relation
+        from the original query
+        """
+        # if cardinality is in '?1', we can ignore the relation and use variable
+        # from the original query
+        rschema = self.schema.rschema(relation.r_type)
+        if target == 'object':
+            cardindex = 0
+            ttypes_func = rschema.objects
+            rprop = rschema.rproperty
+        else: # target == 'subject':
+            cardindex = 1
+            ttypes_func = rschema.subjects
+            rprop = lambda x,y,z: rschema.rproperty(y, x, z)
+        for etype in self.varstinfo['possibletypes']:
+            for ttype in ttypes_func(etype):
+                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
+                    return False
+        return True
+
+    def _use_outer_term(self, snippet_varname, term):
+        key = (self.current_expr, self.varname, snippet_varname)
+        if key in self.rewritten:
+            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
+            for inserted_vref in insertedvar.references():
+                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
+        self.rewritten[key] = term
+        
+    def _get_varname_or_term(self, vname):
+        if vname == 'U':
+            if self.u_varname is None:
+                select = self.select
+                self.u_varname = select.allocate_varname()
+                # generate an identifier for the substitution
+                argname = select.allocate_varname()
+                while argname in self.kwargs:
+                    argname = select.allocate_varname()
+                # insert "U eid %(u)s"
+                var = select.get_variable(self.u_varname)
+                select.add_constant_restriction(select.get_variable(self.u_varname),
+                                                'eid', unicode(argname), 'Substitute')
+                self.kwargs[argname] = self.session.user.eid
+            return self.u_varname
+        key = (self.current_expr, self.varname, vname)
+        try:
+            return self.rewritten[key]
+        except KeyError:
+            self.rewritten[key] = newvname = self.select.allocate_varname()
+            return newvname