diff -r 000000000000 -r b97547f5f1fa server/rqlrewrite.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/rqlrewrite.py Wed Nov 05 15:52:50 2008 +0100 @@ -0,0 +1,395 @@ +"""RQL rewriting utilities, used for read security checking + +:organization: Logilab +:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" + +from rql import nodes, stmts, TypeResolverException +from cubicweb import Unauthorized, server, typed_eid +from cubicweb.server.ssplanner import add_types_restriction + +def remove_solutions(origsolutions, solutions, defined): + """when a rqlst has been generated from another by introducing security + assertions, this method returns solutions which are contained in orig + solutions + """ + newsolutions = [] + for origsol in origsolutions: + for newsol in solutions[:]: + for var, etype in origsol.items(): + try: + if newsol[var] != etype: + try: + defined[var].stinfo['possibletypes'].remove(newsol[var]) + except KeyError: + pass + break + except KeyError,ex: + # variable has been rewritten + continue + else: + newsolutions.append(newsol) + solutions.remove(newsol) + return newsolutions + +class Unsupported(Exception): pass + +class RQLRewriter(object): + """insert some rql snippets into another rql syntax tree""" + def __init__(self, querier, session): + self.session = session + self.annotate = querier._rqlhelper.annotate + self._compute_solutions = querier.solutions + self.schema = querier.schema + + def compute_solutions(self): + self.annotate(self.select) + try: + self._compute_solutions(self.session, self.select, self.kwargs) + except TypeResolverException: + raise Unsupported() + if len(self.select.solutions) < len(self.solutions): + raise Unsupported() + + def rewrite(self, select, snippets, solutions, kwargs): + if server.DEBUG: + print '---- rewrite', select, snippets, solutions + self.select = select + self.solutions = solutions + self.kwargs = kwargs + self.u_varname = None + self.removing_ambiguity = False + self.exists_snippet = {} + # we have to annotate the rqlst before inserting snippets, even though + # we'll have to redo it latter + self.annotate(select) + self.insert_snippets(snippets) + if not self.exists_snippet and self.u_varname: + # U has been inserted than cancelled, cleanup + select.undefine_variable(select.defined_vars[self.u_varname]) + # clean solutions according to initial solutions + newsolutions = remove_solutions(solutions, select.solutions, + select.defined_vars) + assert len(newsolutions) >= len(solutions), \ + 'rewritten rql %s has lost some solutions, there is probably something '\ + 'wrong in your schema permission (for instance using a '\ + 'RQLExpression which insert a relation which doesn\'t exists in '\ + 'the schema)\nOrig solutions: %s\nnew solutions: %s' % ( + select, solutions, newsolutions) + if len(newsolutions) > len(solutions): + # the snippet has introduced some ambiguities, we have to resolve them + # "manually" + variantes = self.build_variantes(newsolutions) + # insert "is" where necessary + varexistsmap = {} + self.removing_ambiguity = True + for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems(): + varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] + var = select.defined_vars[varname] + exists = var.references()[0].scope + exists.add_constant_restriction(var, 'is', etype, 'etype') + varexistsmap[mainvar] = exists + # insert ORED exists where necessary + for variante in variantes[1:]: + self.insert_snippets(snippets, varexistsmap) + for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems(): + varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] + try: + var = select.defined_vars[varname] + except KeyError: + # not a newly inserted variable + continue + exists = var.references()[0].scope + exists.add_constant_restriction(var, 'is', etype, 'etype') + # recompute solutions + #select.annotated = False # avoid assertion error + self.compute_solutions() + # clean solutions according to initial solutions + newsolutions = remove_solutions(solutions, select.solutions, + select.defined_vars) + select.solutions = newsolutions + add_types_restriction(self.schema, select) + if server.DEBUG: + print '---- rewriten', select + + def build_variantes(self, newsolutions): + variantes = set() + for sol in newsolutions: + variante = [] + for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems(): + variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) ) + variantes.add(tuple(variante)) + # rebuild variantes as dict + variantes = [dict(variante) for variante in variantes] + # remove variable which have always the same type + for erqlexpr, mainvar, oldvar in self.rewritten: + it = iter(variantes) + etype = it.next()[(erqlexpr, mainvar, oldvar)] + for variante in it: + if variante[(erqlexpr, mainvar, oldvar)] != etype: + break + else: + for variante in variantes: + del variante[(erqlexpr, mainvar, oldvar)] + return variantes + + def insert_snippets(self, snippets, varexistsmap=None): + self.rewritten = {} + for varname, erqlexprs in snippets: + if varexistsmap is not None and not varname in varexistsmap: + continue + try: + self.const = typed_eid(varname) + self.varname = self.const + self.rhs_rels = self.lhs_rels = {} + except ValueError: + self.varname = varname + self.const = None + self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo + if varexistsmap is None: + self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations']) + self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations'] + if not rel in stinfo['rhsrelations']) + else: + self.rhs_rels = self.lhs_rels = {} + parent = None + inserted = False + for erqlexpr in erqlexprs: + self.current_expr = erqlexpr + if varexistsmap is None: + try: + new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent) + except Unsupported: + continue + inserted = True + if new is not None: + self.exists_snippet[erqlexpr] = new + parent = parent or new + else: + # called to reintroduce snippet due to ambiguity creation, + # so skip snippets which are not introducing this ambiguity + exists = varexistsmap[varname] + if self.exists_snippet[erqlexpr] is exists: + self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists) + if varexistsmap is None and not inserted: + # no rql expression found matching rql solutions. User has no access right + raise Unauthorized() + + def insert_snippet(self, varname, snippetrqlst, parent=None): + new = snippetrqlst.where.accept(self) + if new is not None: + try: + var = self.select.defined_vars[varname] + except KeyError: + # not a variable + pass + else: + if var.stinfo['optrelations']: + # use a subquery + subselect = stmts.Select() + subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname))) + subselect.add_restriction(new.copy(subselect)) + aliases = [varname] + for rel in var.stinfo['relations']: + rschema = self.schema.rschema(rel.r_type) + if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']): + self.select.remove_node(rel) + rel.children[0].name = varname + subselect.add_restriction(rel.copy(subselect)) + for vref in rel.children[1].iget_nodes(nodes.VariableRef): + subselect.append_selected(vref.copy(subselect)) + aliases.append(vref.name) + if self.u_varname: + # generate an identifier for the substitution + argname = subselect.allocate_varname() + while argname in self.kwargs: + argname = subselect.allocate_varname() + subselect.add_constant_restriction(subselect.get_variable(self.u_varname), + 'eid', unicode(argname), 'Substitute') + self.kwargs[argname] = self.session.user.eid + add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions) + assert parent is None + myunion = stmts.Union() + myunion.append(subselect) + aliases = [nodes.VariableRef(self.select.get_variable(name, i)) + for i, name in enumerate(aliases)] + self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False) + self._cleanup_inserted(new) + try: + self.compute_solutions() + except Unsupported: + # some solutions have been lost, can't apply this rql expr + self.select.remove_subquery(new, undefine=True) + raise + return + new = nodes.Exists(new) + if parent is None: + self.select.add_restriction(new) + else: + grandpa = parent.parent + or_ = nodes.Or(parent, new) + grandpa.replace(parent, or_) + if not self.removing_ambiguity: + try: + self.compute_solutions() + except Unsupported: + # some solutions have been lost, can't apply this rql expr + if parent is None: + self.select.remove_node(new, undefine=True) + else: + parent.parent.replace(or_, or_.children[0]) + self._cleanup_inserted(new) + raise + return new + + def _cleanup_inserted(self, node): + # cleanup inserted variable references + for vref in node.iget_nodes(nodes.VariableRef): + vref.unregister_reference() + if not vref.variable.stinfo['references']: + # no more references, undefine the variable + del self.select.defined_vars[vref.name] + + def _visit_binary(self, node, cls): + newnode = cls() + for c in node.children: + new = c.accept(self) + if new is None: + continue + newnode.append(new) + if len(newnode.children) == 0: + return None + if len(newnode.children) == 1: + return newnode.children[0] + return newnode + + def _visit_unary(self, node, cls): + newc = node.children[0].accept(self) + if newc is None: + return None + newnode = cls() + newnode.append(newc) + return newnode + + def visit_and(self, et): + return self._visit_binary(et, nodes.And) + + def visit_or(self, ou): + return self._visit_binary(ou, nodes.Or) + + def visit_not(self, node): + return self._visit_unary(node, nodes.Not) + + def visit_exists(self, node): + return self._visit_unary(node, nodes.Exists) + + def visit_relation(self, relation): + lhs, rhs = relation.get_variable_parts() + if lhs.name == 'X': + # on lhs + # see if we can reuse this relation + if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U': + if self._may_be_shared(relation, 'object'): + # ok, can share variable + term = self.lhs_rels[relation.r_type].children[1].children[0] + self._use_outer_term(rhs.name, term) + return + elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U': + # on rhs + # see if we can reuse this relation + if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'): + # ok, can share variable + term = self.rhs_rels[relation.r_type].children[0] + self._use_outer_term(lhs.name, term) + return + rel = nodes.Relation(relation.r_type, relation.optional) + for c in relation.children: + rel.append(c.accept(self)) + return rel + + def visit_comparison(self, cmp): + cmp_ = nodes.Comparison(cmp.operator) + for c in cmp.children: + cmp_.append(c.accept(self)) + return cmp_ + + def visit_mathexpression(self, mexpr): + cmp_ = nodes.MathExpression(cmp.operator) + for c in cmp.children: + cmp_.append(c.accept(self)) + return cmp_ + + def visit_function(self, function): + """generate filter name for a function""" + function_ = nodes.Function(function.name) + for c in function.children: + function_.append(c.accept(self)) + return function_ + + def visit_constant(self, constant): + """generate filter name for a constant""" + return nodes.Constant(constant.value, constant.type) + + def visit_variableref(self, vref): + """get the sql name for a variable reference""" + if vref.name == 'X': + if self.const is not None: + return nodes.Constant(self.const, 'Int') + return nodes.VariableRef(self.select.get_variable(self.varname)) + vname_or_term = self._get_varname_or_term(vref.name) + if isinstance(vname_or_term, basestring): + return nodes.VariableRef(self.select.get_variable(vname_or_term)) + # shared term + return vname_or_term.copy(self.select) + + def _may_be_shared(self, relation, target): + """return True if the snippet relation can be skipped to use a relation + from the original query + """ + # if cardinality is in '?1', we can ignore the relation and use variable + # from the original query + rschema = self.schema.rschema(relation.r_type) + if target == 'object': + cardindex = 0 + ttypes_func = rschema.objects + rprop = rschema.rproperty + else: # target == 'subject': + cardindex = 1 + ttypes_func = rschema.subjects + rprop = lambda x,y,z: rschema.rproperty(y, x, z) + for etype in self.varstinfo['possibletypes']: + for ttype in ttypes_func(etype): + if rprop(etype, ttype, 'cardinality')[cardindex] in '+*': + return False + return True + + def _use_outer_term(self, snippet_varname, term): + key = (self.current_expr, self.varname, snippet_varname) + if key in self.rewritten: + insertedvar = self.select.defined_vars.pop(self.rewritten[key]) + for inserted_vref in insertedvar.references(): + inserted_vref.parent.replace(inserted_vref, term.copy(self.select)) + self.rewritten[key] = term + + def _get_varname_or_term(self, vname): + if vname == 'U': + if self.u_varname is None: + select = self.select + self.u_varname = select.allocate_varname() + # generate an identifier for the substitution + argname = select.allocate_varname() + while argname in self.kwargs: + argname = select.allocate_varname() + # insert "U eid %(u)s" + var = select.get_variable(self.u_varname) + select.add_constant_restriction(select.get_variable(self.u_varname), + 'eid', unicode(argname), 'Substitute') + self.kwargs[argname] = self.session.user.eid + return self.u_varname + key = (self.current_expr, self.varname, vname) + try: + return self.rewritten[key] + except KeyError: + self.rewritten[key] = newvname = self.select.allocate_varname() + return newvname