provide a new add_cubes() migration function for cases where the new cubes are linked together by new relations
In this case, we need to add all new cubes at once.
"""RQL rewriting utilities, used for read security checking
:organization: Logilab
:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""
from rql import nodes, stmts, TypeResolverException
from cubicweb import Unauthorized, server, typed_eid
from cubicweb.server.ssplanner import add_types_restriction
def remove_solutions(origsolutions, solutions, defined):
"""when a rqlst has been generated from another by introducing security
assertions, this method returns solutions which are contained in orig
solutions
"""
newsolutions = []
for origsol in origsolutions:
for newsol in solutions[:]:
for var, etype in origsol.items():
try:
if newsol[var] != etype:
try:
defined[var].stinfo['possibletypes'].remove(newsol[var])
except KeyError:
pass
break
except KeyError,ex:
# variable has been rewritten
continue
else:
newsolutions.append(newsol)
solutions.remove(newsol)
return newsolutions
class Unsupported(Exception): pass
class RQLRewriter(object):
"""insert some rql snippets into another rql syntax tree"""
def __init__(self, querier, session):
self.session = session
self.annotate = querier._rqlhelper.annotate
self._compute_solutions = querier.solutions
self.schema = querier.schema
def compute_solutions(self):
self.annotate(self.select)
try:
self._compute_solutions(self.session, self.select, self.kwargs)
except TypeResolverException:
raise Unsupported()
if len(self.select.solutions) < len(self.solutions):
raise Unsupported()
def rewrite(self, select, snippets, solutions, kwargs):
if server.DEBUG:
print '---- rewrite', select, snippets, solutions
self.select = select
self.solutions = solutions
self.kwargs = kwargs
self.u_varname = None
self.removing_ambiguity = False
self.exists_snippet = {}
# we have to annotate the rqlst before inserting snippets, even though
# we'll have to redo it latter
self.annotate(select)
self.insert_snippets(snippets)
if not self.exists_snippet and self.u_varname:
# U has been inserted than cancelled, cleanup
select.undefine_variable(select.defined_vars[self.u_varname])
# clean solutions according to initial solutions
newsolutions = remove_solutions(solutions, select.solutions,
select.defined_vars)
assert len(newsolutions) >= len(solutions), \
'rewritten rql %s has lost some solutions, there is probably something '\
'wrong in your schema permission (for instance using a '\
'RQLExpression which insert a relation which doesn\'t exists in '\
'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
select, solutions, newsolutions)
if len(newsolutions) > len(solutions):
# the snippet has introduced some ambiguities, we have to resolve them
# "manually"
variantes = self.build_variantes(newsolutions)
# insert "is" where necessary
varexistsmap = {}
self.removing_ambiguity = True
for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
var = select.defined_vars[varname]
exists = var.references()[0].scope
exists.add_constant_restriction(var, 'is', etype, 'etype')
varexistsmap[mainvar] = exists
# insert ORED exists where necessary
for variante in variantes[1:]:
self.insert_snippets(snippets, varexistsmap)
for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
try:
var = select.defined_vars[varname]
except KeyError:
# not a newly inserted variable
continue
exists = var.references()[0].scope
exists.add_constant_restriction(var, 'is', etype, 'etype')
# recompute solutions
#select.annotated = False # avoid assertion error
self.compute_solutions()
# clean solutions according to initial solutions
newsolutions = remove_solutions(solutions, select.solutions,
select.defined_vars)
select.solutions = newsolutions
add_types_restriction(self.schema, select)
if server.DEBUG:
print '---- rewriten', select
def build_variantes(self, newsolutions):
variantes = set()
for sol in newsolutions:
variante = []
for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
variantes.add(tuple(variante))
# rebuild variantes as dict
variantes = [dict(variante) for variante in variantes]
# remove variable which have always the same type
for erqlexpr, mainvar, oldvar in self.rewritten:
it = iter(variantes)
etype = it.next()[(erqlexpr, mainvar, oldvar)]
for variante in it:
if variante[(erqlexpr, mainvar, oldvar)] != etype:
break
else:
for variante in variantes:
del variante[(erqlexpr, mainvar, oldvar)]
return variantes
def insert_snippets(self, snippets, varexistsmap=None):
self.rewritten = {}
for varname, erqlexprs in snippets:
if varexistsmap is not None and not varname in varexistsmap:
continue
try:
self.const = typed_eid(varname)
self.varname = self.const
self.rhs_rels = self.lhs_rels = {}
except ValueError:
self.varname = varname
self.const = None
self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
if varexistsmap is None:
self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
if not rel in stinfo['rhsrelations'])
else:
self.rhs_rels = self.lhs_rels = {}
parent = None
inserted = False
for erqlexpr in erqlexprs:
self.current_expr = erqlexpr
if varexistsmap is None:
try:
new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
except Unsupported:
continue
inserted = True
if new is not None:
self.exists_snippet[erqlexpr] = new
parent = parent or new
else:
# called to reintroduce snippet due to ambiguity creation,
# so skip snippets which are not introducing this ambiguity
exists = varexistsmap[varname]
if self.exists_snippet[erqlexpr] is exists:
self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
if varexistsmap is None and not inserted:
# no rql expression found matching rql solutions. User has no access right
raise Unauthorized()
def insert_snippet(self, varname, snippetrqlst, parent=None):
new = snippetrqlst.where.accept(self)
if new is not None:
try:
var = self.select.defined_vars[varname]
except KeyError:
# not a variable
pass
else:
if var.stinfo['optrelations']:
# use a subquery
subselect = stmts.Select()
subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
subselect.add_restriction(new.copy(subselect))
aliases = [varname]
for rel in var.stinfo['relations']:
rschema = self.schema.rschema(rel.r_type)
if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
self.select.remove_node(rel)
rel.children[0].name = varname
subselect.add_restriction(rel.copy(subselect))
for vref in rel.children[1].iget_nodes(nodes.VariableRef):
subselect.append_selected(vref.copy(subselect))
aliases.append(vref.name)
if self.u_varname:
# generate an identifier for the substitution
argname = subselect.allocate_varname()
while argname in self.kwargs:
argname = subselect.allocate_varname()
subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
'eid', unicode(argname), 'Substitute')
self.kwargs[argname] = self.session.user.eid
add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
assert parent is None
myunion = stmts.Union()
myunion.append(subselect)
aliases = [nodes.VariableRef(self.select.get_variable(name, i))
for i, name in enumerate(aliases)]
self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
self._cleanup_inserted(new)
try:
self.compute_solutions()
except Unsupported:
# some solutions have been lost, can't apply this rql expr
self.select.remove_subquery(new, undefine=True)
raise
return
new = nodes.Exists(new)
if parent is None:
self.select.add_restriction(new)
else:
grandpa = parent.parent
or_ = nodes.Or(parent, new)
grandpa.replace(parent, or_)
if not self.removing_ambiguity:
try:
self.compute_solutions()
except Unsupported:
# some solutions have been lost, can't apply this rql expr
if parent is None:
self.select.remove_node(new, undefine=True)
else:
parent.parent.replace(or_, or_.children[0])
self._cleanup_inserted(new)
raise
return new
def _cleanup_inserted(self, node):
# cleanup inserted variable references
for vref in node.iget_nodes(nodes.VariableRef):
vref.unregister_reference()
if not vref.variable.stinfo['references']:
# no more references, undefine the variable
del self.select.defined_vars[vref.name]
def _visit_binary(self, node, cls):
newnode = cls()
for c in node.children:
new = c.accept(self)
if new is None:
continue
newnode.append(new)
if len(newnode.children) == 0:
return None
if len(newnode.children) == 1:
return newnode.children[0]
return newnode
def _visit_unary(self, node, cls):
newc = node.children[0].accept(self)
if newc is None:
return None
newnode = cls()
newnode.append(newc)
return newnode
def visit_and(self, et):
return self._visit_binary(et, nodes.And)
def visit_or(self, ou):
return self._visit_binary(ou, nodes.Or)
def visit_not(self, node):
return self._visit_unary(node, nodes.Not)
def visit_exists(self, node):
return self._visit_unary(node, nodes.Exists)
def visit_relation(self, relation):
lhs, rhs = relation.get_variable_parts()
if lhs.name == 'X':
# on lhs
# see if we can reuse this relation
if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
if self._may_be_shared(relation, 'object'):
# ok, can share variable
term = self.lhs_rels[relation.r_type].children[1].children[0]
self._use_outer_term(rhs.name, term)
return
elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
# on rhs
# see if we can reuse this relation
if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
# ok, can share variable
term = self.rhs_rels[relation.r_type].children[0]
self._use_outer_term(lhs.name, term)
return
rel = nodes.Relation(relation.r_type, relation.optional)
for c in relation.children:
rel.append(c.accept(self))
return rel
def visit_comparison(self, cmp):
cmp_ = nodes.Comparison(cmp.operator)
for c in cmp.children:
cmp_.append(c.accept(self))
return cmp_
def visit_mathexpression(self, mexpr):
cmp_ = nodes.MathExpression(cmp.operator)
for c in cmp.children:
cmp_.append(c.accept(self))
return cmp_
def visit_function(self, function):
"""generate filter name for a function"""
function_ = nodes.Function(function.name)
for c in function.children:
function_.append(c.accept(self))
return function_
def visit_constant(self, constant):
"""generate filter name for a constant"""
return nodes.Constant(constant.value, constant.type)
def visit_variableref(self, vref):
"""get the sql name for a variable reference"""
if vref.name == 'X':
if self.const is not None:
return nodes.Constant(self.const, 'Int')
return nodes.VariableRef(self.select.get_variable(self.varname))
vname_or_term = self._get_varname_or_term(vref.name)
if isinstance(vname_or_term, basestring):
return nodes.VariableRef(self.select.get_variable(vname_or_term))
# shared term
return vname_or_term.copy(self.select)
def _may_be_shared(self, relation, target):
"""return True if the snippet relation can be skipped to use a relation
from the original query
"""
# if cardinality is in '?1', we can ignore the relation and use variable
# from the original query
rschema = self.schema.rschema(relation.r_type)
if target == 'object':
cardindex = 0
ttypes_func = rschema.objects
rprop = rschema.rproperty
else: # target == 'subject':
cardindex = 1
ttypes_func = rschema.subjects
rprop = lambda x,y,z: rschema.rproperty(y, x, z)
for etype in self.varstinfo['possibletypes']:
for ttype in ttypes_func(etype):
if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
return False
return True
def _use_outer_term(self, snippet_varname, term):
key = (self.current_expr, self.varname, snippet_varname)
if key in self.rewritten:
insertedvar = self.select.defined_vars.pop(self.rewritten[key])
for inserted_vref in insertedvar.references():
inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
self.rewritten[key] = term
def _get_varname_or_term(self, vname):
if vname == 'U':
if self.u_varname is None:
select = self.select
self.u_varname = select.allocate_varname()
# generate an identifier for the substitution
argname = select.allocate_varname()
while argname in self.kwargs:
argname = select.allocate_varname()
# insert "U eid %(u)s"
var = select.get_variable(self.u_varname)
select.add_constant_restriction(select.get_variable(self.u_varname),
'eid', unicode(argname), 'Substitute')
self.kwargs[argname] = self.session.user.eid
return self.u_varname
key = (self.current_expr, self.varname, vname)
try:
return self.rewritten[key]
except KeyError:
self.rewritten[key] = newvname = self.select.allocate_varname()
return newvname