server/rqlrewrite.py
branch3.5
changeset 3240 8604a15995d1
parent 3239 1ceac4cd4fb7
child 3241 1a6f7a0e7dbd
--- a/server/rqlrewrite.py	Wed Sep 16 14:17:12 2009 +0200
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,396 +0,0 @@
-"""RQL rewriting utilities, used for read security checking
-
-:organization: Logilab
-:copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
-:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
-:license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
-"""
-
-from rql import nodes, stmts, TypeResolverException
-from cubicweb import Unauthorized, server, typed_eid
-from cubicweb.server.ssplanner import add_types_restriction
-
-def remove_solutions(origsolutions, solutions, defined):
-    """when a rqlst has been generated from another by introducing security
-    assertions, this method returns solutions which are contained in orig
-    solutions
-    """
-    newsolutions = []
-    for origsol in origsolutions:
-        for newsol in solutions[:]:
-            for var, etype in origsol.items():
-                try:
-                    if newsol[var] != etype:
-                        try:
-                            defined[var].stinfo['possibletypes'].remove(newsol[var])
-                        except KeyError:
-                            pass
-                        break
-                except KeyError:
-                    # variable has been rewritten
-                    continue
-            else:
-                newsolutions.append(newsol)
-                solutions.remove(newsol)
-    return newsolutions
-
-class Unsupported(Exception): pass
-
-class RQLRewriter(object):
-    """insert some rql snippets into another rql syntax tree"""
-    def __init__(self, querier, session):
-        self.session = session
-        self.annotate = querier._rqlhelper.annotate
-        self._compute_solutions = querier.solutions
-        self.schema = querier.schema
-
-    def compute_solutions(self):
-        self.annotate(self.select)
-        try:
-            self._compute_solutions(self.session, self.select, self.kwargs)
-        except TypeResolverException:
-            raise Unsupported()
-        if len(self.select.solutions) < len(self.solutions):
-            raise Unsupported()
-
-    def rewrite(self, select, snippets, solutions, kwargs):
-        if server.DEBUG:
-            print '---- rewrite', select, snippets, solutions
-        self.select = select
-        self.solutions = solutions
-        self.kwargs = kwargs
-        self.u_varname = None
-        self.removing_ambiguity = False
-        self.exists_snippet = {}
-        # we have to annotate the rqlst before inserting snippets, even though
-        # we'll have to redo it latter
-        self.annotate(select)
-        self.insert_snippets(snippets)
-        if not self.exists_snippet and self.u_varname:
-            # U has been inserted than cancelled, cleanup
-            select.undefine_variable(select.defined_vars[self.u_varname])
-        # clean solutions according to initial solutions
-        newsolutions = remove_solutions(solutions, select.solutions,
-                                        select.defined_vars)
-        assert len(newsolutions) >= len(solutions), \
-               'rewritten rql %s has lost some solutions, there is probably something '\
-               'wrong in your schema permission (for instance using a '\
-              'RQLExpression which insert a relation which doesn\'t exists in '\
-               'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
-            select, solutions, newsolutions)
-        if len(newsolutions) > len(solutions):
-            # the snippet has introduced some ambiguities, we have to resolve them
-            # "manually"
-            variantes = self.build_variantes(newsolutions)
-            # insert "is" where necessary
-            varexistsmap = {}
-            self.removing_ambiguity = True
-            for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
-                varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
-                var = select.defined_vars[varname]
-                exists = var.references()[0].scope
-                exists.add_constant_restriction(var, 'is', etype, 'etype')
-                varexistsmap[mainvar] = exists
-            # insert ORED exists where necessary
-            for variante in variantes[1:]:
-                self.insert_snippets(snippets, varexistsmap)
-                for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
-                    varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
-                    try:
-                        var = select.defined_vars[varname]
-                    except KeyError:
-                        # not a newly inserted variable
-                        continue
-                    exists = var.references()[0].scope
-                    exists.add_constant_restriction(var, 'is', etype, 'etype')
-            # recompute solutions
-            #select.annotated = False # avoid assertion error
-            self.compute_solutions()
-            # clean solutions according to initial solutions
-            newsolutions = remove_solutions(solutions, select.solutions,
-                                            select.defined_vars)
-        select.solutions = newsolutions
-        add_types_restriction(self.schema, select)
-        if server.DEBUG:
-            print '---- rewriten', select
-
-    def build_variantes(self, newsolutions):
-        variantes = set()
-        for sol in newsolutions:
-            variante = []
-            for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
-                variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
-            variantes.add(tuple(variante))
-        # rebuild variantes as dict
-        variantes = [dict(variante) for variante in variantes]
-        # remove variable which have always the same type
-        for erqlexpr, mainvar, oldvar in self.rewritten:
-            it = iter(variantes)
-            etype = it.next()[(erqlexpr, mainvar, oldvar)]
-            for variante in it:
-                if variante[(erqlexpr, mainvar, oldvar)] != etype:
-                    break
-            else:
-                for variante in variantes:
-                    del variante[(erqlexpr, mainvar, oldvar)]
-        return variantes
-
-    def insert_snippets(self, snippets, varexistsmap=None):
-        self.rewritten = {}
-        for varname, erqlexprs in snippets:
-            if varexistsmap is not None and not varname in varexistsmap:
-                continue
-            try:
-                self.const = typed_eid(varname)
-                self.varname = self.const
-                self.rhs_rels = self.lhs_rels = {}
-            except ValueError:
-                self.varname = varname
-                self.const = None
-                self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
-                if varexistsmap is None:
-                    self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
-                    self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
-                                                  if not rel in stinfo['rhsrelations'])
-                else:
-                    self.rhs_rels = self.lhs_rels = {}
-            parent = None
-            inserted = False
-            for erqlexpr in erqlexprs:
-                self.current_expr = erqlexpr
-                if varexistsmap is None:
-                    try:
-                        new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
-                    except Unsupported:
-                        continue
-                    inserted = True
-                    if new is not None:
-                        self.exists_snippet[erqlexpr] = new
-                    parent = parent or new
-                else:
-                    # called to reintroduce snippet due to ambiguity creation,
-                    # so skip snippets which are not introducing this ambiguity
-                    exists = varexistsmap[varname]
-                    if self.exists_snippet[erqlexpr] is exists:
-                        self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
-            if varexistsmap is None and not inserted:
-                # no rql expression found matching rql solutions. User has no access right
-                raise Unauthorized()
-
-    def insert_snippet(self, varname, snippetrqlst, parent=None):
-        new = snippetrqlst.where.accept(self)
-        if new is not None:
-            try:
-                var = self.select.defined_vars[varname]
-            except KeyError:
-                # not a variable
-                pass
-            else:
-                if var.stinfo['optrelations']:
-                    # use a subquery
-                    subselect = stmts.Select()
-                    subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
-                    subselect.add_restriction(new.copy(subselect))
-                    aliases = [varname]
-                    for rel in var.stinfo['relations']:
-                        rschema = self.schema.rschema(rel.r_type)
-                        if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
-                            self.select.remove_node(rel)
-                            rel.children[0].name = varname
-                            subselect.add_restriction(rel.copy(subselect))
-                            for vref in rel.children[1].iget_nodes(nodes.VariableRef):
-                                subselect.append_selected(vref.copy(subselect))
-                                aliases.append(vref.name)
-                    if self.u_varname:
-                        # generate an identifier for the substitution
-                        argname = subselect.allocate_varname()
-                        while argname in self.kwargs:
-                            argname = subselect.allocate_varname()
-                        subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
-                                                        'eid', unicode(argname), 'Substitute')
-                        self.kwargs[argname] = self.session.user.eid
-                    add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
-                    assert parent is None
-                    myunion = stmts.Union()
-                    myunion.append(subselect)
-                    aliases = [nodes.VariableRef(self.select.get_variable(name, i))
-                               for i, name in enumerate(aliases)]
-                    self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
-                    self._cleanup_inserted(new)
-                    try:
-                        self.compute_solutions()
-                    except Unsupported:
-                        # some solutions have been lost, can't apply this rql expr
-                        self.select.remove_subquery(new, undefine=True)
-                        raise
-                    return
-            new = nodes.Exists(new)
-            if parent is None:
-                self.select.add_restriction(new)
-            else:
-                grandpa = parent.parent
-                or_ = nodes.Or(parent, new)
-                grandpa.replace(parent, or_)
-            if not self.removing_ambiguity:
-                try:
-                    self.compute_solutions()
-                except Unsupported:
-                    # some solutions have been lost, can't apply this rql expr
-                    if parent is None:
-                        self.select.remove_node(new, undefine=True)
-                    else:
-                        parent.parent.replace(or_, or_.children[0])
-                        self._cleanup_inserted(new)
-                    raise
-            return new
-
-    def _cleanup_inserted(self, node):
-        # cleanup inserted variable references
-        for vref in node.iget_nodes(nodes.VariableRef):
-            vref.unregister_reference()
-            if not vref.variable.stinfo['references']:
-                # no more references, undefine the variable
-                del self.select.defined_vars[vref.name]
-
-    def _visit_binary(self, node, cls):
-        newnode = cls()
-        for c in node.children:
-            new = c.accept(self)
-            if new is None:
-                continue
-            newnode.append(new)
-        if len(newnode.children) == 0:
-            return None
-        if len(newnode.children) == 1:
-            return newnode.children[0]
-        return newnode
-
-    def _visit_unary(self, node, cls):
-        newc = node.children[0].accept(self)
-        if newc is None:
-            return None
-        newnode = cls()
-        newnode.append(newc)
-        return newnode
-
-    def visit_and(self, et):
-        return self._visit_binary(et, nodes.And)
-
-    def visit_or(self, ou):
-        return self._visit_binary(ou, nodes.Or)
-
-    def visit_not(self, node):
-        return self._visit_unary(node, nodes.Not)
-
-    def visit_exists(self, node):
-        return self._visit_unary(node, nodes.Exists)
-
-    def visit_relation(self, relation):
-        lhs, rhs = relation.get_variable_parts()
-        if lhs.name == 'X':
-            # on lhs
-            # see if we can reuse this relation
-            if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
-                if self._may_be_shared(relation, 'object'):
-                    # ok, can share variable
-                    term = self.lhs_rels[relation.r_type].children[1].children[0]
-                    self._use_outer_term(rhs.name, term)
-                    return
-        elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
-            # on rhs
-            # see if we can reuse this relation
-            if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
-                # ok, can share variable
-                term = self.rhs_rels[relation.r_type].children[0]
-                self._use_outer_term(lhs.name, term)
-                return
-        rel = nodes.Relation(relation.r_type, relation.optional)
-        for c in relation.children:
-            rel.append(c.accept(self))
-        return rel
-
-    def visit_comparison(self, cmp):
-        cmp_ = nodes.Comparison(cmp.operator)
-        for c in cmp.children:
-            cmp_.append(c.accept(self))
-        return cmp_
-
-    def visit_mathexpression(self, mexpr):
-        cmp_ = nodes.MathExpression(mexpr.operator)
-        for c in cmp.children:
-            cmp_.append(c.accept(self))
-        return cmp_
-
-    def visit_function(self, function):
-        """generate filter name for a function"""
-        function_ = nodes.Function(function.name)
-        for c in function.children:
-            function_.append(c.accept(self))
-        return function_
-
-    def visit_constant(self, constant):
-        """generate filter name for a constant"""
-        return nodes.Constant(constant.value, constant.type)
-
-    def visit_variableref(self, vref):
-        """get the sql name for a variable reference"""
-        if vref.name == 'X':
-            if self.const is not None:
-                return nodes.Constant(self.const, 'Int')
-            return nodes.VariableRef(self.select.get_variable(self.varname))
-        vname_or_term = self._get_varname_or_term(vref.name)
-        if isinstance(vname_or_term, basestring):
-            return nodes.VariableRef(self.select.get_variable(vname_or_term))
-        # shared term
-        return vname_or_term.copy(self.select)
-
-    def _may_be_shared(self, relation, target):
-        """return True if the snippet relation can be skipped to use a relation
-        from the original query
-        """
-        # if cardinality is in '?1', we can ignore the relation and use variable
-        # from the original query
-        rschema = self.schema.rschema(relation.r_type)
-        if target == 'object':
-            cardindex = 0
-            ttypes_func = rschema.objects
-            rprop = rschema.rproperty
-        else: # target == 'subject':
-            cardindex = 1
-            ttypes_func = rschema.subjects
-            rprop = lambda x, y, z: rschema.rproperty(y, x, z)
-        for etype in self.varstinfo['possibletypes']:
-            for ttype in ttypes_func(etype):
-                if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
-                    return False
-        return True
-
-    def _use_outer_term(self, snippet_varname, term):
-        key = (self.current_expr, self.varname, snippet_varname)
-        if key in self.rewritten:
-            insertedvar = self.select.defined_vars.pop(self.rewritten[key])
-            for inserted_vref in insertedvar.references():
-                inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
-        self.rewritten[key] = term
-
-    def _get_varname_or_term(self, vname):
-        if vname == 'U':
-            if self.u_varname is None:
-                select = self.select
-                self.u_varname = select.allocate_varname()
-                # generate an identifier for the substitution
-                argname = select.allocate_varname()
-                while argname in self.kwargs:
-                    argname = select.allocate_varname()
-                # insert "U eid %(u)s"
-                var = select.get_variable(self.u_varname)
-                select.add_constant_restriction(select.get_variable(self.u_varname),
-                                                'eid', unicode(argname), 'Substitute')
-                self.kwargs[argname] = self.session.user.eid
-            return self.u_varname
-        key = (self.current_expr, self.varname, vname)
-        try:
-            return self.rewritten[key]
-        except KeyError:
-            self.rewritten[key] = newvname = self.select.allocate_varname()
-            return newvname