|
1 """RQL rewriting utilities, used for read security checking |
|
2 |
|
3 :organization: Logilab |
|
4 :copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
|
5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
6 """ |
|
7 |
|
8 from rql import nodes, stmts, TypeResolverException |
|
9 from cubicweb import Unauthorized, server, typed_eid |
|
10 from cubicweb.server.ssplanner import add_types_restriction |
|
11 |
|
12 def remove_solutions(origsolutions, solutions, defined): |
|
13 """when a rqlst has been generated from another by introducing security |
|
14 assertions, this method returns solutions which are contained in orig |
|
15 solutions |
|
16 """ |
|
17 newsolutions = [] |
|
18 for origsol in origsolutions: |
|
19 for newsol in solutions[:]: |
|
20 for var, etype in origsol.items(): |
|
21 try: |
|
22 if newsol[var] != etype: |
|
23 try: |
|
24 defined[var].stinfo['possibletypes'].remove(newsol[var]) |
|
25 except KeyError: |
|
26 pass |
|
27 break |
|
28 except KeyError,ex: |
|
29 # variable has been rewritten |
|
30 continue |
|
31 else: |
|
32 newsolutions.append(newsol) |
|
33 solutions.remove(newsol) |
|
34 return newsolutions |
|
35 |
|
36 class Unsupported(Exception): pass |
|
37 |
|
38 class RQLRewriter(object): |
|
39 """insert some rql snippets into another rql syntax tree""" |
|
40 def __init__(self, querier, session): |
|
41 self.session = session |
|
42 self.annotate = querier._rqlhelper.annotate |
|
43 self._compute_solutions = querier.solutions |
|
44 self.schema = querier.schema |
|
45 |
|
46 def compute_solutions(self): |
|
47 self.annotate(self.select) |
|
48 try: |
|
49 self._compute_solutions(self.session, self.select, self.kwargs) |
|
50 except TypeResolverException: |
|
51 raise Unsupported() |
|
52 if len(self.select.solutions) < len(self.solutions): |
|
53 raise Unsupported() |
|
54 |
|
55 def rewrite(self, select, snippets, solutions, kwargs): |
|
56 if server.DEBUG: |
|
57 print '---- rewrite', select, snippets, solutions |
|
58 self.select = select |
|
59 self.solutions = solutions |
|
60 self.kwargs = kwargs |
|
61 self.u_varname = None |
|
62 self.removing_ambiguity = False |
|
63 self.exists_snippet = {} |
|
64 # we have to annotate the rqlst before inserting snippets, even though |
|
65 # we'll have to redo it latter |
|
66 self.annotate(select) |
|
67 self.insert_snippets(snippets) |
|
68 if not self.exists_snippet and self.u_varname: |
|
69 # U has been inserted than cancelled, cleanup |
|
70 select.undefine_variable(select.defined_vars[self.u_varname]) |
|
71 # clean solutions according to initial solutions |
|
72 newsolutions = remove_solutions(solutions, select.solutions, |
|
73 select.defined_vars) |
|
74 assert len(newsolutions) >= len(solutions), \ |
|
75 'rewritten rql %s has lost some solutions, there is probably something '\ |
|
76 'wrong in your schema permission (for instance using a '\ |
|
77 'RQLExpression which insert a relation which doesn\'t exists in '\ |
|
78 'the schema)\nOrig solutions: %s\nnew solutions: %s' % ( |
|
79 select, solutions, newsolutions) |
|
80 if len(newsolutions) > len(solutions): |
|
81 # the snippet has introduced some ambiguities, we have to resolve them |
|
82 # "manually" |
|
83 variantes = self.build_variantes(newsolutions) |
|
84 # insert "is" where necessary |
|
85 varexistsmap = {} |
|
86 self.removing_ambiguity = True |
|
87 for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems(): |
|
88 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] |
|
89 var = select.defined_vars[varname] |
|
90 exists = var.references()[0].scope |
|
91 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
92 varexistsmap[mainvar] = exists |
|
93 # insert ORED exists where necessary |
|
94 for variante in variantes[1:]: |
|
95 self.insert_snippets(snippets, varexistsmap) |
|
96 for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems(): |
|
97 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] |
|
98 try: |
|
99 var = select.defined_vars[varname] |
|
100 except KeyError: |
|
101 # not a newly inserted variable |
|
102 continue |
|
103 exists = var.references()[0].scope |
|
104 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
105 # recompute solutions |
|
106 #select.annotated = False # avoid assertion error |
|
107 self.compute_solutions() |
|
108 # clean solutions according to initial solutions |
|
109 newsolutions = remove_solutions(solutions, select.solutions, |
|
110 select.defined_vars) |
|
111 select.solutions = newsolutions |
|
112 add_types_restriction(self.schema, select) |
|
113 if server.DEBUG: |
|
114 print '---- rewriten', select |
|
115 |
|
116 def build_variantes(self, newsolutions): |
|
117 variantes = set() |
|
118 for sol in newsolutions: |
|
119 variante = [] |
|
120 for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems(): |
|
121 variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) ) |
|
122 variantes.add(tuple(variante)) |
|
123 # rebuild variantes as dict |
|
124 variantes = [dict(variante) for variante in variantes] |
|
125 # remove variable which have always the same type |
|
126 for erqlexpr, mainvar, oldvar in self.rewritten: |
|
127 it = iter(variantes) |
|
128 etype = it.next()[(erqlexpr, mainvar, oldvar)] |
|
129 for variante in it: |
|
130 if variante[(erqlexpr, mainvar, oldvar)] != etype: |
|
131 break |
|
132 else: |
|
133 for variante in variantes: |
|
134 del variante[(erqlexpr, mainvar, oldvar)] |
|
135 return variantes |
|
136 |
|
137 def insert_snippets(self, snippets, varexistsmap=None): |
|
138 self.rewritten = {} |
|
139 for varname, erqlexprs in snippets: |
|
140 if varexistsmap is not None and not varname in varexistsmap: |
|
141 continue |
|
142 try: |
|
143 self.const = typed_eid(varname) |
|
144 self.varname = self.const |
|
145 self.rhs_rels = self.lhs_rels = {} |
|
146 except ValueError: |
|
147 self.varname = varname |
|
148 self.const = None |
|
149 self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo |
|
150 if varexistsmap is None: |
|
151 self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations']) |
|
152 self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations'] |
|
153 if not rel in stinfo['rhsrelations']) |
|
154 else: |
|
155 self.rhs_rels = self.lhs_rels = {} |
|
156 parent = None |
|
157 inserted = False |
|
158 for erqlexpr in erqlexprs: |
|
159 self.current_expr = erqlexpr |
|
160 if varexistsmap is None: |
|
161 try: |
|
162 new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent) |
|
163 except Unsupported: |
|
164 continue |
|
165 inserted = True |
|
166 if new is not None: |
|
167 self.exists_snippet[erqlexpr] = new |
|
168 parent = parent or new |
|
169 else: |
|
170 # called to reintroduce snippet due to ambiguity creation, |
|
171 # so skip snippets which are not introducing this ambiguity |
|
172 exists = varexistsmap[varname] |
|
173 if self.exists_snippet[erqlexpr] is exists: |
|
174 self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists) |
|
175 if varexistsmap is None and not inserted: |
|
176 # no rql expression found matching rql solutions. User has no access right |
|
177 raise Unauthorized() |
|
178 |
|
179 def insert_snippet(self, varname, snippetrqlst, parent=None): |
|
180 new = snippetrqlst.where.accept(self) |
|
181 if new is not None: |
|
182 try: |
|
183 var = self.select.defined_vars[varname] |
|
184 except KeyError: |
|
185 # not a variable |
|
186 pass |
|
187 else: |
|
188 if var.stinfo['optrelations']: |
|
189 # use a subquery |
|
190 subselect = stmts.Select() |
|
191 subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname))) |
|
192 subselect.add_restriction(new.copy(subselect)) |
|
193 aliases = [varname] |
|
194 for rel in var.stinfo['relations']: |
|
195 rschema = self.schema.rschema(rel.r_type) |
|
196 if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']): |
|
197 self.select.remove_node(rel) |
|
198 rel.children[0].name = varname |
|
199 subselect.add_restriction(rel.copy(subselect)) |
|
200 for vref in rel.children[1].iget_nodes(nodes.VariableRef): |
|
201 subselect.append_selected(vref.copy(subselect)) |
|
202 aliases.append(vref.name) |
|
203 if self.u_varname: |
|
204 # generate an identifier for the substitution |
|
205 argname = subselect.allocate_varname() |
|
206 while argname in self.kwargs: |
|
207 argname = subselect.allocate_varname() |
|
208 subselect.add_constant_restriction(subselect.get_variable(self.u_varname), |
|
209 'eid', unicode(argname), 'Substitute') |
|
210 self.kwargs[argname] = self.session.user.eid |
|
211 add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions) |
|
212 assert parent is None |
|
213 myunion = stmts.Union() |
|
214 myunion.append(subselect) |
|
215 aliases = [nodes.VariableRef(self.select.get_variable(name, i)) |
|
216 for i, name in enumerate(aliases)] |
|
217 self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False) |
|
218 self._cleanup_inserted(new) |
|
219 try: |
|
220 self.compute_solutions() |
|
221 except Unsupported: |
|
222 # some solutions have been lost, can't apply this rql expr |
|
223 self.select.remove_subquery(new, undefine=True) |
|
224 raise |
|
225 return |
|
226 new = nodes.Exists(new) |
|
227 if parent is None: |
|
228 self.select.add_restriction(new) |
|
229 else: |
|
230 grandpa = parent.parent |
|
231 or_ = nodes.Or(parent, new) |
|
232 grandpa.replace(parent, or_) |
|
233 if not self.removing_ambiguity: |
|
234 try: |
|
235 self.compute_solutions() |
|
236 except Unsupported: |
|
237 # some solutions have been lost, can't apply this rql expr |
|
238 if parent is None: |
|
239 self.select.remove_node(new, undefine=True) |
|
240 else: |
|
241 parent.parent.replace(or_, or_.children[0]) |
|
242 self._cleanup_inserted(new) |
|
243 raise |
|
244 return new |
|
245 |
|
246 def _cleanup_inserted(self, node): |
|
247 # cleanup inserted variable references |
|
248 for vref in node.iget_nodes(nodes.VariableRef): |
|
249 vref.unregister_reference() |
|
250 if not vref.variable.stinfo['references']: |
|
251 # no more references, undefine the variable |
|
252 del self.select.defined_vars[vref.name] |
|
253 |
|
254 def _visit_binary(self, node, cls): |
|
255 newnode = cls() |
|
256 for c in node.children: |
|
257 new = c.accept(self) |
|
258 if new is None: |
|
259 continue |
|
260 newnode.append(new) |
|
261 if len(newnode.children) == 0: |
|
262 return None |
|
263 if len(newnode.children) == 1: |
|
264 return newnode.children[0] |
|
265 return newnode |
|
266 |
|
267 def _visit_unary(self, node, cls): |
|
268 newc = node.children[0].accept(self) |
|
269 if newc is None: |
|
270 return None |
|
271 newnode = cls() |
|
272 newnode.append(newc) |
|
273 return newnode |
|
274 |
|
275 def visit_and(self, et): |
|
276 return self._visit_binary(et, nodes.And) |
|
277 |
|
278 def visit_or(self, ou): |
|
279 return self._visit_binary(ou, nodes.Or) |
|
280 |
|
281 def visit_not(self, node): |
|
282 return self._visit_unary(node, nodes.Not) |
|
283 |
|
284 def visit_exists(self, node): |
|
285 return self._visit_unary(node, nodes.Exists) |
|
286 |
|
287 def visit_relation(self, relation): |
|
288 lhs, rhs = relation.get_variable_parts() |
|
289 if lhs.name == 'X': |
|
290 # on lhs |
|
291 # see if we can reuse this relation |
|
292 if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U': |
|
293 if self._may_be_shared(relation, 'object'): |
|
294 # ok, can share variable |
|
295 term = self.lhs_rels[relation.r_type].children[1].children[0] |
|
296 self._use_outer_term(rhs.name, term) |
|
297 return |
|
298 elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U': |
|
299 # on rhs |
|
300 # see if we can reuse this relation |
|
301 if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'): |
|
302 # ok, can share variable |
|
303 term = self.rhs_rels[relation.r_type].children[0] |
|
304 self._use_outer_term(lhs.name, term) |
|
305 return |
|
306 rel = nodes.Relation(relation.r_type, relation.optional) |
|
307 for c in relation.children: |
|
308 rel.append(c.accept(self)) |
|
309 return rel |
|
310 |
|
311 def visit_comparison(self, cmp): |
|
312 cmp_ = nodes.Comparison(cmp.operator) |
|
313 for c in cmp.children: |
|
314 cmp_.append(c.accept(self)) |
|
315 return cmp_ |
|
316 |
|
317 def visit_mathexpression(self, mexpr): |
|
318 cmp_ = nodes.MathExpression(cmp.operator) |
|
319 for c in cmp.children: |
|
320 cmp_.append(c.accept(self)) |
|
321 return cmp_ |
|
322 |
|
323 def visit_function(self, function): |
|
324 """generate filter name for a function""" |
|
325 function_ = nodes.Function(function.name) |
|
326 for c in function.children: |
|
327 function_.append(c.accept(self)) |
|
328 return function_ |
|
329 |
|
330 def visit_constant(self, constant): |
|
331 """generate filter name for a constant""" |
|
332 return nodes.Constant(constant.value, constant.type) |
|
333 |
|
334 def visit_variableref(self, vref): |
|
335 """get the sql name for a variable reference""" |
|
336 if vref.name == 'X': |
|
337 if self.const is not None: |
|
338 return nodes.Constant(self.const, 'Int') |
|
339 return nodes.VariableRef(self.select.get_variable(self.varname)) |
|
340 vname_or_term = self._get_varname_or_term(vref.name) |
|
341 if isinstance(vname_or_term, basestring): |
|
342 return nodes.VariableRef(self.select.get_variable(vname_or_term)) |
|
343 # shared term |
|
344 return vname_or_term.copy(self.select) |
|
345 |
|
346 def _may_be_shared(self, relation, target): |
|
347 """return True if the snippet relation can be skipped to use a relation |
|
348 from the original query |
|
349 """ |
|
350 # if cardinality is in '?1', we can ignore the relation and use variable |
|
351 # from the original query |
|
352 rschema = self.schema.rschema(relation.r_type) |
|
353 if target == 'object': |
|
354 cardindex = 0 |
|
355 ttypes_func = rschema.objects |
|
356 rprop = rschema.rproperty |
|
357 else: # target == 'subject': |
|
358 cardindex = 1 |
|
359 ttypes_func = rschema.subjects |
|
360 rprop = lambda x,y,z: rschema.rproperty(y, x, z) |
|
361 for etype in self.varstinfo['possibletypes']: |
|
362 for ttype in ttypes_func(etype): |
|
363 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*': |
|
364 return False |
|
365 return True |
|
366 |
|
367 def _use_outer_term(self, snippet_varname, term): |
|
368 key = (self.current_expr, self.varname, snippet_varname) |
|
369 if key in self.rewritten: |
|
370 insertedvar = self.select.defined_vars.pop(self.rewritten[key]) |
|
371 for inserted_vref in insertedvar.references(): |
|
372 inserted_vref.parent.replace(inserted_vref, term.copy(self.select)) |
|
373 self.rewritten[key] = term |
|
374 |
|
375 def _get_varname_or_term(self, vname): |
|
376 if vname == 'U': |
|
377 if self.u_varname is None: |
|
378 select = self.select |
|
379 self.u_varname = select.allocate_varname() |
|
380 # generate an identifier for the substitution |
|
381 argname = select.allocate_varname() |
|
382 while argname in self.kwargs: |
|
383 argname = select.allocate_varname() |
|
384 # insert "U eid %(u)s" |
|
385 var = select.get_variable(self.u_varname) |
|
386 select.add_constant_restriction(select.get_variable(self.u_varname), |
|
387 'eid', unicode(argname), 'Substitute') |
|
388 self.kwargs[argname] = self.session.user.eid |
|
389 return self.u_varname |
|
390 key = (self.current_expr, self.varname, vname) |
|
391 try: |
|
392 return self.rewritten[key] |
|
393 except KeyError: |
|
394 self.rewritten[key] = newvname = self.select.allocate_varname() |
|
395 return newvname |