|
1 """RQL rewriting utilities : insert rql expression snippets into rql syntax |
|
2 tree. |
|
3 |
|
4 This is used for instance for read security checking in the repository. |
|
5 |
|
6 :organization: Logilab |
|
7 :copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2. |
|
8 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
9 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses |
|
10 """ |
|
11 __docformat__ = "restructuredtext en" |
|
12 |
|
13 from rql import nodes as n, stmts, TypeResolverException |
|
14 |
|
15 from logilab.common.compat import any |
|
16 |
|
17 from cubicweb import Unauthorized, server, typed_eid |
|
18 from cubicweb.server.ssplanner import add_types_restriction |
|
19 |
|
20 |
|
21 def remove_solutions(origsolutions, solutions, defined): |
|
22 """when a rqlst has been generated from another by introducing security |
|
23 assertions, this method returns solutions which are contained in orig |
|
24 solutions |
|
25 """ |
|
26 newsolutions = [] |
|
27 for origsol in origsolutions: |
|
28 for newsol in solutions[:]: |
|
29 for var, etype in origsol.items(): |
|
30 try: |
|
31 if newsol[var] != etype: |
|
32 try: |
|
33 defined[var].stinfo['possibletypes'].remove(newsol[var]) |
|
34 except KeyError: |
|
35 pass |
|
36 break |
|
37 except KeyError: |
|
38 # variable has been rewritten |
|
39 continue |
|
40 else: |
|
41 newsolutions.append(newsol) |
|
42 solutions.remove(newsol) |
|
43 return newsolutions |
|
44 |
|
45 |
|
46 class Unsupported(Exception): pass |
|
47 |
|
48 |
|
49 class RQLRewriter(object): |
|
50 """insert some rql snippets into another rql syntax tree |
|
51 |
|
52 this class *isn't thread safe* |
|
53 """ |
|
54 |
|
55 def __init__(self, session): |
|
56 self.session = session |
|
57 vreg = session.vreg |
|
58 self.schema = vreg.schema |
|
59 self.annotate = vreg.rqlhelper.annotate |
|
60 self._compute_solutions = vreg.solutions |
|
61 |
|
62 def compute_solutions(self): |
|
63 self.annotate(self.select) |
|
64 try: |
|
65 self._compute_solutions(self.session, self.select, self.kwargs) |
|
66 except TypeResolverException: |
|
67 raise Unsupported(str(self.select)) |
|
68 if len(self.select.solutions) < len(self.solutions): |
|
69 raise Unsupported() |
|
70 |
|
71 def rewrite(self, select, snippets, solutions, kwargs): |
|
72 """ |
|
73 snippets: (varmap, list of rql expression) |
|
74 with varmap a *tuple* (select var, snippet var) |
|
75 """ |
|
76 if server.DEBUG: |
|
77 print '---- rewrite', select, snippets, solutions |
|
78 self.select = self.insert_scope = select |
|
79 self.solutions = solutions |
|
80 self.kwargs = kwargs |
|
81 self.u_varname = None |
|
82 self.removing_ambiguity = False |
|
83 self.exists_snippet = {} |
|
84 self.pending_keys = [] |
|
85 # we have to annotate the rqlst before inserting snippets, even though |
|
86 # we'll have to redo it latter |
|
87 self.annotate(select) |
|
88 self.insert_snippets(snippets) |
|
89 if not self.exists_snippet and self.u_varname: |
|
90 # U has been inserted than cancelled, cleanup |
|
91 select.undefine_variable(select.defined_vars[self.u_varname]) |
|
92 # clean solutions according to initial solutions |
|
93 newsolutions = remove_solutions(solutions, select.solutions, |
|
94 select.defined_vars) |
|
95 assert len(newsolutions) >= len(solutions), ( |
|
96 'rewritten rql %s has lost some solutions, there is probably ' |
|
97 'something wrong in your schema permission (for instance using a ' |
|
98 'RQLExpression which insert a relation which doesn\'t exists in ' |
|
99 'the schema)\nOrig solutions: %s\nnew solutions: %s' % ( |
|
100 select, solutions, newsolutions)) |
|
101 if len(newsolutions) > len(solutions): |
|
102 newsolutions = self.remove_ambiguities(snippets, newsolutions) |
|
103 select.solutions = newsolutions |
|
104 add_types_restriction(self.schema, select) |
|
105 if server.DEBUG: |
|
106 print '---- rewriten', select |
|
107 |
|
108 def insert_snippets(self, snippets, varexistsmap=None): |
|
109 self.rewritten = {} |
|
110 for varmap, rqlexprs in snippets: |
|
111 if varexistsmap is not None and not varmap in varexistsmap: |
|
112 continue |
|
113 self.varmap = varmap |
|
114 selectvar, snippetvar = varmap |
|
115 assert snippetvar in 'SOX' |
|
116 self.revvarmap = {snippetvar: selectvar} |
|
117 self.varinfo = vi = {} |
|
118 try: |
|
119 vi['const'] = typed_eid(selectvar) # XXX gae |
|
120 vi['rhs_rels'] = vi['lhs_rels'] = {} |
|
121 except ValueError: |
|
122 vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo |
|
123 if varexistsmap is None: |
|
124 vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations']) |
|
125 vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations'] |
|
126 if not r in sti['rhsrelations']) |
|
127 else: |
|
128 vi['rhs_rels'] = vi['lhs_rels'] = {} |
|
129 parent = None |
|
130 inserted = False |
|
131 for rqlexpr in rqlexprs: |
|
132 self.current_expr = rqlexpr |
|
133 if varexistsmap is None: |
|
134 try: |
|
135 new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent) |
|
136 except Unsupported: |
|
137 import traceback |
|
138 traceback.print_exc() |
|
139 continue |
|
140 inserted = True |
|
141 if new is not None: |
|
142 self.exists_snippet[rqlexpr] = new |
|
143 parent = parent or new |
|
144 else: |
|
145 # called to reintroduce snippet due to ambiguity creation, |
|
146 # so skip snippets which are not introducing this ambiguity |
|
147 exists = varexistsmap[varmap] |
|
148 if self.exists_snippet[rqlexpr] is exists: |
|
149 self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists) |
|
150 if varexistsmap is None and not inserted: |
|
151 # no rql expression found matching rql solutions. User has no access right |
|
152 raise Unauthorized(str((varmap, str(self.select), [expr.expression for expr in rqlexprs]))) |
|
153 |
|
154 def insert_snippet(self, varmap, snippetrqlst, parent=None): |
|
155 new = snippetrqlst.where.accept(self) |
|
156 if new is not None: |
|
157 if self.varinfo.get('stinfo', {}).get('optrelations'): |
|
158 assert parent is None |
|
159 self.insert_scope = self.snippet_subquery(varmap, new) |
|
160 self.insert_pending() |
|
161 self.insert_scope = self.select |
|
162 return |
|
163 new = n.Exists(new) |
|
164 if parent is None: |
|
165 self.insert_scope.add_restriction(new) |
|
166 else: |
|
167 grandpa = parent.parent |
|
168 or_ = n.Or(parent, new) |
|
169 grandpa.replace(parent, or_) |
|
170 if not self.removing_ambiguity: |
|
171 try: |
|
172 self.compute_solutions() |
|
173 except Unsupported: |
|
174 # some solutions have been lost, can't apply this rql expr |
|
175 if parent is None: |
|
176 self.select.remove_node(new, undefine=True) |
|
177 else: |
|
178 parent.parent.replace(or_, or_.children[0]) |
|
179 self._cleanup_inserted(new) |
|
180 raise |
|
181 else: |
|
182 self.insert_scope = new |
|
183 self.insert_pending() |
|
184 self.insert_scope = self.select |
|
185 return new |
|
186 self.insert_pending() |
|
187 |
|
188 def insert_pending(self): |
|
189 """pending_keys hold variable referenced by U has_<action>_permission X |
|
190 relation. |
|
191 |
|
192 Once the snippet introducing this has been inserted and solutions |
|
193 recomputed, we have to insert snippet defined for <action> of entity |
|
194 types taken by X |
|
195 """ |
|
196 while self.pending_keys: |
|
197 key, action = self.pending_keys.pop() |
|
198 try: |
|
199 varname = self.rewritten[key] |
|
200 except KeyError: |
|
201 try: |
|
202 varname = self.revvarmap[key[-1]] |
|
203 except KeyError: |
|
204 # variable isn't used anywhere else, we can't insert security |
|
205 raise Unauthorized() |
|
206 ptypes = self.select.defined_vars[varname].stinfo['possibletypes'] |
|
207 if len(ptypes) > 1: |
|
208 # XXX dunno how to handle this |
|
209 self.session.error( |
|
210 'cant check security of %s, ambigous type for %s in %s', |
|
211 self.select, varname, key[0]) # key[0] == the rql expression |
|
212 raise Unauthorized() |
|
213 etype = iter(ptypes).next() |
|
214 eschema = self.schema.eschema(etype) |
|
215 if not eschema.has_perm(self.session, action): |
|
216 rqlexprs = eschema.get_rqlexprs(action) |
|
217 if not rqlexprs: |
|
218 raise Unauthorised() |
|
219 self.insert_snippets([((varname, 'X'), rqlexprs)]) |
|
220 |
|
221 def snippet_subquery(self, varmap, transformedsnippet): |
|
222 """introduce the given snippet in a subquery""" |
|
223 subselect = stmts.Select() |
|
224 selectvar, snippetvar = varmap |
|
225 subselect.append_selected(n.VariableRef( |
|
226 subselect.get_variable(selectvar))) |
|
227 aliases = [selectvar] |
|
228 subselect.add_restriction(transformedsnippet.copy(subselect)) |
|
229 stinfo = self.varinfo['stinfo'] |
|
230 for rel in stinfo['relations']: |
|
231 rschema = self.schema.rschema(rel.r_type) |
|
232 if rschema.is_final() or (rschema.inlined and |
|
233 not rel in stinfo['rhsrelations']): |
|
234 self.select.remove_node(rel) |
|
235 rel.children[0].name = selectvar |
|
236 subselect.add_restriction(rel.copy(subselect)) |
|
237 for vref in rel.children[1].iget_nodes(n.VariableRef): |
|
238 subselect.append_selected(vref.copy(subselect)) |
|
239 aliases.append(vref.name) |
|
240 if self.u_varname: |
|
241 # generate an identifier for the substitution |
|
242 argname = subselect.allocate_varname() |
|
243 while argname in self.kwargs: |
|
244 argname = subselect.allocate_varname() |
|
245 subselect.add_constant_restriction(subselect.get_variable(self.u_varname), |
|
246 'eid', unicode(argname), 'Substitute') |
|
247 self.kwargs[argname] = self.session.user.eid |
|
248 add_types_restriction(self.schema, subselect, subselect, |
|
249 solutions=self.solutions) |
|
250 myunion = stmts.Union() |
|
251 myunion.append(subselect) |
|
252 aliases = [n.VariableRef(self.select.get_variable(name, i)) |
|
253 for i, name in enumerate(aliases)] |
|
254 self.select.add_subquery(n.SubQuery(aliases, myunion), check=False) |
|
255 self._cleanup_inserted(transformedsnippet) |
|
256 try: |
|
257 self.compute_solutions() |
|
258 except Unsupported: |
|
259 # some solutions have been lost, can't apply this rql expr |
|
260 self.select.remove_subquery(new, undefine=True) |
|
261 raise |
|
262 return subselect |
|
263 |
|
264 def remove_ambiguities(self, snippets, newsolutions): |
|
265 # the snippet has introduced some ambiguities, we have to resolve them |
|
266 # "manually" |
|
267 variantes = self.build_variantes(newsolutions) |
|
268 # insert "is" where necessary |
|
269 varexistsmap = {} |
|
270 self.removing_ambiguity = True |
|
271 for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems(): |
|
272 varname = self.rewritten[(erqlexpr, varmap, oldvarname)] |
|
273 var = self.select.defined_vars[varname] |
|
274 exists = var.references()[0].scope |
|
275 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
276 varexistsmap[varmap] = exists |
|
277 # insert ORED exists where necessary |
|
278 for variante in variantes[1:]: |
|
279 self.insert_snippets(snippets, varexistsmap) |
|
280 for key, etype in variante.iteritems(): |
|
281 varname = self.rewritten[key] |
|
282 try: |
|
283 var = self.select.defined_vars[varname] |
|
284 except KeyError: |
|
285 # not a newly inserted variable |
|
286 continue |
|
287 exists = var.references()[0].scope |
|
288 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
289 # recompute solutions |
|
290 #select.annotated = False # avoid assertion error |
|
291 self.compute_solutions() |
|
292 # clean solutions according to initial solutions |
|
293 return remove_solutions(self.solutions, self.select.solutions, |
|
294 self.select.defined_vars) |
|
295 |
|
296 def build_variantes(self, newsolutions): |
|
297 variantes = set() |
|
298 for sol in newsolutions: |
|
299 variante = [] |
|
300 for key, newvar in self.rewritten.iteritems(): |
|
301 variante.append( (key, sol[newvar]) ) |
|
302 variantes.add(tuple(variante)) |
|
303 # rebuild variantes as dict |
|
304 variantes = [dict(variante) for variante in variantes] |
|
305 # remove variable which have always the same type |
|
306 for key in self.rewritten: |
|
307 it = iter(variantes) |
|
308 etype = it.next()[key] |
|
309 for variante in it: |
|
310 if variante[key] != etype: |
|
311 break |
|
312 else: |
|
313 for variante in variantes: |
|
314 del variante[key] |
|
315 return variantes |
|
316 |
|
317 def _cleanup_inserted(self, node): |
|
318 # cleanup inserted variable references |
|
319 for vref in node.iget_nodes(n.VariableRef): |
|
320 vref.unregister_reference() |
|
321 if not vref.variable.stinfo['references']: |
|
322 # no more references, undefine the variable |
|
323 del self.select.defined_vars[vref.name] |
|
324 |
|
325 def _may_be_shared(self, relation, target, searchedvarname): |
|
326 """return True if the snippet relation can be skipped to use a relation |
|
327 from the original query |
|
328 """ |
|
329 # if cardinality is in '?1', we can ignore the relation and use variable |
|
330 # from the original query |
|
331 rschema = self.schema.rschema(relation.r_type) |
|
332 if target == 'object': |
|
333 cardindex = 0 |
|
334 ttypes_func = rschema.objects |
|
335 rprop = rschema.rproperty |
|
336 else: # target == 'subject': |
|
337 cardindex = 1 |
|
338 ttypes_func = rschema.subjects |
|
339 rprop = lambda x, y, z: rschema.rproperty(y, x, z) |
|
340 for etype in self.varinfo['stinfo']['possibletypes']: |
|
341 for ttype in ttypes_func(etype): |
|
342 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*': |
|
343 return False |
|
344 return True |
|
345 |
|
346 def _use_outer_term(self, snippet_varname, term): |
|
347 key = (self.current_expr, self.varmap, snippet_varname) |
|
348 if key in self.rewritten: |
|
349 insertedvar = self.select.defined_vars.pop(self.rewritten[key]) |
|
350 for inserted_vref in insertedvar.references(): |
|
351 inserted_vref.parent.replace(inserted_vref, term.copy(self.select)) |
|
352 self.rewritten[key] = term.name |
|
353 |
|
354 def _get_varname_or_term(self, vname): |
|
355 if vname == 'U': |
|
356 if self.u_varname is None: |
|
357 select = self.select |
|
358 self.u_varname = select.allocate_varname() |
|
359 # generate an identifier for the substitution |
|
360 argname = select.allocate_varname() |
|
361 while argname in self.kwargs: |
|
362 argname = select.allocate_varname() |
|
363 # insert "U eid %(u)s" |
|
364 var = select.get_variable(self.u_varname) |
|
365 select.add_constant_restriction(select.get_variable(self.u_varname), |
|
366 'eid', unicode(argname), 'Substitute') |
|
367 self.kwargs[argname] = self.session.user.eid |
|
368 return self.u_varname |
|
369 key = (self.current_expr, self.varmap, vname) |
|
370 try: |
|
371 return self.rewritten[key] |
|
372 except KeyError: |
|
373 self.rewritten[key] = newvname = self.select.allocate_varname() |
|
374 return newvname |
|
375 |
|
376 # visitor methods ########################################################## |
|
377 |
|
378 def _visit_binary(self, node, cls): |
|
379 newnode = cls() |
|
380 for c in node.children: |
|
381 new = c.accept(self) |
|
382 if new is None: |
|
383 continue |
|
384 newnode.append(new) |
|
385 if len(newnode.children) == 0: |
|
386 return None |
|
387 if len(newnode.children) == 1: |
|
388 return newnode.children[0] |
|
389 return newnode |
|
390 |
|
391 def _visit_unary(self, node, cls): |
|
392 newc = node.children[0].accept(self) |
|
393 if newc is None: |
|
394 return None |
|
395 newnode = cls() |
|
396 newnode.append(newc) |
|
397 return newnode |
|
398 |
|
399 def visit_and(self, node): |
|
400 return self._visit_binary(node, n.And) |
|
401 |
|
402 def visit_or(self, node): |
|
403 return self._visit_binary(node, n.Or) |
|
404 |
|
405 def visit_not(self, node): |
|
406 return self._visit_unary(node, n.Not) |
|
407 |
|
408 def visit_exists(self, node): |
|
409 return self._visit_unary(node, n.Exists) |
|
410 |
|
411 def visit_relation(self, node): |
|
412 lhs, rhs = node.get_variable_parts() |
|
413 if node.r_type in ('has_add_permission', 'has_update_permission', |
|
414 'has_delete_permission', 'has_read_permission'): |
|
415 assert lhs.name == 'U' |
|
416 action = node.r_type.split('_')[1] |
|
417 key = (self.current_expr, self.varmap, rhs.name) |
|
418 self.pending_keys.append( (key, action) ) |
|
419 return |
|
420 if lhs.name in self.revvarmap: |
|
421 # on lhs |
|
422 # see if we can reuse this relation |
|
423 rels = self.varinfo['lhs_rels'] |
|
424 if (node.r_type in rels and isinstance(rhs, n.VariableRef) |
|
425 and rhs.name != 'U' and not rels[node.r_type].neged(strict=True) |
|
426 and self._may_be_shared(node, 'object', lhs.name)): |
|
427 # ok, can share variable |
|
428 term = rels[node.r_type].children[1].children[0] |
|
429 self._use_outer_term(rhs.name, term) |
|
430 return |
|
431 elif isinstance(rhs, n.VariableRef) and rhs.name in self.revvarmap and lhs.name != 'U': |
|
432 # on rhs |
|
433 # see if we can reuse this relation |
|
434 rels = self.varinfo['rhs_rels'] |
|
435 if (node.r_type in rels and not rels[node.r_type].neged(strict=True) |
|
436 and self._may_be_shared(node, 'subject', rhs.name)): |
|
437 # ok, can share variable |
|
438 term = rels[node.r_type].children[0] |
|
439 self._use_outer_term(lhs.name, term) |
|
440 return |
|
441 rel = n.Relation(node.r_type, node.optional) |
|
442 for c in node.children: |
|
443 rel.append(c.accept(self)) |
|
444 return rel |
|
445 |
|
446 def visit_comparison(self, node): |
|
447 cmp_ = n.Comparison(node.operator) |
|
448 for c in node.children: |
|
449 cmp_.append(c.accept(self)) |
|
450 return cmp_ |
|
451 |
|
452 def visit_mathexpression(self, node): |
|
453 cmp_ = n.MathExpression(node.operator) |
|
454 for c in cmp.children: |
|
455 cmp_.append(c.accept(self)) |
|
456 return cmp_ |
|
457 |
|
458 def visit_function(self, node): |
|
459 """generate filter name for a function""" |
|
460 function_ = n.Function(node.name) |
|
461 for c in node.children: |
|
462 function_.append(c.accept(self)) |
|
463 return function_ |
|
464 |
|
465 def visit_constant(self, node): |
|
466 """generate filter name for a constant""" |
|
467 return n.Constant(node.value, node.type) |
|
468 |
|
469 def visit_variableref(self, node): |
|
470 """get the sql name for a variable reference""" |
|
471 if node.name in self.revvarmap: |
|
472 if self.varinfo.get('const') is not None: |
|
473 return n.Constant(self.varinfo['const'], 'Int') # XXX gae |
|
474 return n.VariableRef(self.select.get_variable( |
|
475 self.revvarmap[node.name])) |
|
476 vname_or_term = self._get_varname_or_term(node.name) |
|
477 if isinstance(vname_or_term, basestring): |
|
478 return n.VariableRef(self.select.get_variable(vname_or_term)) |
|
479 # shared term |
|
480 return vname_or_term.copy(self.select) |