diff -r 058bb3dc685f -r 0b59724cb3f2 cubicweb/rqlrewrite.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/cubicweb/rqlrewrite.py Sat Jan 16 13:48:51 2016 +0100 @@ -0,0 +1,933 @@ +# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of CubicWeb. +# +# CubicWeb is free software: you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) +# any later version. +# +# CubicWeb is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with CubicWeb. If not, see . +"""RQL rewriting utilities : insert rql expression snippets into rql syntax +tree. + +This is used for instance for read security checking in the repository. +""" +__docformat__ = "restructuredtext en" + +from six import text_type, string_types + +from rql import nodes as n, stmts, TypeResolverException +from rql.utils import common_parent + +from yams import BadSchemaDefinition + +from logilab.common import tempattr +from logilab.common.graph import has_path + +from cubicweb import Unauthorized +from cubicweb.schema import RRQLExpression + +def cleanup_solutions(rqlst, solutions): + for sol in solutions: + for vname in list(sol): + if not (vname in rqlst.defined_vars or vname in rqlst.aliases): + del sol[vname] + + +def add_types_restriction(schema, rqlst, newroot=None, solutions=None): + if newroot is None: + assert solutions is None + if hasattr(rqlst, '_types_restr_added'): + return + solutions = rqlst.solutions + newroot = rqlst + rqlst._types_restr_added = True + else: + assert solutions is not None + rqlst = rqlst.stmt + eschema = schema.eschema + allpossibletypes = {} + for solution in solutions: + for varname, etype in solution.items(): + # XXX not considering aliases by design, right ? + if varname not in newroot.defined_vars or eschema(etype).final: + continue + allpossibletypes.setdefault(varname, set()).add(etype) + # XXX could be factorized with add_etypes_restriction from rql 0.31 + for varname in sorted(allpossibletypes): + var = newroot.defined_vars[varname] + stinfo = var.stinfo + if stinfo.get('uidrel') is not None: + continue # eid specified, no need for additional type specification + try: + typerel = rqlst.defined_vars[varname].stinfo.get('typerel') + except KeyError: + assert varname in rqlst.aliases + continue + if newroot is rqlst and typerel is not None: + mytyperel = typerel + else: + for vref in var.references(): + rel = vref.relation() + if rel and rel.is_types_restriction(): + mytyperel = rel + break + else: + mytyperel = None + possibletypes = allpossibletypes[varname] + if mytyperel is not None: + if mytyperel.r_type == 'is_instance_of': + # turn is_instance_of relation into a is relation since we've + # all possible solutions and don't want to bother with + # potential is_instance_of incompatibility + mytyperel.r_type = 'is' + if len(possibletypes) > 1: + node = n.Function('IN') + for etype in sorted(possibletypes): + node.append(n.Constant(etype, 'etype')) + else: + etype = next(iter(possibletypes)) + node = n.Constant(etype, 'etype') + comp = mytyperel.children[1] + comp.replace(comp.children[0], node) + else: + # variable has already some strict types restriction. new + # possible types can only be a subset of existing ones, so only + # remove no more possible types + for cst in mytyperel.get_nodes(n.Constant): + if not cst.value in possibletypes: + cst.parent.remove(cst) + else: + # we have to add types restriction + if stinfo.get('scope') is not None: + rel = var.scope.add_type_restriction(var, possibletypes) + else: + # tree is not annotated yet, no scope set so add the restriction + # to the root + rel = newroot.add_type_restriction(var, possibletypes) + stinfo['typerel'] = rel + stinfo['possibletypes'] = possibletypes + + +def remove_solutions(origsolutions, solutions, defined): + """when a rqlst has been generated from another by introducing security + assertions, this method returns solutions which are contained in orig + solutions + """ + newsolutions = [] + for origsol in origsolutions: + for newsol in solutions[:]: + for var, etype in origsol.items(): + try: + if newsol[var] != etype: + try: + defined[var].stinfo['possibletypes'].remove(newsol[var]) + except KeyError: + pass + break + except KeyError: + # variable has been rewritten + continue + else: + newsolutions.append(newsol) + solutions.remove(newsol) + return newsolutions + + +def _add_noinvariant(noinvariant, restricted, select, nbtrees): + # a variable can actually be invariant if it has not been restricted for + # security reason or if security assertion hasn't modified the possible + # solutions for the query + for vname in restricted: + try: + var = select.defined_vars[vname] + except KeyError: + # this is an alias + continue + if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1: + noinvariant.add(var) + + +def _expand_selection(terms, selected, aliases, select, newselect): + for term in terms: + for vref in term.iget_nodes(n.VariableRef): + if not vref.name in selected: + select.append_selected(vref) + colalias = newselect.get_variable(vref.name, len(aliases)) + aliases.append(n.VariableRef(colalias)) + selected.add(vref.name) + +def _has_multiple_cardinality(etypes, rdef, ttypes_func, cardindex): + """return True if relation definitions from entity types (`etypes`) to + target types returned by the `ttypes_func` function all have single (1 or ?) + cardinality. + """ + for etype in etypes: + for ttype in ttypes_func(etype): + if rdef(etype, ttype).cardinality[cardindex] in '+*': + return True + return False + +def _compatible_relation(relations, stmt, sniprel): + """Search among given rql relation nodes if there is one 'compatible' with the + snippet relation, and return it if any, else None. + + A relation is compatible if it: + * belongs to the currently processed statement, + * isn't negged (i.e. direct parent is a NOT node) + * isn't optional (outer join) or similarly as the snippet relation + """ + for rel in relations: + # don't share if relation's scope is not the current statement + if rel.scope is not stmt: + continue + # don't share neged relation + if rel.neged(strict=True): + continue + # don't share optional relation, unless the snippet relation is + # similarly optional + if rel.optional and rel.optional != sniprel.optional: + continue + return rel + return None + + +def iter_relations(stinfo): + # this is a function so that test may return relation in a predictable order + return stinfo['relations'] - stinfo['rhsrelations'] + + +class Unsupported(Exception): + """raised when an rql expression can't be inserted in some rql query + because it create an unresolvable query (eg no solutions found) + """ + +class VariableFromSubQuery(Exception): + """flow control exception to indicate that a variable is coming from a + subquery, and let parent act accordingly + """ + def __init__(self, variable): + self.variable = variable + + +class RQLRewriter(object): + """Insert some rql snippets into another rql syntax tree, for security / + relation vocabulary. This implies that it should only restrict results of + the original query, not generate new ones. Hence, inserted snippets are + inserted under an EXISTS node. + + This class *isn't thread safe*. + """ + + def __init__(self, session): + self.session = session + vreg = session.vreg + self.schema = vreg.schema + self.annotate = vreg.rqlhelper.annotate + self._compute_solutions = vreg.solutions + + def compute_solutions(self): + self.annotate(self.select) + try: + self._compute_solutions(self.session, self.select, self.kwargs) + except TypeResolverException: + raise Unsupported(str(self.select)) + if len(self.select.solutions) < len(self.solutions): + raise Unsupported() + + def insert_local_checks(self, select, kwargs, + localchecks, restricted, noinvariant): + """ + select: the rql syntax tree Select node + kwargs: query arguments + + localchecks: {(('Var name', (rqlexpr1, rqlexpr2)), + ('Var name1', (rqlexpr1, rqlexpr23))): [solution]} + + (see querier._check_permissions docstring for more information) + + restricted: set of variable names to which an rql expression has to be + applied + + noinvariant: set of variable names that can't be considered has + invariant due to security reason (will be filed by this method) + """ + nbtrees = len(localchecks) + myunion = union = select.parent + # transform in subquery when len(localchecks)>1 and groups + if nbtrees > 1 and (select.orderby or select.groupby or + select.having or select.has_aggregat or + select.distinct or + select.limit or select.offset): + newselect = stmts.Select() + # only select variables in subqueries + origselection = select.selection + select.select_only_variables() + select.has_aggregat = False + # create subquery first so correct node are used on copy + # (eg ColumnAlias instead of Variable) + aliases = [n.VariableRef(newselect.get_variable(vref.name, i)) + for i, vref in enumerate(select.selection)] + selected = set(vref.name for vref in aliases) + # now copy original selection and groups + for term in origselection: + newselect.append_selected(term.copy(newselect)) + if select.orderby: + sortterms = [] + for sortterm in select.orderby: + sortterms.append(sortterm.copy(newselect)) + for fnode in sortterm.get_nodes(n.Function): + if fnode.name == 'FTIRANK': + # we've to fetch the has_text relation as well + var = fnode.children[0].variable + rel = next(iter(var.stinfo['ftirels'])) + assert not rel.ored(), 'unsupported' + newselect.add_restriction(rel.copy(newselect)) + # remove relation from the orig select and + # cleanup variable stinfo + rel.parent.remove(rel) + var.stinfo['ftirels'].remove(rel) + var.stinfo['relations'].remove(rel) + # XXX not properly re-annotated after security insertion? + newvar = newselect.get_variable(var.name) + newvar.stinfo.setdefault('ftirels', set()).add(rel) + newvar.stinfo.setdefault('relations', set()).add(rel) + newselect.set_orderby(sortterms) + _expand_selection(select.orderby, selected, aliases, select, newselect) + select.orderby = () # XXX dereference? + if select.groupby: + newselect.set_groupby([g.copy(newselect) for g in select.groupby]) + _expand_selection(select.groupby, selected, aliases, select, newselect) + select.groupby = () # XXX dereference? + if select.having: + newselect.set_having([g.copy(newselect) for g in select.having]) + _expand_selection(select.having, selected, aliases, select, newselect) + select.having = () # XXX dereference? + if select.limit: + newselect.limit = select.limit + select.limit = None + if select.offset: + newselect.offset = select.offset + select.offset = 0 + myunion = stmts.Union() + newselect.set_with([n.SubQuery(aliases, myunion)], check=False) + newselect.distinct = select.distinct + solutions = [sol.copy() for sol in select.solutions] + cleanup_solutions(newselect, solutions) + newselect.set_possible_types(solutions) + # if some solutions doesn't need rewriting, insert original + # select as first union subquery + if () in localchecks: + myunion.append(select) + # we're done, replace original select by the new select with + # subqueries (more added in the loop below) + union.replace(select, newselect) + elif not () in localchecks: + union.remove(select) + for lcheckdef, lchecksolutions in localchecks.items(): + if not lcheckdef: + continue + myrqlst = select.copy(solutions=lchecksolutions) + myunion.append(myrqlst) + # in-place rewrite + annotation / simplification + lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef] + self.rewrite(myrqlst, lcheckdef, kwargs) + _add_noinvariant(noinvariant, restricted, myrqlst, nbtrees) + if () in localchecks: + select.set_possible_types(localchecks[()]) + add_types_restriction(self.schema, select) + _add_noinvariant(noinvariant, restricted, select, nbtrees) + self.annotate(union) + + def rewrite(self, select, snippets, kwargs, existingvars=None): + """ + snippets: (varmap, list of rql expression) + with varmap a *dict* {select var: snippet var} + """ + self.select = select + # remove_solutions used below require a copy + self.solutions = solutions = select.solutions[:] + self.kwargs = kwargs + self.u_varname = None + self.removing_ambiguity = False + self.exists_snippet = {} + self.pending_keys = [] + self.existingvars = existingvars + # we have to annotate the rqlst before inserting snippets, even though + # we'll have to redo it later + self.annotate(select) + self.insert_snippets(snippets) + if not self.exists_snippet and self.u_varname: + # U has been inserted than cancelled, cleanup + select.undefine_variable(select.defined_vars[self.u_varname]) + # clean solutions according to initial solutions + newsolutions = remove_solutions(solutions, select.solutions, + select.defined_vars) + assert len(newsolutions) >= len(solutions), ( + 'rewritten rql %s has lost some solutions, there is probably ' + 'something wrong in your schema permission (for instance using a ' + 'RQLExpression which inserts a relation which doesn\'t exist in ' + 'the schema)\nOrig solutions: %s\nnew solutions: %s' % ( + select, solutions, newsolutions)) + if len(newsolutions) > len(solutions): + newsolutions = self.remove_ambiguities(snippets, newsolutions) + assert newsolutions + select.solutions = newsolutions + add_types_restriction(self.schema, select) + + def insert_snippets(self, snippets, varexistsmap=None): + self.rewritten = {} + for varmap, rqlexprs in snippets: + if isinstance(varmap, dict): + varmap = tuple(sorted(varmap.items())) + else: + assert isinstance(varmap, tuple), varmap + if varexistsmap is not None and not varmap in varexistsmap: + continue + self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap) + + def init_from_varmap(self, varmap, varexistsmap=None): + self.varmap = varmap + self.revvarmap = {} + self.varinfos = [] + for i, (selectvar, snippetvar) in enumerate(varmap): + assert snippetvar in 'SOX' + self.revvarmap[snippetvar] = (selectvar, i) + vi = {} + self.varinfos.append(vi) + try: + vi['const'] = int(selectvar) + vi['rhs_rels'] = vi['lhs_rels'] = {} + except ValueError: + try: + vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo + except KeyError: + vi['stinfo'] = sti = self._subquery_variable(selectvar) + if varexistsmap is None: + # build an index for quick access to relations + vi['rhs_rels'] = {} + for rel in sti.get('rhsrelations', []): + vi['rhs_rels'].setdefault(rel.r_type, []).append(rel) + vi['lhs_rels'] = {} + for rel in sti.get('relations', []): + if not rel in sti.get('rhsrelations', []): + vi['lhs_rels'].setdefault(rel.r_type, []).append(rel) + else: + vi['rhs_rels'] = vi['lhs_rels'] = {} + + def _subquery_variable(self, selectvar): + raise VariableFromSubQuery(selectvar) + + def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap): + try: + self.init_from_varmap(varmap, varexistsmap) + except VariableFromSubQuery as ex: + # variable may have been moved to a newly inserted subquery + # we should insert snippet in that subquery + subquery = self.select.aliases[ex.variable].query + assert len(subquery.children) == 1, subquery + subselect = subquery.children[0] + RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)], + self.kwargs) + return + self._insert_scope = None + previous = None + inserted = False + for rqlexpr in rqlexprs: + self.current_expr = rqlexpr + if varexistsmap is None: + try: + new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, previous) + except Unsupported: + continue + inserted = True + if new is not None and self._insert_scope is None: + self.exists_snippet[rqlexpr] = new + previous = previous or new + else: + # called to reintroduce snippet due to ambiguity creation, + # so skip snippets which are not introducing this ambiguity + exists = varexistsmap[varmap] + if self.exists_snippet.get(rqlexpr) is exists: + self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists) + if varexistsmap is None and not inserted: + # no rql expression found matching rql solutions. User has no access right + raise Unauthorized() # XXX may also be because of bad constraints in schema definition + + def insert_snippet(self, varmap, snippetrqlst, previous=None): + new = snippetrqlst.where.accept(self) + existing = self.existingvars + self.existingvars = None + try: + return self._insert_snippet(varmap, previous, new) + finally: + self.existingvars = existing + + def _inserted_root(self, new): + if not isinstance(new, (n.Exists, n.Not)): + new = n.Exists(new) + return new + + def _insert_snippet(self, varmap, previous, new): + """insert `new` snippet into the syntax tree, which have been rewritten + using `varmap`. In cases where an action is protected by several rql + expresssion, `previous` will be the first rql expression which has been + inserted, and so should be ORed with the following expressions. + """ + if new is not None: + if self._insert_scope is None: + insert_scope = None + for vi in self.varinfos: + scope = vi.get('stinfo', {}).get('scope', self.select) + if insert_scope is None: + insert_scope = scope + else: + insert_scope = common_parent(scope, insert_scope) + else: + insert_scope = self._insert_scope + if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations') + for vi in self.varinfos): + assert previous is None + self._insert_scope, new = self.snippet_subquery(varmap, new) + self.insert_pending() + #self._insert_scope = None + return new + new = self._inserted_root(new) + if previous is None: + insert_scope.add_restriction(new) + else: + grandpa = previous.parent + or_ = n.Or(previous, new) + grandpa.replace(previous, or_) + if not self.removing_ambiguity: + try: + self.compute_solutions() + except Unsupported: + # some solutions have been lost, can't apply this rql expr + if previous is None: + self.current_statement().remove_node(new, undefine=True) + else: + grandpa.replace(or_, previous) + self._cleanup_inserted(new) + raise + else: + with tempattr(self, '_insert_scope', new): + self.insert_pending() + return new + self.insert_pending() + + def insert_pending(self): + """pending_keys hold variable referenced by U has__permission X + relation. + + Once the snippet introducing this has been inserted and solutions + recomputed, we have to insert snippet defined for of entity + types taken by X + """ + stmt = self.current_statement() + while self.pending_keys: + key, action = self.pending_keys.pop() + try: + varname = self.rewritten[key] + except KeyError: + try: + varname = self.revvarmap[key[-1]][0] + except KeyError: + # variable isn't used anywhere else, we can't insert security + raise Unauthorized() + ptypes = stmt.defined_vars[varname].stinfo['possibletypes'] + if len(ptypes) > 1: + # XXX dunno how to handle this + self.session.error( + 'cant check security of %s, ambigous type for %s in %s', + stmt, varname, key[0]) # key[0] == the rql expression + raise Unauthorized() + etype = next(iter(ptypes)) + eschema = self.schema.eschema(etype) + if not eschema.has_perm(self.session, action): + rqlexprs = eschema.get_rqlexprs(action) + if not rqlexprs: + raise Unauthorized() + self.insert_snippets([({varname: 'X'}, rqlexprs)]) + + def snippet_subquery(self, varmap, transformedsnippet): + """introduce the given snippet in a subquery""" + subselect = stmts.Select() + snippetrqlst = n.Exists(transformedsnippet.copy(subselect)) + get_rschema = self.schema.rschema + aliases = [] + done = set() + for i, (selectvar, _) in enumerate(varmap): + need_null_test = False + subselectvar = subselect.get_variable(selectvar) + subselect.append_selected(n.VariableRef(subselectvar)) + aliases.append(selectvar) + todo = [(selectvar, self.varinfos[i]['stinfo'])] + while todo: + varname, stinfo = todo.pop() + done.add(varname) + for rel in iter_relations(stinfo): + if rel in done: + continue + done.add(rel) + rschema = get_rschema(rel.r_type) + if rschema.final or rschema.inlined: + rel.children[0].name = varname # XXX explain why + subselect.add_restriction(rel.copy(subselect)) + for vref in rel.children[1].iget_nodes(n.VariableRef): + if isinstance(vref.variable, n.ColumnAlias): + # XXX could probably be handled by generating the + # subquery into the detected subquery + raise BadSchemaDefinition( + "cant insert security because of usage two inlined " + "relations in this query. You should probably at " + "least uninline %s" % rel.r_type) + subselect.append_selected(vref.copy(subselect)) + aliases.append(vref.name) + self.select.remove_node(rel) + # when some inlined relation has to be copied in the + # subquery and that relation is optional, we need to + # test that either value is NULL or that the snippet + # condition is satisfied + if varname == selectvar and rel.optional and rschema.inlined: + need_null_test = True + # also, if some attributes or inlined relation of the + # object variable are accessed, we need to get all those + # from the subquery as well + if vref.name not in done and rschema.inlined: + # we can use vref here define in above for loop + ostinfo = vref.variable.stinfo + for orel in iter_relations(ostinfo): + orschema = get_rschema(orel.r_type) + if orschema.final or orschema.inlined: + todo.append( (vref.name, ostinfo) ) + break + if need_null_test: + snippetrqlst = n.Or( + n.make_relation(subselect.get_variable(selectvar), 'is', + (None, None), n.Constant, + operator='='), + snippetrqlst) + subselect.add_restriction(snippetrqlst) + if self.u_varname: + # generate an identifier for the substitution + argname = subselect.allocate_varname() + while argname in self.kwargs: + argname = subselect.allocate_varname() + subselect.add_constant_restriction(subselect.get_variable(self.u_varname), + 'eid', text_type(argname), 'Substitute') + self.kwargs[argname] = self.session.user.eid + add_types_restriction(self.schema, subselect, subselect, + solutions=self.solutions) + myunion = stmts.Union() + myunion.append(subselect) + aliases = [n.VariableRef(self.select.get_variable(name, i)) + for i, name in enumerate(aliases)] + self.select.add_subquery(n.SubQuery(aliases, myunion), check=False) + self._cleanup_inserted(transformedsnippet) + try: + self.compute_solutions() + except Unsupported: + # some solutions have been lost, can't apply this rql expr + self.select.remove_subquery(self.select.with_[-1]) + raise + return subselect, snippetrqlst + + def remove_ambiguities(self, snippets, newsolutions): + # the snippet has introduced some ambiguities, we have to resolve them + # "manually" + variantes = self.build_variantes(newsolutions) + # insert "is" where necessary + varexistsmap = {} + self.removing_ambiguity = True + for (erqlexpr, varmap, oldvarname), etype in variantes[0].items(): + varname = self.rewritten[(erqlexpr, varmap, oldvarname)] + var = self.select.defined_vars[varname] + exists = var.references()[0].scope + exists.add_constant_restriction(var, 'is', etype, 'etype') + varexistsmap[varmap] = exists + # insert ORED exists where necessary + for variante in variantes[1:]: + self.insert_snippets(snippets, varexistsmap) + for key, etype in variante.items(): + varname = self.rewritten[key] + try: + var = self.select.defined_vars[varname] + except KeyError: + # not a newly inserted variable + continue + exists = var.references()[0].scope + exists.add_constant_restriction(var, 'is', etype, 'etype') + # recompute solutions + self.compute_solutions() + # clean solutions according to initial solutions + return remove_solutions(self.solutions, self.select.solutions, + self.select.defined_vars) + + def build_variantes(self, newsolutions): + variantes = set() + for sol in newsolutions: + variante = [] + for key, newvar in self.rewritten.items(): + variante.append( (key, sol[newvar]) ) + variantes.add(tuple(variante)) + # rebuild variantes as dict + variantes = [dict(variante) for variante in variantes] + # remove variable which have always the same type + for key in self.rewritten: + it = iter(variantes) + etype = next(it)[key] + for variante in it: + if variante[key] != etype: + break + else: + for variante in variantes: + del variante[key] + return variantes + + def _cleanup_inserted(self, node): + # cleanup inserted variable references + removed = set() + for vref in node.iget_nodes(n.VariableRef): + vref.unregister_reference() + if not vref.variable.stinfo['references']: + # no more references, undefine the variable + del self.select.defined_vars[vref.name] + removed.add(vref.name) + for key, newvar in list(self.rewritten.items()): + if newvar in removed: + del self.rewritten[key] + + + def _may_be_shared_with(self, sniprel, target): + """if the snippet relation can be skipped to use a relation from the + original query, return that relation node + """ + if sniprel.neged(strict=True): + return None # no way + rschema = self.schema.rschema(sniprel.r_type) + stmt = self.current_statement() + for vi in self.varinfos: + try: + if target == 'object': + orels = vi['lhs_rels'][sniprel.r_type] + cardindex = 0 + ttypes_func = rschema.objects + rdef = rschema.rdef + else: # target == 'subject': + orels = vi['rhs_rels'][sniprel.r_type] + cardindex = 1 + ttypes_func = rschema.subjects + rdef = lambda x, y: rschema.rdef(y, x) + except KeyError: + # may be raised by vi['xhs_rels'][sniprel.r_type] + continue + # if cardinality isn't in '?1', we can't ignore the snippet relation + # and use variable from the original query + if _has_multiple_cardinality(vi['stinfo']['possibletypes'], rdef, + ttypes_func, cardindex): + continue + orel = _compatible_relation(orels, stmt, sniprel) + if orel is not None: + return orel + return None + + def _use_orig_term(self, snippet_varname, term): + key = (self.current_expr, self.varmap, snippet_varname) + if key in self.rewritten: + stmt = self.current_statement() + insertedvar = stmt.defined_vars.pop(self.rewritten[key]) + for inserted_vref in insertedvar.references(): + inserted_vref.parent.replace(inserted_vref, term.copy(stmt)) + self.rewritten[key] = term.name + + def _get_varname_or_term(self, vname): + stmt = self.current_statement() + if vname == 'U': + stmt = self.select + if self.u_varname is None: + self.u_varname = stmt.allocate_varname() + # generate an identifier for the substitution + argname = stmt.allocate_varname() + while argname in self.kwargs: + argname = stmt.allocate_varname() + # insert "U eid %(u)s" + stmt.add_constant_restriction( + stmt.get_variable(self.u_varname), + 'eid', text_type(argname), 'Substitute') + self.kwargs[argname] = self.session.user.eid + return self.u_varname + key = (self.current_expr, self.varmap, vname) + try: + return self.rewritten[key] + except KeyError: + self.rewritten[key] = newvname = stmt.allocate_varname() + return newvname + + # visitor methods ########################################################## + + def _visit_binary(self, node, cls): + newnode = cls() + for c in node.children: + new = c.accept(self) + if new is None: + continue + newnode.append(new) + if len(newnode.children) == 0: + return None + if len(newnode.children) == 1: + return newnode.children[0] + return newnode + + def _visit_unary(self, node, cls): + newc = node.children[0].accept(self) + if newc is None: + return None + newnode = cls() + newnode.append(newc) + return newnode + + def visit_and(self, node): + return self._visit_binary(node, n.And) + + def visit_or(self, node): + return self._visit_binary(node, n.Or) + + def visit_not(self, node): + return self._visit_unary(node, n.Not) + + def visit_exists(self, node): + return self._visit_unary(node, n.Exists) + + def keep_var(self, varname): + if varname in 'SO': + return varname in self.existingvars + if varname == 'U': + return True + vargraph = self.current_expr.vargraph + for existingvar in self.existingvars: + #path = has_path(vargraph, varname, existingvar) + if not varname in vargraph or has_path(vargraph, varname, existingvar): + return True + # no path from this variable to an existing variable + return False + + def visit_relation(self, node): + lhs, rhs = node.get_variable_parts() + # remove relations where an unexistant variable and or a variable linked + # to an unexistant variable is used. + if self.existingvars: + if not self.keep_var(lhs.name): + return + if node.r_type in ('has_add_permission', 'has_update_permission', + 'has_delete_permission', 'has_read_permission'): + assert lhs.name == 'U' + action = node.r_type.split('_')[1] + key = (self.current_expr, self.varmap, rhs.name) + self.pending_keys.append( (key, action) ) + return + if isinstance(rhs, n.VariableRef): + if self.existingvars and not self.keep_var(rhs.name): + return + if lhs.name in self.revvarmap and rhs.name != 'U': + orel = self._may_be_shared_with(node, 'object') + if orel is not None: + self._use_orig_term(rhs.name, orel.children[1].children[0]) + return + elif rhs.name in self.revvarmap and lhs.name != 'U': + orel = self._may_be_shared_with(node, 'subject') + if orel is not None: + self._use_orig_term(lhs.name, orel.children[0]) + return + rel = n.Relation(node.r_type, node.optional) + for c in node.children: + rel.append(c.accept(self)) + return rel + + def visit_comparison(self, node): + cmp_ = n.Comparison(node.operator) + for c in node.children: + cmp_.append(c.accept(self)) + return cmp_ + + def visit_mathexpression(self, node): + cmp_ = n.MathExpression(node.operator) + for c in node.children: + cmp_.append(c.accept(self)) + return cmp_ + + def visit_function(self, node): + """generate filter name for a function""" + function_ = n.Function(node.name) + for c in node.children: + function_.append(c.accept(self)) + return function_ + + def visit_constant(self, node): + """generate filter name for a constant""" + return n.Constant(node.value, node.type) + + def visit_variableref(self, node): + """get the sql name for a variable reference""" + stmt = self.current_statement() + if node.name in self.revvarmap: + selectvar, index = self.revvarmap[node.name] + vi = self.varinfos[index] + if vi.get('const') is not None: + return n.Constant(vi['const'], 'Int') + return n.VariableRef(stmt.get_variable(selectvar)) + vname_or_term = self._get_varname_or_term(node.name) + if isinstance(vname_or_term, string_types): + return n.VariableRef(stmt.get_variable(vname_or_term)) + # shared term + return vname_or_term.copy(stmt) + + def current_statement(self): + if self._insert_scope is None: + return self.select + return self._insert_scope.stmt + + +class RQLRelationRewriter(RQLRewriter): + """Insert some rql snippets into another rql syntax tree, replacing computed + relations by their associated rule. + + This class *isn't thread safe*. + """ + def __init__(self, session): + super(RQLRelationRewriter, self).__init__(session) + self.rules = {} + for rschema in self.schema.iter_computed_relations(): + self.rules[rschema.type] = RRQLExpression(rschema.rule) + + def rewrite(self, union, kwargs=None): + self.kwargs = kwargs + self.removing_ambiguity = False + self.existingvars = None + self.pending_keys = None + for relation in union.iget_nodes(n.Relation): + if relation.r_type in self.rules: + self.select = relation.stmt + self.solutions = solutions = self.select.solutions[:] + self.current_expr = self.rules[relation.r_type] + self._insert_scope = relation.scope + self.rewritten = {} + lhs, rhs = relation.get_variable_parts() + varmap = {lhs.name: 'S', rhs.name: 'O'} + self.init_from_varmap(tuple(sorted(varmap.items()))) + self.insert_snippet(varmap, self.current_expr.snippet_rqlst) + self.select.remove_node(relation) + + def _subquery_variable(self, selectvar): + return self.select.aliases[selectvar].stinfo + + def _inserted_root(self, new): + return new