--- a/rqlrewrite.py Wed Mar 30 11:07:16 2011 +0200
+++ b/rqlrewrite.py Wed Mar 30 11:08:15 2011 +0200
@@ -24,6 +24,7 @@
__docformat__ = "restructuredtext en"
from rql import nodes as n, stmts, TypeResolverException
+from rql.utils import common_parent
from yams import BadSchemaDefinition
from logilab.common.graph import has_path
@@ -180,13 +181,23 @@
def insert_snippets(self, snippets, varexistsmap=None):
self.rewritten = {}
for varmap, rqlexprs in snippets:
+ if isinstance(varmap, dict):
+ varmap = tuple(sorted(varmap.items()))
+ else:
+ assert isinstance(varmap, tuple), varmap
if varexistsmap is not None and not varmap in varexistsmap:
continue
- self.varmap = varmap
- selectvar, snippetvar = varmap
+ self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
+
+ def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
+ self.varmap = varmap
+ self.revvarmap = {}
+ self.varinfos = []
+ for i, (selectvar, snippetvar) in enumerate(varmap):
assert snippetvar in 'SOX'
- self.revvarmap = {snippetvar: selectvar}
- self.varinfo = vi = {}
+ self.revvarmap[snippetvar] = (selectvar, i)
+ vi = {}
+ self.varinfos.append(vi)
try:
vi['const'] = typed_eid(selectvar) # XXX gae
vi['rhs_rels'] = vi['lhs_rels'] = {}
@@ -194,42 +205,42 @@
try:
vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
except KeyError:
- # variable has been moved to a newly inserted subquery
+ # variable may have been moved to a newly inserted subquery
# we should insert snippet in that subquery
subquery = self.select.aliases[selectvar].query
assert len(subquery.children) == 1
subselect = subquery.children[0]
RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
subselect.solutions, self.kwargs)
- continue
+ return
if varexistsmap is None:
vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
if not r in sti['rhsrelations'])
else:
vi['rhs_rels'] = vi['lhs_rels'] = {}
- parent = None
- inserted = False
- for rqlexpr in rqlexprs:
- self.current_expr = rqlexpr
- if varexistsmap is None:
- try:
- new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
- except Unsupported:
- continue
- inserted = True
- if new is not None:
- self.exists_snippet[rqlexpr] = new
- parent = parent or new
- else:
- # called to reintroduce snippet due to ambiguity creation,
- # so skip snippets which are not introducing this ambiguity
- exists = varexistsmap[varmap]
- if self.exists_snippet[rqlexpr] is exists:
- self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
- if varexistsmap is None and not inserted:
- # no rql expression found matching rql solutions. User has no access right
- raise Unauthorized()
+ parent = None
+ inserted = False
+ for rqlexpr in rqlexprs:
+ self.current_expr = rqlexpr
+ if varexistsmap is None:
+ try:
+ new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
+ except Unsupported:
+ continue
+ inserted = True
+ if new is not None:
+ self.exists_snippet[rqlexpr] = new
+ parent = parent or new
+ else:
+ # called to reintroduce snippet due to ambiguity creation,
+ # so skip snippets which are not introducing this ambiguity
+ exists = varexistsmap[varmap]
+ if self.exists_snippet[rqlexpr] is exists:
+ self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
+ if varexistsmap is None and not inserted:
+ # no rql expression found matching rql solutions. User has no access right
+ raise Unauthorized() # XXX bad constraint when inserting constraints
def insert_snippet(self, varmap, snippetrqlst, parent=None):
new = snippetrqlst.where.accept(self)
@@ -243,10 +254,16 @@
def _insert_snippet(self, varmap, parent, new):
if new is not None:
if self._insert_scope is None:
- insert_scope = self.varinfo.get('stinfo', {}).get('scope', self.select)
+ insert_scope = None
+ for vi in self.varinfos:
+ scope = vi.get('stinfo', {}).get('scope', self.select)
+ if insert_scope is None:
+ insert_scope = scope
+ else:
+ insert_scope = common_parent(scope, insert_scope)
else:
insert_scope = self._insert_scope
- if self.varinfo.get('stinfo', {}).get('optrelations'):
+ if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos):
assert parent is None
self._insert_scope = self.snippet_subquery(varmap, new)
self.insert_pending()
@@ -292,7 +309,7 @@
varname = self.rewritten[key]
except KeyError:
try:
- varname = self.revvarmap[key[-1]]
+ varname = self.revvarmap[key[-1]][0]
except KeyError:
# variable isn't used anywhere else, we can't insert security
raise Unauthorized()
@@ -309,45 +326,51 @@
rqlexprs = eschema.get_rqlexprs(action)
if not rqlexprs:
raise Unauthorized()
- self.insert_snippets([((varname, 'X'), rqlexprs)])
+ self.insert_snippets([({varname: 'X'}, rqlexprs)])
def snippet_subquery(self, varmap, transformedsnippet):
"""introduce the given snippet in a subquery"""
subselect = stmts.Select()
- selectvar = varmap[0]
- subselectvar = subselect.get_variable(selectvar)
- subselect.append_selected(n.VariableRef(subselectvar))
snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
- aliases = [selectvar]
- stinfo = self.varinfo['stinfo']
- need_null_test = False
- for rel in stinfo['relations']:
- rschema = self.schema.rschema(rel.r_type)
- if rschema.final or (rschema.inlined and
- not rel in stinfo['rhsrelations']):
- rel.children[0].name = selectvar # XXX explain why
- subselect.add_restriction(rel.copy(subselect))
- for vref in rel.children[1].iget_nodes(n.VariableRef):
- if isinstance(vref.variable, n.ColumnAlias):
- # XXX could probably be handled by generating the subquery
- # into the detected subquery
- raise BadSchemaDefinition(
- "cant insert security because of usage two inlined "
- "relations in this query. You should probably at "
- "least uninline %s" % rel.r_type)
- subselect.append_selected(vref.copy(subselect))
- aliases.append(vref.name)
- self.select.remove_node(rel)
- # when some inlined relation has to be copied in the subquery,
- # we need to test that either value is NULL or that the snippet
- # condition is satisfied
- if rschema.inlined and rel.optional:
- need_null_test = True
- if need_null_test:
- snippetrqlst = n.Or(
- n.make_relation(subselectvar, 'is', (None, None), n.Constant,
- operator='='),
- snippetrqlst)
+ aliases = []
+ rels_done = set()
+ for i, (selectvar, snippetvar) in enumerate(varmap):
+ subselectvar = subselect.get_variable(selectvar)
+ subselect.append_selected(n.VariableRef(subselectvar))
+ aliases.append(selectvar)
+ vi = self.varinfos[i]
+ need_null_test = False
+ stinfo = vi['stinfo']
+ for rel in stinfo['relations']:
+ if rel in rels_done:
+ continue
+ rels_done.add(rel)
+ rschema = self.schema.rschema(rel.r_type)
+ if rschema.final or (rschema.inlined and
+ not rel in stinfo['rhsrelations']):
+ rel.children[0].name = selectvar # XXX explain why
+ subselect.add_restriction(rel.copy(subselect))
+ for vref in rel.children[1].iget_nodes(n.VariableRef):
+ if isinstance(vref.variable, n.ColumnAlias):
+ # XXX could probably be handled by generating the
+ # subquery into the detected subquery
+ raise BadSchemaDefinition(
+ "cant insert security because of usage two inlined "
+ "relations in this query. You should probably at "
+ "least uninline %s" % rel.r_type)
+ subselect.append_selected(vref.copy(subselect))
+ aliases.append(vref.name)
+ self.select.remove_node(rel)
+ # when some inlined relation has to be copied in the
+ # subquery, we need to test that either value is NULL or
+ # that the snippet condition is satisfied
+ if rschema.inlined and rel.optional:
+ need_null_test = True
+ if need_null_test:
+ snippetrqlst = n.Or(
+ n.make_relation(subselectvar, 'is', (None, None), n.Constant,
+ operator='='),
+ snippetrqlst)
subselect.add_restriction(snippetrqlst)
if self.u_varname:
# generate an identifier for the substitution
@@ -439,30 +462,32 @@
original query, return that relation node
"""
rschema = self.schema.rschema(sniprel.r_type)
- try:
- if target == 'object':
- orel = self.varinfo['lhs_rels'][sniprel.r_type]
- cardindex = 0
- ttypes_func = rschema.objects
- rdef = rschema.rdef
- else: # target == 'subject':
- orel = self.varinfo['rhs_rels'][sniprel.r_type]
- cardindex = 1
- ttypes_func = rschema.subjects
- rdef = lambda x, y: rschema.rdef(y, x)
- except KeyError:
- # may be raised by self.varinfo['xhs_rels'][sniprel.r_type]
- return None
- # can't share neged relation or relations with different outer join
- if (orel.neged(strict=True) or sniprel.neged(strict=True)
- or (orel.optional and orel.optional != sniprel.optional)):
- return None
- # if cardinality is in '?1', we can ignore the snippet relation and use
- # variable from the original query
- for etype in self.varinfo['stinfo']['possibletypes']:
- for ttype in ttypes_func(etype):
- if rdef(etype, ttype).cardinality[cardindex] in '+*':
- return None
+ for vi in self.varinfos:
+ try:
+ if target == 'object':
+ orel = vi['lhs_rels'][sniprel.r_type]
+ cardindex = 0
+ ttypes_func = rschema.objects
+ rdef = rschema.rdef
+ else: # target == 'subject':
+ orel = vi['rhs_rels'][sniprel.r_type]
+ cardindex = 1
+ ttypes_func = rschema.subjects
+ rdef = lambda x, y: rschema.rdef(y, x)
+ except KeyError:
+ # may be raised by vi['xhs_rels'][sniprel.r_type]
+ return None
+ # can't share neged relation or relations with different outer join
+ if (orel.neged(strict=True) or sniprel.neged(strict=True)
+ or (orel.optional and orel.optional != sniprel.optional)):
+ return None
+ # if cardinality is in '?1', we can ignore the snippet relation and use
+ # variable from the original query
+ for etype in vi['stinfo']['possibletypes']:
+ for ttype in ttypes_func(etype):
+ if rdef(etype, ttype).cardinality[cardindex] in '+*':
+ return None
+ break
return orel
def _use_orig_term(self, snippet_varname, term):
@@ -601,10 +626,11 @@
def visit_variableref(self, node):
"""get the sql name for a variable reference"""
if node.name in self.revvarmap:
- if self.varinfo.get('const') is not None:
- return n.Constant(self.varinfo['const'], 'Int') # XXX gae
- return n.VariableRef(self.select.get_variable(
- self.revvarmap[node.name]))
+ selectvar, index = self.revvarmap[node.name]
+ vi = self.varinfos[index]
+ if vi.get('const') is not None:
+ return n.Constant(vi['const'], 'Int') # XXX gae
+ return n.VariableRef(self.select.get_variable(selectvar))
vname_or_term = self._get_varname_or_term(node.name)
if isinstance(vname_or_term, basestring):
return n.VariableRef(self.select.get_variable(vname_or_term))