--- a/rqlrewrite.py Thu Jun 09 16:41:41 2011 +0200
+++ b/rqlrewrite.py Fri Jun 17 18:51:01 2011 +0200
@@ -20,12 +20,16 @@
This is used for instance for read security checking in the repository.
"""
+from __future__ import with_statement
__docformat__ = "restructuredtext en"
from rql import nodes as n, stmts, TypeResolverException
from rql.utils import common_parent
+
from yams import BadSchemaDefinition
+
+from logilab.common import tempattr
from logilab.common.graph import has_path
from cubicweb import Unauthorized, typed_eid
@@ -156,7 +160,6 @@
self.exists_snippet = {}
self.pending_keys = []
self.existingvars = existingvars
- self._insert_scope = None
# we have to annotate the rqlst before inserting snippets, even though
# we'll have to redo it latter
self.annotate(select)
@@ -193,6 +196,7 @@
self.varmap = varmap
self.revvarmap = {}
self.varinfos = []
+ self._insert_scope = None
for i, (selectvar, snippetvar) in enumerate(varmap):
assert snippetvar in 'SOX'
self.revvarmap[snippetvar] = (selectvar, i)
@@ -229,7 +233,7 @@
except Unsupported:
continue
inserted = True
- if new is not None:
+ if new is not None and self._insert_scope is None:
self.exists_snippet[rqlexpr] = new
parent = parent or new
else:
@@ -263,11 +267,12 @@
insert_scope = common_parent(scope, insert_scope)
else:
insert_scope = self._insert_scope
- if any(vi.get('stinfo', {}).get('optrelations') for vi in self.varinfos):
+ if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
+ for vi in self.varinfos):
assert parent is None
self._insert_scope = self.snippet_subquery(varmap, new)
self.insert_pending()
- self._insert_scope = None
+ #self._insert_scope = None
return
if not isinstance(new, (n.Exists, n.Not)):
new = n.Exists(new)
@@ -283,15 +288,14 @@
except Unsupported:
# some solutions have been lost, can't apply this rql expr
if parent is None:
- self.select.remove_node(new, undefine=True)
+ self.current_statement().remove_node(new, undefine=True)
else:
parent.parent.replace(or_, or_.children[0])
self._cleanup_inserted(new)
raise
else:
- self._insert_scope = new
- self.insert_pending()
- self._insert_scope = None
+ with tempattr(self, '_insert_scope', new):
+ self.insert_pending()
return new
self.insert_pending()
@@ -303,6 +307,7 @@
recomputed, we have to insert snippet defined for <action> of entity
types taken by X
"""
+ stmt = self.current_statement()
while self.pending_keys:
key, action = self.pending_keys.pop()
try:
@@ -313,12 +318,12 @@
except KeyError:
# variable isn't used anywhere else, we can't insert security
raise Unauthorized()
- ptypes = self.select.defined_vars[varname].stinfo['possibletypes']
+ ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
if len(ptypes) > 1:
# XXX dunno how to handle this
self.session.error(
'cant check security of %s, ambigous type for %s in %s',
- self.select, varname, key[0]) # key[0] == the rql expression
+ stmt, varname, key[0]) # key[0] == the rql expression
raise Unauthorized()
etype = iter(ptypes).next()
eschema = self.schema.eschema(etype)
@@ -499,17 +504,18 @@
self.rewritten[key] = term.name
def _get_varname_or_term(self, vname):
+ stmt = self.current_statement()
if vname == 'U':
+ stmt = self.select
if self.u_varname is None:
- select = self.select
- self.u_varname = select.allocate_varname()
+ self.u_varname = stmt.allocate_varname()
# generate an identifier for the substitution
- argname = select.allocate_varname()
+ argname = stmt.allocate_varname()
while argname in self.kwargs:
- argname = select.allocate_varname()
+ argname = stmt.allocate_varname()
# insert "U eid %(u)s"
- select.add_constant_restriction(
- select.get_variable(self.u_varname),
+ stmt.add_constant_restriction(
+ stmt.get_variable(self.u_varname),
'eid', unicode(argname), 'Substitute')
self.kwargs[argname] = self.session.user.eid
return self.u_varname
@@ -517,7 +523,7 @@
try:
return self.rewritten[key]
except KeyError:
- self.rewritten[key] = newvname = self.select.allocate_varname()
+ self.rewritten[key] = newvname = stmt.allocate_varname()
return newvname
# visitor methods ##########################################################
@@ -625,14 +631,20 @@
def visit_variableref(self, node):
"""get the sql name for a variable reference"""
+ stmt = self.current_statement()
if node.name in self.revvarmap:
selectvar, index = self.revvarmap[node.name]
vi = self.varinfos[index]
if vi.get('const') is not None:
return n.Constant(vi['const'], 'Int') # XXX gae
- return n.VariableRef(self.select.get_variable(selectvar))
+ return n.VariableRef(stmt.get_variable(selectvar))
vname_or_term = self._get_varname_or_term(node.name)
if isinstance(vname_or_term, basestring):
- return n.VariableRef(self.select.get_variable(vname_or_term))
+ return n.VariableRef(stmt.get_variable(vname_or_term))
# shared term
- return vname_or_term.copy(self.select)
+ return vname_or_term.copy(stmt)
+
+ def current_statement(self):
+ if self._insert_scope is None:
+ return self.select
+ return self._insert_scope.stmt