# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""RQL rewriting utilities : insert rql expression snippets into rql syntax
tree.
This is used for instance for read security checking in the repository.
"""
__docformat__ = "restructuredtext en"
from rql import nodes as n, stmts, TypeResolverException
from rql.utils import common_parent
from yams import BadSchemaDefinition
from logilab.common import tempattr
from logilab.common.graph import has_path
from cubicweb import Unauthorized
def cleanup_solutions(rqlst, solutions):
for sol in solutions:
for vname in list(sol):
if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
del sol[vname]
def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
if newroot is None:
assert solutions is None
if hasattr(rqlst, '_types_restr_added'):
return
solutions = rqlst.solutions
newroot = rqlst
rqlst._types_restr_added = True
else:
assert solutions is not None
rqlst = rqlst.stmt
eschema = schema.eschema
allpossibletypes = {}
for solution in solutions:
for varname, etype in solution.iteritems():
# XXX not considering aliases by design, right ?
if varname not in newroot.defined_vars or eschema(etype).final:
continue
allpossibletypes.setdefault(varname, set()).add(etype)
# XXX could be factorized with add_etypes_restriction from rql 0.31
for varname in sorted(allpossibletypes):
var = newroot.defined_vars[varname]
stinfo = var.stinfo
if stinfo.get('uidrel') is not None:
continue # eid specified, no need for additional type specification
try:
typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
except KeyError:
assert varname in rqlst.aliases
continue
if newroot is rqlst and typerel is not None:
mytyperel = typerel
else:
for vref in var.references():
rel = vref.relation()
if rel and rel.is_types_restriction():
mytyperel = rel
break
else:
mytyperel = None
possibletypes = allpossibletypes[varname]
if mytyperel is not None:
if mytyperel.r_type == 'is_instance_of':
# turn is_instance_of relation into a is relation since we've
# all possible solutions and don't want to bother with
# potential is_instance_of incompatibility
mytyperel.r_type = 'is'
if len(possibletypes) > 1:
node = n.Function('IN')
for etype in possibletypes:
node.append(n.Constant(etype, 'etype'))
else:
node = n.Constant(etype, 'etype')
comp = mytyperel.children[1]
comp.replace(comp.children[0], node)
else:
# variable has already some strict types restriction. new
# possible types can only be a subset of existing ones, so only
# remove no more possible types
for cst in mytyperel.get_nodes(n.Constant):
if not cst.value in possibletypes:
cst.parent.remove(cst)
else:
# we have to add types restriction
if stinfo.get('scope') is not None:
rel = var.scope.add_type_restriction(var, possibletypes)
else:
# tree is not annotated yet, no scope set so add the restriction
# to the root
rel = newroot.add_type_restriction(var, possibletypes)
stinfo['typerel'] = rel
stinfo['possibletypes'] = possibletypes
def remove_solutions(origsolutions, solutions, defined):
"""when a rqlst has been generated from another by introducing security
assertions, this method returns solutions which are contained in orig
solutions
"""
newsolutions = []
for origsol in origsolutions:
for newsol in solutions[:]:
for var, etype in origsol.items():
try:
if newsol[var] != etype:
try:
defined[var].stinfo['possibletypes'].remove(newsol[var])
except KeyError:
pass
break
except KeyError:
# variable has been rewritten
continue
else:
newsolutions.append(newsol)
solutions.remove(newsol)
return newsolutions
def _add_noinvariant(noinvariant, restricted, select, nbtrees):
# a variable can actually be invariant if it has not been restricted for
# security reason or if security assertion hasn't modified the possible
# solutions for the query
for vname in restricted:
try:
var = select.defined_vars[vname]
except KeyError:
# this is an alias
continue
if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1:
noinvariant.add(var)
def _expand_selection(terms, selected, aliases, select, newselect):
for term in terms:
for vref in term.iget_nodes(n.VariableRef):
if not vref.name in selected:
select.append_selected(vref)
colalias = newselect.get_variable(vref.name, len(aliases))
aliases.append(n.VariableRef(colalias))
selected.add(vref.name)
def _has_multiple_cardinality(etypes, rdef, ttypes_func, cardindex):
"""return True if relation definitions from entity types (`etypes`) to
target types returned by the `ttypes_func` function all have single (1 or ?)
cardinality.
"""
for etype in etypes:
for ttype in ttypes_func(etype):
if rdef(etype, ttype).cardinality[cardindex] in '+*':
return True
return False
def _compatible_relation(relations, stmt, sniprel):
"""Search among given rql relation nodes if there is one 'compatible' with the
snippet relation, and return it if any, else None.
A relation is compatible if it:
* belongs to the currently processed statement,
* isn't negged (i.e. direct parent is a NOT node)
* isn't optional (outer join) or similarly as the snippet relation
"""
for rel in relations:
# don't share if relation's scope is not the current statement
if rel.scope is not stmt:
continue
# don't share neged relation
if rel.neged(strict=True):
continue
# don't share optional relation, unless the snippet relation is
# similarly optional
if rel.optional and rel.optional != sniprel.optional:
continue
return rel
return None
def iter_relations(stinfo):
# this is a function so that test may return relation in a predictable order
return stinfo['relations'] - stinfo['rhsrelations']
class Unsupported(Exception):
"""raised when an rql expression can't be inserted in some rql query
because it create an unresolvable query (eg no solutions found)
"""
class RQLRewriter(object):
"""insert some rql snippets into another rql syntax tree
this class *isn't thread safe*
"""
def __init__(self, session):
self.session = session
vreg = session.vreg
self.schema = vreg.schema
self.annotate = vreg.rqlhelper.annotate
self._compute_solutions = vreg.solutions
def compute_solutions(self):
self.annotate(self.select)
try:
self._compute_solutions(self.session, self.select, self.kwargs)
except TypeResolverException:
raise Unsupported(str(self.select))
if len(self.select.solutions) < len(self.solutions):
raise Unsupported()
def insert_local_checks(self, select, kwargs,
localchecks, restricted, noinvariant):
"""
select: the rql syntax tree Select node
kwargs: query arguments
localchecks: {(('Var name', (rqlexpr1, rqlexpr2)),
('Var name1', (rqlexpr1, rqlexpr23))): [solution]}
(see querier._check_permissions docstring for more information)
restricted: set of variable names to which an rql expression has to be
applied
noinvariant: set of variable names that can't be considered has
invariant due to security reason (will be filed by this method)
"""
nbtrees = len(localchecks)
myunion = union = select.parent
# transform in subquery when len(localchecks)>1 and groups
if nbtrees > 1 and (select.orderby or select.groupby or
select.having or select.has_aggregat or
select.distinct or
select.limit or select.offset):
newselect = stmts.Select()
# only select variables in subqueries
origselection = select.selection
select.select_only_variables()
select.has_aggregat = False
# create subquery first so correct node are used on copy
# (eg ColumnAlias instead of Variable)
aliases = [n.VariableRef(newselect.get_variable(vref.name, i))
for i, vref in enumerate(select.selection)]
selected = set(vref.name for vref in aliases)
# now copy original selection and groups
for term in origselection:
newselect.append_selected(term.copy(newselect))
if select.orderby:
sortterms = []
for sortterm in select.orderby:
sortterms.append(sortterm.copy(newselect))
for fnode in sortterm.get_nodes(n.Function):
if fnode.name == 'FTIRANK':
# we've to fetch the has_text relation as well
var = fnode.children[0].variable
rel = iter(var.stinfo['ftirels']).next()
assert not rel.ored(), 'unsupported'
newselect.add_restriction(rel.copy(newselect))
# remove relation from the orig select and
# cleanup variable stinfo
rel.parent.remove(rel)
var.stinfo['ftirels'].remove(rel)
var.stinfo['relations'].remove(rel)
# XXX not properly re-annotated after security insertion?
newvar = newselect.get_variable(var.name)
newvar.stinfo.setdefault('ftirels', set()).add(rel)
newvar.stinfo.setdefault('relations', set()).add(rel)
newselect.set_orderby(sortterms)
_expand_selection(select.orderby, selected, aliases, select, newselect)
select.orderby = () # XXX dereference?
if select.groupby:
newselect.set_groupby([g.copy(newselect) for g in select.groupby])
_expand_selection(select.groupby, selected, aliases, select, newselect)
select.groupby = () # XXX dereference?
if select.having:
newselect.set_having([g.copy(newselect) for g in select.having])
_expand_selection(select.having, selected, aliases, select, newselect)
select.having = () # XXX dereference?
if select.limit:
newselect.limit = select.limit
select.limit = None
if select.offset:
newselect.offset = select.offset
select.offset = 0
myunion = stmts.Union()
newselect.set_with([n.SubQuery(aliases, myunion)], check=False)
newselect.distinct = select.distinct
solutions = [sol.copy() for sol in select.solutions]
cleanup_solutions(newselect, solutions)
newselect.set_possible_types(solutions)
# if some solutions doesn't need rewriting, insert original
# select as first union subquery
if () in localchecks:
myunion.append(select)
# we're done, replace original select by the new select with
# subqueries (more added in the loop below)
union.replace(select, newselect)
elif not () in localchecks:
union.remove(select)
for lcheckdef, lchecksolutions in localchecks.iteritems():
if not lcheckdef:
continue
myrqlst = select.copy(solutions=lchecksolutions)
myunion.append(myrqlst)
# in-place rewrite + annotation / simplification
lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
self.rewrite(myrqlst, lcheckdef, kwargs)
_add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
if () in localchecks:
select.set_possible_types(localchecks[()])
add_types_restriction(self.schema, select)
_add_noinvariant(noinvariant, restricted, select, nbtrees)
self.annotate(union)
def rewrite(self, select, snippets, kwargs, existingvars=None):
"""
snippets: (varmap, list of rql expression)
with varmap a *tuple* (select var, snippet var)
"""
self.select = select
# remove_solutions used below require a copy
self.solutions = solutions = select.solutions[:]
self.kwargs = kwargs
self.u_varname = None
self.removing_ambiguity = False
self.exists_snippet = {}
self.pending_keys = []
self.existingvars = existingvars
# we have to annotate the rqlst before inserting snippets, even though
# we'll have to redo it latter
self.annotate(select)
self.insert_snippets(snippets)
if not self.exists_snippet and self.u_varname:
# U has been inserted than cancelled, cleanup
select.undefine_variable(select.defined_vars[self.u_varname])
# clean solutions according to initial solutions
newsolutions = remove_solutions(solutions, select.solutions,
select.defined_vars)
assert len(newsolutions) >= len(solutions), (
'rewritten rql %s has lost some solutions, there is probably '
'something wrong in your schema permission (for instance using a '
'RQLExpression which insert a relation which doesn\'t exists in '
'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
select, solutions, newsolutions))
if len(newsolutions) > len(solutions):
newsolutions = self.remove_ambiguities(snippets, newsolutions)
assert newsolutions
select.solutions = newsolutions
add_types_restriction(self.schema, select)
def insert_snippets(self, snippets, varexistsmap=None):
self.rewritten = {}
for varmap, rqlexprs in snippets:
if isinstance(varmap, dict):
varmap = tuple(sorted(varmap.items()))
else:
assert isinstance(varmap, tuple), varmap
if varexistsmap is not None and not varmap in varexistsmap:
continue
self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
self.varmap = varmap
self.revvarmap = {}
self.varinfos = []
self._insert_scope = None
for i, (selectvar, snippetvar) in enumerate(varmap):
assert snippetvar in 'SOX'
self.revvarmap[snippetvar] = (selectvar, i)
vi = {}
self.varinfos.append(vi)
try:
vi['const'] = int(selectvar)
vi['rhs_rels'] = vi['lhs_rels'] = {}
except ValueError:
try:
vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
except KeyError:
# variable may have been moved to a newly inserted subquery
# we should insert snippet in that subquery
subquery = self.select.aliases[selectvar].query
assert len(subquery.children) == 1
subselect = subquery.children[0]
RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
subselect.solutions, self.kwargs)
return
if varexistsmap is None:
# build an index for quick access to relations
vi['rhs_rels'] = {}
for rel in sti['rhsrelations']:
vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
vi['lhs_rels'] = {}
for rel in sti['relations']:
if not rel in sti['rhsrelations']:
vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
else:
vi['rhs_rels'] = vi['lhs_rels'] = {}
previous = None
inserted = False
for rqlexpr in rqlexprs:
self.current_expr = rqlexpr
if varexistsmap is None:
try:
new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, previous)
except Unsupported:
continue
inserted = True
if new is not None and self._insert_scope is None:
self.exists_snippet[rqlexpr] = new
previous = previous or new
else:
# called to reintroduce snippet due to ambiguity creation,
# so skip snippets which are not introducing this ambiguity
exists = varexistsmap[varmap]
if self.exists_snippet.get(rqlexpr) is exists:
self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
if varexistsmap is None and not inserted:
# no rql expression found matching rql solutions. User has no access right
raise Unauthorized() # XXX may also be because of bad constraints in schema definition
def insert_snippet(self, varmap, snippetrqlst, previous=None):
new = snippetrqlst.where.accept(self)
existing = self.existingvars
self.existingvars = None
try:
return self._insert_snippet(varmap, previous, new)
finally:
self.existingvars = existing
def _insert_snippet(self, varmap, previous, new):
"""insert `new` snippet into the syntax tree, which have been rewritten
using `varmap`. In cases where an action is protected by several rql
expresssion, `previous` will be the first rql expression which has been
inserted, and so should be ORed with the following expressions.
"""
if new is not None:
if self._insert_scope is None:
insert_scope = None
for vi in self.varinfos:
scope = vi.get('stinfo', {}).get('scope', self.select)
if insert_scope is None:
insert_scope = scope
else:
insert_scope = common_parent(scope, insert_scope)
else:
insert_scope = self._insert_scope
if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
for vi in self.varinfos):
assert previous is None
self._insert_scope, new = self.snippet_subquery(varmap, new)
self.insert_pending()
#self._insert_scope = None
return new
if not isinstance(new, (n.Exists, n.Not)):
new = n.Exists(new)
if previous is None:
insert_scope.add_restriction(new)
else:
grandpa = previous.parent
or_ = n.Or(previous, new)
grandpa.replace(previous, or_)
if not self.removing_ambiguity:
try:
self.compute_solutions()
except Unsupported:
# some solutions have been lost, can't apply this rql expr
if previous is None:
self.current_statement().remove_node(new, undefine=True)
else:
grandpa.replace(or_, previous)
self._cleanup_inserted(new)
raise
else:
with tempattr(self, '_insert_scope', new):
self.insert_pending()
return new
self.insert_pending()
def insert_pending(self):
"""pending_keys hold variable referenced by U has_<action>_permission X
relation.
Once the snippet introducing this has been inserted and solutions
recomputed, we have to insert snippet defined for <action> of entity
types taken by X
"""
stmt = self.current_statement()
while self.pending_keys:
key, action = self.pending_keys.pop()
try:
varname = self.rewritten[key]
except KeyError:
try:
varname = self.revvarmap[key[-1]][0]
except KeyError:
# variable isn't used anywhere else, we can't insert security
raise Unauthorized()
ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
if len(ptypes) > 1:
# XXX dunno how to handle this
self.session.error(
'cant check security of %s, ambigous type for %s in %s',
stmt, varname, key[0]) # key[0] == the rql expression
raise Unauthorized()
etype = iter(ptypes).next()
eschema = self.schema.eschema(etype)
if not eschema.has_perm(self.session, action):
rqlexprs = eschema.get_rqlexprs(action)
if not rqlexprs:
raise Unauthorized()
self.insert_snippets([({varname: 'X'}, rqlexprs)])
def snippet_subquery(self, varmap, transformedsnippet):
"""introduce the given snippet in a subquery"""
subselect = stmts.Select()
snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
get_rschema = self.schema.rschema
aliases = []
done = set()
for i, (selectvar, _) in enumerate(varmap):
need_null_test = False
subselectvar = subselect.get_variable(selectvar)
subselect.append_selected(n.VariableRef(subselectvar))
aliases.append(selectvar)
todo = [(selectvar, self.varinfos[i]['stinfo'])]
while todo:
varname, stinfo = todo.pop()
done.add(varname)
for rel in iter_relations(stinfo):
if rel in done:
continue
done.add(rel)
rschema = get_rschema(rel.r_type)
if rschema.final or rschema.inlined:
rel.children[0].name = varname # XXX explain why
subselect.add_restriction(rel.copy(subselect))
for vref in rel.children[1].iget_nodes(n.VariableRef):
if isinstance(vref.variable, n.ColumnAlias):
# XXX could probably be handled by generating the
# subquery into the detected subquery
raise BadSchemaDefinition(
"cant insert security because of usage two inlined "
"relations in this query. You should probably at "
"least uninline %s" % rel.r_type)
subselect.append_selected(vref.copy(subselect))
aliases.append(vref.name)
self.select.remove_node(rel)
# when some inlined relation has to be copied in the
# subquery and that relation is optional, we need to
# test that either value is NULL or that the snippet
# condition is satisfied
if varname == selectvar and rel.optional and rschema.inlined:
need_null_test = True
# also, if some attributes or inlined relation of the
# object variable are accessed, we need to get all those
# from the subquery as well
if vref.name not in done and rschema.inlined:
# we can use vref here define in above for loop
ostinfo = vref.variable.stinfo
for orel in iter_relations(ostinfo):
orschema = get_rschema(orel.r_type)
if orschema.final or orschema.inlined:
todo.append( (vref.name, ostinfo) )
break
if need_null_test:
snippetrqlst = n.Or(
n.make_relation(subselect.get_variable(selectvar), 'is',
(None, None), n.Constant,
operator='='),
snippetrqlst)
subselect.add_restriction(snippetrqlst)
if self.u_varname:
# generate an identifier for the substitution
argname = subselect.allocate_varname()
while argname in self.kwargs:
argname = subselect.allocate_varname()
subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
'eid', unicode(argname), 'Substitute')
self.kwargs[argname] = self.session.user.eid
add_types_restriction(self.schema, subselect, subselect,
solutions=self.solutions)
myunion = stmts.Union()
myunion.append(subselect)
aliases = [n.VariableRef(self.select.get_variable(name, i))
for i, name in enumerate(aliases)]
self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
self._cleanup_inserted(transformedsnippet)
try:
self.compute_solutions()
except Unsupported:
# some solutions have been lost, can't apply this rql expr
self.select.remove_subquery(self.select.with_[-1])
raise
return subselect, snippetrqlst
def remove_ambiguities(self, snippets, newsolutions):
# the snippet has introduced some ambiguities, we have to resolve them
# "manually"
variantes = self.build_variantes(newsolutions)
# insert "is" where necessary
varexistsmap = {}
self.removing_ambiguity = True
for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
var = self.select.defined_vars[varname]
exists = var.references()[0].scope
exists.add_constant_restriction(var, 'is', etype, 'etype')
varexistsmap[varmap] = exists
# insert ORED exists where necessary
for variante in variantes[1:]:
self.insert_snippets(snippets, varexistsmap)
for key, etype in variante.iteritems():
varname = self.rewritten[key]
try:
var = self.select.defined_vars[varname]
except KeyError:
# not a newly inserted variable
continue
exists = var.references()[0].scope
exists.add_constant_restriction(var, 'is', etype, 'etype')
# recompute solutions
self.compute_solutions()
# clean solutions according to initial solutions
return remove_solutions(self.solutions, self.select.solutions,
self.select.defined_vars)
def build_variantes(self, newsolutions):
variantes = set()
for sol in newsolutions:
variante = []
for key, newvar in self.rewritten.iteritems():
variante.append( (key, sol[newvar]) )
variantes.add(tuple(variante))
# rebuild variantes as dict
variantes = [dict(variante) for variante in variantes]
# remove variable which have always the same type
for key in self.rewritten:
it = iter(variantes)
etype = it.next()[key]
for variante in it:
if variante[key] != etype:
break
else:
for variante in variantes:
del variante[key]
return variantes
def _cleanup_inserted(self, node):
# cleanup inserted variable references
removed = set()
for vref in node.iget_nodes(n.VariableRef):
vref.unregister_reference()
if not vref.variable.stinfo['references']:
# no more references, undefine the variable
del self.select.defined_vars[vref.name]
removed.add(vref.name)
for key, newvar in self.rewritten.items(): # I mean items we alter it
if newvar in removed:
del self.rewritten[key]
def _may_be_shared_with(self, sniprel, target):
"""if the snippet relation can be skipped to use a relation from the
original query, return that relation node
"""
if sniprel.neged(strict=True):
return None # no way
rschema = self.schema.rschema(sniprel.r_type)
stmt = self.current_statement()
for vi in self.varinfos:
try:
if target == 'object':
orels = vi['lhs_rels'][sniprel.r_type]
cardindex = 0
ttypes_func = rschema.objects
rdef = rschema.rdef
else: # target == 'subject':
orels = vi['rhs_rels'][sniprel.r_type]
cardindex = 1
ttypes_func = rschema.subjects
rdef = lambda x, y: rschema.rdef(y, x)
except KeyError:
# may be raised by vi['xhs_rels'][sniprel.r_type]
continue
# if cardinality isn't in '?1', we can't ignore the snippet relation
# and use variable from the original query
if _has_multiple_cardinality(vi['stinfo']['possibletypes'], rdef,
ttypes_func, cardindex):
continue
orel = _compatible_relation(orels, stmt, sniprel)
if orel is not None:
return orel
return None
def _use_orig_term(self, snippet_varname, term):
key = (self.current_expr, self.varmap, snippet_varname)
if key in self.rewritten:
stmt = self.current_statement()
insertedvar = stmt.defined_vars.pop(self.rewritten[key])
for inserted_vref in insertedvar.references():
inserted_vref.parent.replace(inserted_vref, term.copy(stmt))
self.rewritten[key] = term.name
def _get_varname_or_term(self, vname):
stmt = self.current_statement()
if vname == 'U':
stmt = self.select
if self.u_varname is None:
self.u_varname = stmt.allocate_varname()
# generate an identifier for the substitution
argname = stmt.allocate_varname()
while argname in self.kwargs:
argname = stmt.allocate_varname()
# insert "U eid %(u)s"
stmt.add_constant_restriction(
stmt.get_variable(self.u_varname),
'eid', unicode(argname), 'Substitute')
self.kwargs[argname] = self.session.user.eid
return self.u_varname
key = (self.current_expr, self.varmap, vname)
try:
return self.rewritten[key]
except KeyError:
self.rewritten[key] = newvname = stmt.allocate_varname()
return newvname
# visitor methods ##########################################################
def _visit_binary(self, node, cls):
newnode = cls()
for c in node.children:
new = c.accept(self)
if new is None:
continue
newnode.append(new)
if len(newnode.children) == 0:
return None
if len(newnode.children) == 1:
return newnode.children[0]
return newnode
def _visit_unary(self, node, cls):
newc = node.children[0].accept(self)
if newc is None:
return None
newnode = cls()
newnode.append(newc)
return newnode
def visit_and(self, node):
return self._visit_binary(node, n.And)
def visit_or(self, node):
return self._visit_binary(node, n.Or)
def visit_not(self, node):
return self._visit_unary(node, n.Not)
def visit_exists(self, node):
return self._visit_unary(node, n.Exists)
def keep_var(self, varname):
if varname in 'SO':
return varname in self.existingvars
if varname == 'U':
return True
vargraph = self.current_expr.vargraph
for existingvar in self.existingvars:
#path = has_path(vargraph, varname, existingvar)
if has_path(vargraph, varname, existingvar):
return True
# no path from this variable to an existing variable
return False
def visit_relation(self, node):
lhs, rhs = node.get_variable_parts()
# remove relations where an unexistant variable and or a variable linked
# to an unexistant variable is used.
if self.existingvars:
if not self.keep_var(lhs.name):
return
if node.r_type in ('has_add_permission', 'has_update_permission',
'has_delete_permission', 'has_read_permission'):
assert lhs.name == 'U'
action = node.r_type.split('_')[1]
key = (self.current_expr, self.varmap, rhs.name)
self.pending_keys.append( (key, action) )
return
if isinstance(rhs, n.VariableRef):
if self.existingvars and not self.keep_var(rhs.name):
return
if lhs.name in self.revvarmap and rhs.name != 'U':
orel = self._may_be_shared_with(node, 'object')
if orel is not None:
self._use_orig_term(rhs.name, orel.children[1].children[0])
return
elif rhs.name in self.revvarmap and lhs.name != 'U':
orel = self._may_be_shared_with(node, 'subject')
if orel is not None:
self._use_orig_term(lhs.name, orel.children[0])
return
rel = n.Relation(node.r_type, node.optional)
for c in node.children:
rel.append(c.accept(self))
return rel
def visit_comparison(self, node):
cmp_ = n.Comparison(node.operator)
for c in node.children:
cmp_.append(c.accept(self))
return cmp_
def visit_mathexpression(self, node):
cmp_ = n.MathExpression(node.operator)
for c in node.children:
cmp_.append(c.accept(self))
return cmp_
def visit_function(self, node):
"""generate filter name for a function"""
function_ = n.Function(node.name)
for c in node.children:
function_.append(c.accept(self))
return function_
def visit_constant(self, node):
"""generate filter name for a constant"""
return n.Constant(node.value, node.type)
def visit_variableref(self, node):
"""get the sql name for a variable reference"""
stmt = self.current_statement()
if node.name in self.revvarmap:
selectvar, index = self.revvarmap[node.name]
vi = self.varinfos[index]
if vi.get('const') is not None:
return n.Constant(vi['const'], 'Int')
return n.VariableRef(stmt.get_variable(selectvar))
vname_or_term = self._get_varname_or_term(node.name)
if isinstance(vname_or_term, basestring):
return n.VariableRef(stmt.get_variable(vname_or_term))
# shared term
return vname_or_term.copy(stmt)
def current_statement(self):
if self._insert_scope is None:
return self.select
return self._insert_scope.stmt