1 """RQL rewriting utilities, used for read security checking |
|
2 |
|
3 :organization: Logilab |
|
4 :copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2. |
|
5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
6 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses |
|
7 """ |
|
8 |
|
9 from rql import nodes, stmts, TypeResolverException |
|
10 from cubicweb import Unauthorized, server, typed_eid |
|
11 from cubicweb.server.ssplanner import add_types_restriction |
|
12 |
|
13 def remove_solutions(origsolutions, solutions, defined): |
|
14 """when a rqlst has been generated from another by introducing security |
|
15 assertions, this method returns solutions which are contained in orig |
|
16 solutions |
|
17 """ |
|
18 newsolutions = [] |
|
19 for origsol in origsolutions: |
|
20 for newsol in solutions[:]: |
|
21 for var, etype in origsol.items(): |
|
22 try: |
|
23 if newsol[var] != etype: |
|
24 try: |
|
25 defined[var].stinfo['possibletypes'].remove(newsol[var]) |
|
26 except KeyError: |
|
27 pass |
|
28 break |
|
29 except KeyError: |
|
30 # variable has been rewritten |
|
31 continue |
|
32 else: |
|
33 newsolutions.append(newsol) |
|
34 solutions.remove(newsol) |
|
35 return newsolutions |
|
36 |
|
37 class Unsupported(Exception): pass |
|
38 |
|
39 class RQLRewriter(object): |
|
40 """insert some rql snippets into another rql syntax tree""" |
|
41 def __init__(self, querier, session): |
|
42 self.session = session |
|
43 self.annotate = querier._rqlhelper.annotate |
|
44 self._compute_solutions = querier.solutions |
|
45 self.schema = querier.schema |
|
46 |
|
47 def compute_solutions(self): |
|
48 self.annotate(self.select) |
|
49 try: |
|
50 self._compute_solutions(self.session, self.select, self.kwargs) |
|
51 except TypeResolverException: |
|
52 raise Unsupported() |
|
53 if len(self.select.solutions) < len(self.solutions): |
|
54 raise Unsupported() |
|
55 |
|
56 def rewrite(self, select, snippets, solutions, kwargs): |
|
57 if server.DEBUG: |
|
58 print '---- rewrite', select, snippets, solutions |
|
59 self.select = select |
|
60 self.solutions = solutions |
|
61 self.kwargs = kwargs |
|
62 self.u_varname = None |
|
63 self.removing_ambiguity = False |
|
64 self.exists_snippet = {} |
|
65 # we have to annotate the rqlst before inserting snippets, even though |
|
66 # we'll have to redo it latter |
|
67 self.annotate(select) |
|
68 self.insert_snippets(snippets) |
|
69 if not self.exists_snippet and self.u_varname: |
|
70 # U has been inserted than cancelled, cleanup |
|
71 select.undefine_variable(select.defined_vars[self.u_varname]) |
|
72 # clean solutions according to initial solutions |
|
73 newsolutions = remove_solutions(solutions, select.solutions, |
|
74 select.defined_vars) |
|
75 assert len(newsolutions) >= len(solutions), \ |
|
76 'rewritten rql %s has lost some solutions, there is probably something '\ |
|
77 'wrong in your schema permission (for instance using a '\ |
|
78 'RQLExpression which insert a relation which doesn\'t exists in '\ |
|
79 'the schema)\nOrig solutions: %s\nnew solutions: %s' % ( |
|
80 select, solutions, newsolutions) |
|
81 if len(newsolutions) > len(solutions): |
|
82 # the snippet has introduced some ambiguities, we have to resolve them |
|
83 # "manually" |
|
84 variantes = self.build_variantes(newsolutions) |
|
85 # insert "is" where necessary |
|
86 varexistsmap = {} |
|
87 self.removing_ambiguity = True |
|
88 for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems(): |
|
89 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] |
|
90 var = select.defined_vars[varname] |
|
91 exists = var.references()[0].scope |
|
92 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
93 varexistsmap[mainvar] = exists |
|
94 # insert ORED exists where necessary |
|
95 for variante in variantes[1:]: |
|
96 self.insert_snippets(snippets, varexistsmap) |
|
97 for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems(): |
|
98 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)] |
|
99 try: |
|
100 var = select.defined_vars[varname] |
|
101 except KeyError: |
|
102 # not a newly inserted variable |
|
103 continue |
|
104 exists = var.references()[0].scope |
|
105 exists.add_constant_restriction(var, 'is', etype, 'etype') |
|
106 # recompute solutions |
|
107 #select.annotated = False # avoid assertion error |
|
108 self.compute_solutions() |
|
109 # clean solutions according to initial solutions |
|
110 newsolutions = remove_solutions(solutions, select.solutions, |
|
111 select.defined_vars) |
|
112 select.solutions = newsolutions |
|
113 add_types_restriction(self.schema, select) |
|
114 if server.DEBUG: |
|
115 print '---- rewriten', select |
|
116 |
|
117 def build_variantes(self, newsolutions): |
|
118 variantes = set() |
|
119 for sol in newsolutions: |
|
120 variante = [] |
|
121 for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems(): |
|
122 variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) ) |
|
123 variantes.add(tuple(variante)) |
|
124 # rebuild variantes as dict |
|
125 variantes = [dict(variante) for variante in variantes] |
|
126 # remove variable which have always the same type |
|
127 for erqlexpr, mainvar, oldvar in self.rewritten: |
|
128 it = iter(variantes) |
|
129 etype = it.next()[(erqlexpr, mainvar, oldvar)] |
|
130 for variante in it: |
|
131 if variante[(erqlexpr, mainvar, oldvar)] != etype: |
|
132 break |
|
133 else: |
|
134 for variante in variantes: |
|
135 del variante[(erqlexpr, mainvar, oldvar)] |
|
136 return variantes |
|
137 |
|
138 def insert_snippets(self, snippets, varexistsmap=None): |
|
139 self.rewritten = {} |
|
140 for varname, erqlexprs in snippets: |
|
141 if varexistsmap is not None and not varname in varexistsmap: |
|
142 continue |
|
143 try: |
|
144 self.const = typed_eid(varname) |
|
145 self.varname = self.const |
|
146 self.rhs_rels = self.lhs_rels = {} |
|
147 except ValueError: |
|
148 self.varname = varname |
|
149 self.const = None |
|
150 self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo |
|
151 if varexistsmap is None: |
|
152 self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations']) |
|
153 self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations'] |
|
154 if not rel in stinfo['rhsrelations']) |
|
155 else: |
|
156 self.rhs_rels = self.lhs_rels = {} |
|
157 parent = None |
|
158 inserted = False |
|
159 for erqlexpr in erqlexprs: |
|
160 self.current_expr = erqlexpr |
|
161 if varexistsmap is None: |
|
162 try: |
|
163 new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent) |
|
164 except Unsupported: |
|
165 continue |
|
166 inserted = True |
|
167 if new is not None: |
|
168 self.exists_snippet[erqlexpr] = new |
|
169 parent = parent or new |
|
170 else: |
|
171 # called to reintroduce snippet due to ambiguity creation, |
|
172 # so skip snippets which are not introducing this ambiguity |
|
173 exists = varexistsmap[varname] |
|
174 if self.exists_snippet[erqlexpr] is exists: |
|
175 self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists) |
|
176 if varexistsmap is None and not inserted: |
|
177 # no rql expression found matching rql solutions. User has no access right |
|
178 raise Unauthorized() |
|
179 |
|
180 def insert_snippet(self, varname, snippetrqlst, parent=None): |
|
181 new = snippetrqlst.where.accept(self) |
|
182 if new is not None: |
|
183 try: |
|
184 var = self.select.defined_vars[varname] |
|
185 except KeyError: |
|
186 # not a variable |
|
187 pass |
|
188 else: |
|
189 if var.stinfo['optrelations']: |
|
190 # use a subquery |
|
191 subselect = stmts.Select() |
|
192 subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname))) |
|
193 subselect.add_restriction(new.copy(subselect)) |
|
194 aliases = [varname] |
|
195 for rel in var.stinfo['relations']: |
|
196 rschema = self.schema.rschema(rel.r_type) |
|
197 if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']): |
|
198 self.select.remove_node(rel) |
|
199 rel.children[0].name = varname |
|
200 subselect.add_restriction(rel.copy(subselect)) |
|
201 for vref in rel.children[1].iget_nodes(nodes.VariableRef): |
|
202 subselect.append_selected(vref.copy(subselect)) |
|
203 aliases.append(vref.name) |
|
204 if self.u_varname: |
|
205 # generate an identifier for the substitution |
|
206 argname = subselect.allocate_varname() |
|
207 while argname in self.kwargs: |
|
208 argname = subselect.allocate_varname() |
|
209 subselect.add_constant_restriction(subselect.get_variable(self.u_varname), |
|
210 'eid', unicode(argname), 'Substitute') |
|
211 self.kwargs[argname] = self.session.user.eid |
|
212 add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions) |
|
213 assert parent is None |
|
214 myunion = stmts.Union() |
|
215 myunion.append(subselect) |
|
216 aliases = [nodes.VariableRef(self.select.get_variable(name, i)) |
|
217 for i, name in enumerate(aliases)] |
|
218 self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False) |
|
219 self._cleanup_inserted(new) |
|
220 try: |
|
221 self.compute_solutions() |
|
222 except Unsupported: |
|
223 # some solutions have been lost, can't apply this rql expr |
|
224 self.select.remove_subquery(new, undefine=True) |
|
225 raise |
|
226 return |
|
227 new = nodes.Exists(new) |
|
228 if parent is None: |
|
229 self.select.add_restriction(new) |
|
230 else: |
|
231 grandpa = parent.parent |
|
232 or_ = nodes.Or(parent, new) |
|
233 grandpa.replace(parent, or_) |
|
234 if not self.removing_ambiguity: |
|
235 try: |
|
236 self.compute_solutions() |
|
237 except Unsupported: |
|
238 # some solutions have been lost, can't apply this rql expr |
|
239 if parent is None: |
|
240 self.select.remove_node(new, undefine=True) |
|
241 else: |
|
242 parent.parent.replace(or_, or_.children[0]) |
|
243 self._cleanup_inserted(new) |
|
244 raise |
|
245 return new |
|
246 |
|
247 def _cleanup_inserted(self, node): |
|
248 # cleanup inserted variable references |
|
249 for vref in node.iget_nodes(nodes.VariableRef): |
|
250 vref.unregister_reference() |
|
251 if not vref.variable.stinfo['references']: |
|
252 # no more references, undefine the variable |
|
253 del self.select.defined_vars[vref.name] |
|
254 |
|
255 def _visit_binary(self, node, cls): |
|
256 newnode = cls() |
|
257 for c in node.children: |
|
258 new = c.accept(self) |
|
259 if new is None: |
|
260 continue |
|
261 newnode.append(new) |
|
262 if len(newnode.children) == 0: |
|
263 return None |
|
264 if len(newnode.children) == 1: |
|
265 return newnode.children[0] |
|
266 return newnode |
|
267 |
|
268 def _visit_unary(self, node, cls): |
|
269 newc = node.children[0].accept(self) |
|
270 if newc is None: |
|
271 return None |
|
272 newnode = cls() |
|
273 newnode.append(newc) |
|
274 return newnode |
|
275 |
|
276 def visit_and(self, et): |
|
277 return self._visit_binary(et, nodes.And) |
|
278 |
|
279 def visit_or(self, ou): |
|
280 return self._visit_binary(ou, nodes.Or) |
|
281 |
|
282 def visit_not(self, node): |
|
283 return self._visit_unary(node, nodes.Not) |
|
284 |
|
285 def visit_exists(self, node): |
|
286 return self._visit_unary(node, nodes.Exists) |
|
287 |
|
288 def visit_relation(self, relation): |
|
289 lhs, rhs = relation.get_variable_parts() |
|
290 if lhs.name == 'X': |
|
291 # on lhs |
|
292 # see if we can reuse this relation |
|
293 if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U': |
|
294 if self._may_be_shared(relation, 'object'): |
|
295 # ok, can share variable |
|
296 term = self.lhs_rels[relation.r_type].children[1].children[0] |
|
297 self._use_outer_term(rhs.name, term) |
|
298 return |
|
299 elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U': |
|
300 # on rhs |
|
301 # see if we can reuse this relation |
|
302 if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'): |
|
303 # ok, can share variable |
|
304 term = self.rhs_rels[relation.r_type].children[0] |
|
305 self._use_outer_term(lhs.name, term) |
|
306 return |
|
307 rel = nodes.Relation(relation.r_type, relation.optional) |
|
308 for c in relation.children: |
|
309 rel.append(c.accept(self)) |
|
310 return rel |
|
311 |
|
312 def visit_comparison(self, cmp): |
|
313 cmp_ = nodes.Comparison(cmp.operator) |
|
314 for c in cmp.children: |
|
315 cmp_.append(c.accept(self)) |
|
316 return cmp_ |
|
317 |
|
318 def visit_mathexpression(self, mexpr): |
|
319 cmp_ = nodes.MathExpression(mexpr.operator) |
|
320 for c in cmp.children: |
|
321 cmp_.append(c.accept(self)) |
|
322 return cmp_ |
|
323 |
|
324 def visit_function(self, function): |
|
325 """generate filter name for a function""" |
|
326 function_ = nodes.Function(function.name) |
|
327 for c in function.children: |
|
328 function_.append(c.accept(self)) |
|
329 return function_ |
|
330 |
|
331 def visit_constant(self, constant): |
|
332 """generate filter name for a constant""" |
|
333 return nodes.Constant(constant.value, constant.type) |
|
334 |
|
335 def visit_variableref(self, vref): |
|
336 """get the sql name for a variable reference""" |
|
337 if vref.name == 'X': |
|
338 if self.const is not None: |
|
339 return nodes.Constant(self.const, 'Int') |
|
340 return nodes.VariableRef(self.select.get_variable(self.varname)) |
|
341 vname_or_term = self._get_varname_or_term(vref.name) |
|
342 if isinstance(vname_or_term, basestring): |
|
343 return nodes.VariableRef(self.select.get_variable(vname_or_term)) |
|
344 # shared term |
|
345 return vname_or_term.copy(self.select) |
|
346 |
|
347 def _may_be_shared(self, relation, target): |
|
348 """return True if the snippet relation can be skipped to use a relation |
|
349 from the original query |
|
350 """ |
|
351 # if cardinality is in '?1', we can ignore the relation and use variable |
|
352 # from the original query |
|
353 rschema = self.schema.rschema(relation.r_type) |
|
354 if target == 'object': |
|
355 cardindex = 0 |
|
356 ttypes_func = rschema.objects |
|
357 rprop = rschema.rproperty |
|
358 else: # target == 'subject': |
|
359 cardindex = 1 |
|
360 ttypes_func = rschema.subjects |
|
361 rprop = lambda x, y, z: rschema.rproperty(y, x, z) |
|
362 for etype in self.varstinfo['possibletypes']: |
|
363 for ttype in ttypes_func(etype): |
|
364 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*': |
|
365 return False |
|
366 return True |
|
367 |
|
368 def _use_outer_term(self, snippet_varname, term): |
|
369 key = (self.current_expr, self.varname, snippet_varname) |
|
370 if key in self.rewritten: |
|
371 insertedvar = self.select.defined_vars.pop(self.rewritten[key]) |
|
372 for inserted_vref in insertedvar.references(): |
|
373 inserted_vref.parent.replace(inserted_vref, term.copy(self.select)) |
|
374 self.rewritten[key] = term |
|
375 |
|
376 def _get_varname_or_term(self, vname): |
|
377 if vname == 'U': |
|
378 if self.u_varname is None: |
|
379 select = self.select |
|
380 self.u_varname = select.allocate_varname() |
|
381 # generate an identifier for the substitution |
|
382 argname = select.allocate_varname() |
|
383 while argname in self.kwargs: |
|
384 argname = select.allocate_varname() |
|
385 # insert "U eid %(u)s" |
|
386 var = select.get_variable(self.u_varname) |
|
387 select.add_constant_restriction(select.get_variable(self.u_varname), |
|
388 'eid', unicode(argname), 'Substitute') |
|
389 self.kwargs[argname] = self.session.user.eid |
|
390 return self.u_varname |
|
391 key = (self.current_expr, self.varname, vname) |
|
392 try: |
|
393 return self.rewritten[key] |
|
394 except KeyError: |
|
395 self.rewritten[key] = newvname = self.select.allocate_varname() |
|
396 return newvname |
|