--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/rqlrewrite.py Wed Nov 05 15:52:50 2008 +0100
@@ -0,0 +1,395 @@
+"""RQL rewriting utilities, used for read security checking
+
+:organization: Logilab
+:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+
+from rql import nodes, stmts, TypeResolverException
+from cubicweb import Unauthorized, server, typed_eid
+from cubicweb.server.ssplanner import add_types_restriction
+
+def remove_solutions(origsolutions, solutions, defined):
+ """when a rqlst has been generated from another by introducing security
+ assertions, this method returns solutions which are contained in orig
+ solutions
+ """
+ newsolutions = []
+ for origsol in origsolutions:
+ for newsol in solutions[:]:
+ for var, etype in origsol.items():
+ try:
+ if newsol[var] != etype:
+ try:
+ defined[var].stinfo['possibletypes'].remove(newsol[var])
+ except KeyError:
+ pass
+ break
+ except KeyError,ex:
+ # variable has been rewritten
+ continue
+ else:
+ newsolutions.append(newsol)
+ solutions.remove(newsol)
+ return newsolutions
+
+class Unsupported(Exception): pass
+
+class RQLRewriter(object):
+ """insert some rql snippets into another rql syntax tree"""
+ def __init__(self, querier, session):
+ self.session = session
+ self.annotate = querier._rqlhelper.annotate
+ self._compute_solutions = querier.solutions
+ self.schema = querier.schema
+
+ def compute_solutions(self):
+ self.annotate(self.select)
+ try:
+ self._compute_solutions(self.session, self.select, self.kwargs)
+ except TypeResolverException:
+ raise Unsupported()
+ if len(self.select.solutions) < len(self.solutions):
+ raise Unsupported()
+
+ def rewrite(self, select, snippets, solutions, kwargs):
+ if server.DEBUG:
+ print '---- rewrite', select, snippets, solutions
+ self.select = select
+ self.solutions = solutions
+ self.kwargs = kwargs
+ self.u_varname = None
+ self.removing_ambiguity = False
+ self.exists_snippet = {}
+ # we have to annotate the rqlst before inserting snippets, even though
+ # we'll have to redo it latter
+ self.annotate(select)
+ self.insert_snippets(snippets)
+ if not self.exists_snippet and self.u_varname:
+ # U has been inserted than cancelled, cleanup
+ select.undefine_variable(select.defined_vars[self.u_varname])
+ # clean solutions according to initial solutions
+ newsolutions = remove_solutions(solutions, select.solutions,
+ select.defined_vars)
+ assert len(newsolutions) >= len(solutions), \
+ 'rewritten rql %s has lost some solutions, there is probably something '\
+ 'wrong in your schema permission (for instance using a '\
+ 'RQLExpression which insert a relation which doesn\'t exists in '\
+ 'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
+ select, solutions, newsolutions)
+ if len(newsolutions) > len(solutions):
+ # the snippet has introduced some ambiguities, we have to resolve them
+ # "manually"
+ variantes = self.build_variantes(newsolutions)
+ # insert "is" where necessary
+ varexistsmap = {}
+ self.removing_ambiguity = True
+ for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
+ varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
+ var = select.defined_vars[varname]
+ exists = var.references()[0].scope
+ exists.add_constant_restriction(var, 'is', etype, 'etype')
+ varexistsmap[mainvar] = exists
+ # insert ORED exists where necessary
+ for variante in variantes[1:]:
+ self.insert_snippets(snippets, varexistsmap)
+ for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
+ varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
+ try:
+ var = select.defined_vars[varname]
+ except KeyError:
+ # not a newly inserted variable
+ continue
+ exists = var.references()[0].scope
+ exists.add_constant_restriction(var, 'is', etype, 'etype')
+ # recompute solutions
+ #select.annotated = False # avoid assertion error
+ self.compute_solutions()
+ # clean solutions according to initial solutions
+ newsolutions = remove_solutions(solutions, select.solutions,
+ select.defined_vars)
+ select.solutions = newsolutions
+ add_types_restriction(self.schema, select)
+ if server.DEBUG:
+ print '---- rewriten', select
+
+ def build_variantes(self, newsolutions):
+ variantes = set()
+ for sol in newsolutions:
+ variante = []
+ for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
+ variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
+ variantes.add(tuple(variante))
+ # rebuild variantes as dict
+ variantes = [dict(variante) for variante in variantes]
+ # remove variable which have always the same type
+ for erqlexpr, mainvar, oldvar in self.rewritten:
+ it = iter(variantes)
+ etype = it.next()[(erqlexpr, mainvar, oldvar)]
+ for variante in it:
+ if variante[(erqlexpr, mainvar, oldvar)] != etype:
+ break
+ else:
+ for variante in variantes:
+ del variante[(erqlexpr, mainvar, oldvar)]
+ return variantes
+
+ def insert_snippets(self, snippets, varexistsmap=None):
+ self.rewritten = {}
+ for varname, erqlexprs in snippets:
+ if varexistsmap is not None and not varname in varexistsmap:
+ continue
+ try:
+ self.const = typed_eid(varname)
+ self.varname = self.const
+ self.rhs_rels = self.lhs_rels = {}
+ except ValueError:
+ self.varname = varname
+ self.const = None
+ self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
+ if varexistsmap is None:
+ self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
+ self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
+ if not rel in stinfo['rhsrelations'])
+ else:
+ self.rhs_rels = self.lhs_rels = {}
+ parent = None
+ inserted = False
+ for erqlexpr in erqlexprs:
+ self.current_expr = erqlexpr
+ if varexistsmap is None:
+ try:
+ new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
+ except Unsupported:
+ continue
+ inserted = True
+ if new is not None:
+ self.exists_snippet[erqlexpr] = new
+ parent = parent or new
+ else:
+ # called to reintroduce snippet due to ambiguity creation,
+ # so skip snippets which are not introducing this ambiguity
+ exists = varexistsmap[varname]
+ if self.exists_snippet[erqlexpr] is exists:
+ self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
+ if varexistsmap is None and not inserted:
+ # no rql expression found matching rql solutions. User has no access right
+ raise Unauthorized()
+
+ def insert_snippet(self, varname, snippetrqlst, parent=None):
+ new = snippetrqlst.where.accept(self)
+ if new is not None:
+ try:
+ var = self.select.defined_vars[varname]
+ except KeyError:
+ # not a variable
+ pass
+ else:
+ if var.stinfo['optrelations']:
+ # use a subquery
+ subselect = stmts.Select()
+ subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
+ subselect.add_restriction(new.copy(subselect))
+ aliases = [varname]
+ for rel in var.stinfo['relations']:
+ rschema = self.schema.rschema(rel.r_type)
+ if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
+ self.select.remove_node(rel)
+ rel.children[0].name = varname
+ subselect.add_restriction(rel.copy(subselect))
+ for vref in rel.children[1].iget_nodes(nodes.VariableRef):
+ subselect.append_selected(vref.copy(subselect))
+ aliases.append(vref.name)
+ if self.u_varname:
+ # generate an identifier for the substitution
+ argname = subselect.allocate_varname()
+ while argname in self.kwargs:
+ argname = subselect.allocate_varname()
+ subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
+ 'eid', unicode(argname), 'Substitute')
+ self.kwargs[argname] = self.session.user.eid
+ add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
+ assert parent is None
+ myunion = stmts.Union()
+ myunion.append(subselect)
+ aliases = [nodes.VariableRef(self.select.get_variable(name, i))
+ for i, name in enumerate(aliases)]
+ self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
+ self._cleanup_inserted(new)
+ try:
+ self.compute_solutions()
+ except Unsupported:
+ # some solutions have been lost, can't apply this rql expr
+ self.select.remove_subquery(new, undefine=True)
+ raise
+ return
+ new = nodes.Exists(new)
+ if parent is None:
+ self.select.add_restriction(new)
+ else:
+ grandpa = parent.parent
+ or_ = nodes.Or(parent, new)
+ grandpa.replace(parent, or_)
+ if not self.removing_ambiguity:
+ try:
+ self.compute_solutions()
+ except Unsupported:
+ # some solutions have been lost, can't apply this rql expr
+ if parent is None:
+ self.select.remove_node(new, undefine=True)
+ else:
+ parent.parent.replace(or_, or_.children[0])
+ self._cleanup_inserted(new)
+ raise
+ return new
+
+ def _cleanup_inserted(self, node):
+ # cleanup inserted variable references
+ for vref in node.iget_nodes(nodes.VariableRef):
+ vref.unregister_reference()
+ if not vref.variable.stinfo['references']:
+ # no more references, undefine the variable
+ del self.select.defined_vars[vref.name]
+
+ def _visit_binary(self, node, cls):
+ newnode = cls()
+ for c in node.children:
+ new = c.accept(self)
+ if new is None:
+ continue
+ newnode.append(new)
+ if len(newnode.children) == 0:
+ return None
+ if len(newnode.children) == 1:
+ return newnode.children[0]
+ return newnode
+
+ def _visit_unary(self, node, cls):
+ newc = node.children[0].accept(self)
+ if newc is None:
+ return None
+ newnode = cls()
+ newnode.append(newc)
+ return newnode
+
+ def visit_and(self, et):
+ return self._visit_binary(et, nodes.And)
+
+ def visit_or(self, ou):
+ return self._visit_binary(ou, nodes.Or)
+
+ def visit_not(self, node):
+ return self._visit_unary(node, nodes.Not)
+
+ def visit_exists(self, node):
+ return self._visit_unary(node, nodes.Exists)
+
+ def visit_relation(self, relation):
+ lhs, rhs = relation.get_variable_parts()
+ if lhs.name == 'X':
+ # on lhs
+ # see if we can reuse this relation
+ if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
+ if self._may_be_shared(relation, 'object'):
+ # ok, can share variable
+ term = self.lhs_rels[relation.r_type].children[1].children[0]
+ self._use_outer_term(rhs.name, term)
+ return
+ elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
+ # on rhs
+ # see if we can reuse this relation
+ if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
+ # ok, can share variable
+ term = self.rhs_rels[relation.r_type].children[0]
+ self._use_outer_term(lhs.name, term)
+ return
+ rel = nodes.Relation(relation.r_type, relation.optional)
+ for c in relation.children:
+ rel.append(c.accept(self))
+ return rel
+
+ def visit_comparison(self, cmp):
+ cmp_ = nodes.Comparison(cmp.operator)
+ for c in cmp.children:
+ cmp_.append(c.accept(self))
+ return cmp_
+
+ def visit_mathexpression(self, mexpr):
+ cmp_ = nodes.MathExpression(cmp.operator)
+ for c in cmp.children:
+ cmp_.append(c.accept(self))
+ return cmp_
+
+ def visit_function(self, function):
+ """generate filter name for a function"""
+ function_ = nodes.Function(function.name)
+ for c in function.children:
+ function_.append(c.accept(self))
+ return function_
+
+ def visit_constant(self, constant):
+ """generate filter name for a constant"""
+ return nodes.Constant(constant.value, constant.type)
+
+ def visit_variableref(self, vref):
+ """get the sql name for a variable reference"""
+ if vref.name == 'X':
+ if self.const is not None:
+ return nodes.Constant(self.const, 'Int')
+ return nodes.VariableRef(self.select.get_variable(self.varname))
+ vname_or_term = self._get_varname_or_term(vref.name)
+ if isinstance(vname_or_term, basestring):
+ return nodes.VariableRef(self.select.get_variable(vname_or_term))
+ # shared term
+ return vname_or_term.copy(self.select)
+
+ def _may_be_shared(self, relation, target):
+ """return True if the snippet relation can be skipped to use a relation
+ from the original query
+ """
+ # if cardinality is in '?1', we can ignore the relation and use variable
+ # from the original query
+ rschema = self.schema.rschema(relation.r_type)
+ if target == 'object':
+ cardindex = 0
+ ttypes_func = rschema.objects
+ rprop = rschema.rproperty
+ else: # target == 'subject':
+ cardindex = 1
+ ttypes_func = rschema.subjects
+ rprop = lambda x,y,z: rschema.rproperty(y, x, z)
+ for etype in self.varstinfo['possibletypes']:
+ for ttype in ttypes_func(etype):
+ if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
+ return False
+ return True
+
+ def _use_outer_term(self, snippet_varname, term):
+ key = (self.current_expr, self.varname, snippet_varname)
+ if key in self.rewritten:
+ insertedvar = self.select.defined_vars.pop(self.rewritten[key])
+ for inserted_vref in insertedvar.references():
+ inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
+ self.rewritten[key] = term
+
+ def _get_varname_or_term(self, vname):
+ if vname == 'U':
+ if self.u_varname is None:
+ select = self.select
+ self.u_varname = select.allocate_varname()
+ # generate an identifier for the substitution
+ argname = select.allocate_varname()
+ while argname in self.kwargs:
+ argname = select.allocate_varname()
+ # insert "U eid %(u)s"
+ var = select.get_variable(self.u_varname)
+ select.add_constant_restriction(select.get_variable(self.u_varname),
+ 'eid', unicode(argname), 'Substitute')
+ self.kwargs[argname] = self.session.user.eid
+ return self.u_varname
+ key = (self.current_expr, self.varname, vname)
+ try:
+ return self.rewritten[key]
+ except KeyError:
+ self.rewritten[key] = newvname = self.select.allocate_varname()
+ return newvname