server/rqlrewrite.py
changeset 3293 69c0ba095536
parent 3230 1d25e928c299
parent 3292 70c0dd1c3b7d
child 3300 c7c4775a5619
equal deleted inserted replaced
3230:1d25e928c299 3293:69c0ba095536
     1 """RQL rewriting utilities, used for read security checking
       
     2 
       
     3 :organization: Logilab
       
     4 :copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
       
     5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     6 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
       
     7 """
       
     8 
       
     9 from rql import nodes, stmts, TypeResolverException
       
    10 from cubicweb import Unauthorized, server, typed_eid
       
    11 from cubicweb.server.ssplanner import add_types_restriction
       
    12 
       
    13 def remove_solutions(origsolutions, solutions, defined):
       
    14     """when a rqlst has been generated from another by introducing security
       
    15     assertions, this method returns solutions which are contained in orig
       
    16     solutions
       
    17     """
       
    18     newsolutions = []
       
    19     for origsol in origsolutions:
       
    20         for newsol in solutions[:]:
       
    21             for var, etype in origsol.items():
       
    22                 try:
       
    23                     if newsol[var] != etype:
       
    24                         try:
       
    25                             defined[var].stinfo['possibletypes'].remove(newsol[var])
       
    26                         except KeyError:
       
    27                             pass
       
    28                         break
       
    29                 except KeyError:
       
    30                     # variable has been rewritten
       
    31                     continue
       
    32             else:
       
    33                 newsolutions.append(newsol)
       
    34                 solutions.remove(newsol)
       
    35     return newsolutions
       
    36 
       
    37 class Unsupported(Exception): pass
       
    38 
       
    39 class RQLRewriter(object):
       
    40     """insert some rql snippets into another rql syntax tree"""
       
    41     def __init__(self, querier, session):
       
    42         self.session = session
       
    43         self.annotate = querier._rqlhelper.annotate
       
    44         self._compute_solutions = querier.solutions
       
    45         self.schema = querier.schema
       
    46 
       
    47     def compute_solutions(self):
       
    48         self.annotate(self.select)
       
    49         try:
       
    50             self._compute_solutions(self.session, self.select, self.kwargs)
       
    51         except TypeResolverException:
       
    52             raise Unsupported()
       
    53         if len(self.select.solutions) < len(self.solutions):
       
    54             raise Unsupported()
       
    55 
       
    56     def rewrite(self, select, snippets, solutions, kwargs):
       
    57         if server.DEBUG:
       
    58             print '---- rewrite', select, snippets, solutions
       
    59         self.select = select
       
    60         self.solutions = solutions
       
    61         self.kwargs = kwargs
       
    62         self.u_varname = None
       
    63         self.removing_ambiguity = False
       
    64         self.exists_snippet = {}
       
    65         # we have to annotate the rqlst before inserting snippets, even though
       
    66         # we'll have to redo it latter
       
    67         self.annotate(select)
       
    68         self.insert_snippets(snippets)
       
    69         if not self.exists_snippet and self.u_varname:
       
    70             # U has been inserted than cancelled, cleanup
       
    71             select.undefine_variable(select.defined_vars[self.u_varname])
       
    72         # clean solutions according to initial solutions
       
    73         newsolutions = remove_solutions(solutions, select.solutions,
       
    74                                         select.defined_vars)
       
    75         assert len(newsolutions) >= len(solutions), \
       
    76                'rewritten rql %s has lost some solutions, there is probably something '\
       
    77                'wrong in your schema permission (for instance using a '\
       
    78               'RQLExpression which insert a relation which doesn\'t exists in '\
       
    79                'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
       
    80             select, solutions, newsolutions)
       
    81         if len(newsolutions) > len(solutions):
       
    82             # the snippet has introduced some ambiguities, we have to resolve them
       
    83             # "manually"
       
    84             variantes = self.build_variantes(newsolutions)
       
    85             # insert "is" where necessary
       
    86             varexistsmap = {}
       
    87             self.removing_ambiguity = True
       
    88             for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
       
    89                 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
       
    90                 var = select.defined_vars[varname]
       
    91                 exists = var.references()[0].scope
       
    92                 exists.add_constant_restriction(var, 'is', etype, 'etype')
       
    93                 varexistsmap[mainvar] = exists
       
    94             # insert ORED exists where necessary
       
    95             for variante in variantes[1:]:
       
    96                 self.insert_snippets(snippets, varexistsmap)
       
    97                 for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
       
    98                     varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
       
    99                     try:
       
   100                         var = select.defined_vars[varname]
       
   101                     except KeyError:
       
   102                         # not a newly inserted variable
       
   103                         continue
       
   104                     exists = var.references()[0].scope
       
   105                     exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   106             # recompute solutions
       
   107             #select.annotated = False # avoid assertion error
       
   108             self.compute_solutions()
       
   109             # clean solutions according to initial solutions
       
   110             newsolutions = remove_solutions(solutions, select.solutions,
       
   111                                             select.defined_vars)
       
   112         select.solutions = newsolutions
       
   113         add_types_restriction(self.schema, select)
       
   114         if server.DEBUG:
       
   115             print '---- rewriten', select
       
   116 
       
   117     def build_variantes(self, newsolutions):
       
   118         variantes = set()
       
   119         for sol in newsolutions:
       
   120             variante = []
       
   121             for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
       
   122                 variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
       
   123             variantes.add(tuple(variante))
       
   124         # rebuild variantes as dict
       
   125         variantes = [dict(variante) for variante in variantes]
       
   126         # remove variable which have always the same type
       
   127         for erqlexpr, mainvar, oldvar in self.rewritten:
       
   128             it = iter(variantes)
       
   129             etype = it.next()[(erqlexpr, mainvar, oldvar)]
       
   130             for variante in it:
       
   131                 if variante[(erqlexpr, mainvar, oldvar)] != etype:
       
   132                     break
       
   133             else:
       
   134                 for variante in variantes:
       
   135                     del variante[(erqlexpr, mainvar, oldvar)]
       
   136         return variantes
       
   137 
       
   138     def insert_snippets(self, snippets, varexistsmap=None):
       
   139         self.rewritten = {}
       
   140         for varname, erqlexprs in snippets:
       
   141             if varexistsmap is not None and not varname in varexistsmap:
       
   142                 continue
       
   143             try:
       
   144                 self.const = typed_eid(varname)
       
   145                 self.varname = self.const
       
   146                 self.rhs_rels = self.lhs_rels = {}
       
   147             except ValueError:
       
   148                 self.varname = varname
       
   149                 self.const = None
       
   150                 self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
       
   151                 if varexistsmap is None:
       
   152                     self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
       
   153                     self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
       
   154                                                   if not rel in stinfo['rhsrelations'])
       
   155                 else:
       
   156                     self.rhs_rels = self.lhs_rels = {}
       
   157             parent = None
       
   158             inserted = False
       
   159             for erqlexpr in erqlexprs:
       
   160                 self.current_expr = erqlexpr
       
   161                 if varexistsmap is None:
       
   162                     try:
       
   163                         new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
       
   164                     except Unsupported:
       
   165                         continue
       
   166                     inserted = True
       
   167                     if new is not None:
       
   168                         self.exists_snippet[erqlexpr] = new
       
   169                     parent = parent or new
       
   170                 else:
       
   171                     # called to reintroduce snippet due to ambiguity creation,
       
   172                     # so skip snippets which are not introducing this ambiguity
       
   173                     exists = varexistsmap[varname]
       
   174                     if self.exists_snippet[erqlexpr] is exists:
       
   175                         self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
       
   176             if varexistsmap is None and not inserted:
       
   177                 # no rql expression found matching rql solutions. User has no access right
       
   178                 raise Unauthorized()
       
   179 
       
   180     def insert_snippet(self, varname, snippetrqlst, parent=None):
       
   181         new = snippetrqlst.where.accept(self)
       
   182         if new is not None:
       
   183             try:
       
   184                 var = self.select.defined_vars[varname]
       
   185             except KeyError:
       
   186                 # not a variable
       
   187                 pass
       
   188             else:
       
   189                 if var.stinfo['optrelations']:
       
   190                     # use a subquery
       
   191                     subselect = stmts.Select()
       
   192                     subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
       
   193                     subselect.add_restriction(new.copy(subselect))
       
   194                     aliases = [varname]
       
   195                     for rel in var.stinfo['relations']:
       
   196                         rschema = self.schema.rschema(rel.r_type)
       
   197                         if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
       
   198                             self.select.remove_node(rel)
       
   199                             rel.children[0].name = varname
       
   200                             subselect.add_restriction(rel.copy(subselect))
       
   201                             for vref in rel.children[1].iget_nodes(nodes.VariableRef):
       
   202                                 subselect.append_selected(vref.copy(subselect))
       
   203                                 aliases.append(vref.name)
       
   204                     if self.u_varname:
       
   205                         # generate an identifier for the substitution
       
   206                         argname = subselect.allocate_varname()
       
   207                         while argname in self.kwargs:
       
   208                             argname = subselect.allocate_varname()
       
   209                         subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
       
   210                                                         'eid', unicode(argname), 'Substitute')
       
   211                         self.kwargs[argname] = self.session.user.eid
       
   212                     add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
       
   213                     assert parent is None
       
   214                     myunion = stmts.Union()
       
   215                     myunion.append(subselect)
       
   216                     aliases = [nodes.VariableRef(self.select.get_variable(name, i))
       
   217                                for i, name in enumerate(aliases)]
       
   218                     self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
       
   219                     self._cleanup_inserted(new)
       
   220                     try:
       
   221                         self.compute_solutions()
       
   222                     except Unsupported:
       
   223                         # some solutions have been lost, can't apply this rql expr
       
   224                         self.select.remove_subquery(new, undefine=True)
       
   225                         raise
       
   226                     return
       
   227             new = nodes.Exists(new)
       
   228             if parent is None:
       
   229                 self.select.add_restriction(new)
       
   230             else:
       
   231                 grandpa = parent.parent
       
   232                 or_ = nodes.Or(parent, new)
       
   233                 grandpa.replace(parent, or_)
       
   234             if not self.removing_ambiguity:
       
   235                 try:
       
   236                     self.compute_solutions()
       
   237                 except Unsupported:
       
   238                     # some solutions have been lost, can't apply this rql expr
       
   239                     if parent is None:
       
   240                         self.select.remove_node(new, undefine=True)
       
   241                     else:
       
   242                         parent.parent.replace(or_, or_.children[0])
       
   243                         self._cleanup_inserted(new)
       
   244                     raise
       
   245             return new
       
   246 
       
   247     def _cleanup_inserted(self, node):
       
   248         # cleanup inserted variable references
       
   249         for vref in node.iget_nodes(nodes.VariableRef):
       
   250             vref.unregister_reference()
       
   251             if not vref.variable.stinfo['references']:
       
   252                 # no more references, undefine the variable
       
   253                 del self.select.defined_vars[vref.name]
       
   254 
       
   255     def _visit_binary(self, node, cls):
       
   256         newnode = cls()
       
   257         for c in node.children:
       
   258             new = c.accept(self)
       
   259             if new is None:
       
   260                 continue
       
   261             newnode.append(new)
       
   262         if len(newnode.children) == 0:
       
   263             return None
       
   264         if len(newnode.children) == 1:
       
   265             return newnode.children[0]
       
   266         return newnode
       
   267 
       
   268     def _visit_unary(self, node, cls):
       
   269         newc = node.children[0].accept(self)
       
   270         if newc is None:
       
   271             return None
       
   272         newnode = cls()
       
   273         newnode.append(newc)
       
   274         return newnode
       
   275 
       
   276     def visit_and(self, et):
       
   277         return self._visit_binary(et, nodes.And)
       
   278 
       
   279     def visit_or(self, ou):
       
   280         return self._visit_binary(ou, nodes.Or)
       
   281 
       
   282     def visit_not(self, node):
       
   283         return self._visit_unary(node, nodes.Not)
       
   284 
       
   285     def visit_exists(self, node):
       
   286         return self._visit_unary(node, nodes.Exists)
       
   287 
       
   288     def visit_relation(self, relation):
       
   289         lhs, rhs = relation.get_variable_parts()
       
   290         if lhs.name == 'X':
       
   291             # on lhs
       
   292             # see if we can reuse this relation
       
   293             if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
       
   294                 if self._may_be_shared(relation, 'object'):
       
   295                     # ok, can share variable
       
   296                     term = self.lhs_rels[relation.r_type].children[1].children[0]
       
   297                     self._use_outer_term(rhs.name, term)
       
   298                     return
       
   299         elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
       
   300             # on rhs
       
   301             # see if we can reuse this relation
       
   302             if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
       
   303                 # ok, can share variable
       
   304                 term = self.rhs_rels[relation.r_type].children[0]
       
   305                 self._use_outer_term(lhs.name, term)
       
   306                 return
       
   307         rel = nodes.Relation(relation.r_type, relation.optional)
       
   308         for c in relation.children:
       
   309             rel.append(c.accept(self))
       
   310         return rel
       
   311 
       
   312     def visit_comparison(self, cmp):
       
   313         cmp_ = nodes.Comparison(cmp.operator)
       
   314         for c in cmp.children:
       
   315             cmp_.append(c.accept(self))
       
   316         return cmp_
       
   317 
       
   318     def visit_mathexpression(self, mexpr):
       
   319         cmp_ = nodes.MathExpression(mexpr.operator)
       
   320         for c in cmp.children:
       
   321             cmp_.append(c.accept(self))
       
   322         return cmp_
       
   323 
       
   324     def visit_function(self, function):
       
   325         """generate filter name for a function"""
       
   326         function_ = nodes.Function(function.name)
       
   327         for c in function.children:
       
   328             function_.append(c.accept(self))
       
   329         return function_
       
   330 
       
   331     def visit_constant(self, constant):
       
   332         """generate filter name for a constant"""
       
   333         return nodes.Constant(constant.value, constant.type)
       
   334 
       
   335     def visit_variableref(self, vref):
       
   336         """get the sql name for a variable reference"""
       
   337         if vref.name == 'X':
       
   338             if self.const is not None:
       
   339                 return nodes.Constant(self.const, 'Int')
       
   340             return nodes.VariableRef(self.select.get_variable(self.varname))
       
   341         vname_or_term = self._get_varname_or_term(vref.name)
       
   342         if isinstance(vname_or_term, basestring):
       
   343             return nodes.VariableRef(self.select.get_variable(vname_or_term))
       
   344         # shared term
       
   345         return vname_or_term.copy(self.select)
       
   346 
       
   347     def _may_be_shared(self, relation, target):
       
   348         """return True if the snippet relation can be skipped to use a relation
       
   349         from the original query
       
   350         """
       
   351         # if cardinality is in '?1', we can ignore the relation and use variable
       
   352         # from the original query
       
   353         rschema = self.schema.rschema(relation.r_type)
       
   354         if target == 'object':
       
   355             cardindex = 0
       
   356             ttypes_func = rschema.objects
       
   357             rprop = rschema.rproperty
       
   358         else: # target == 'subject':
       
   359             cardindex = 1
       
   360             ttypes_func = rschema.subjects
       
   361             rprop = lambda x, y, z: rschema.rproperty(y, x, z)
       
   362         for etype in self.varstinfo['possibletypes']:
       
   363             for ttype in ttypes_func(etype):
       
   364                 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
       
   365                     return False
       
   366         return True
       
   367 
       
   368     def _use_outer_term(self, snippet_varname, term):
       
   369         key = (self.current_expr, self.varname, snippet_varname)
       
   370         if key in self.rewritten:
       
   371             insertedvar = self.select.defined_vars.pop(self.rewritten[key])
       
   372             for inserted_vref in insertedvar.references():
       
   373                 inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
       
   374         self.rewritten[key] = term
       
   375 
       
   376     def _get_varname_or_term(self, vname):
       
   377         if vname == 'U':
       
   378             if self.u_varname is None:
       
   379                 select = self.select
       
   380                 self.u_varname = select.allocate_varname()
       
   381                 # generate an identifier for the substitution
       
   382                 argname = select.allocate_varname()
       
   383                 while argname in self.kwargs:
       
   384                     argname = select.allocate_varname()
       
   385                 # insert "U eid %(u)s"
       
   386                 var = select.get_variable(self.u_varname)
       
   387                 select.add_constant_restriction(select.get_variable(self.u_varname),
       
   388                                                 'eid', unicode(argname), 'Substitute')
       
   389                 self.kwargs[argname] = self.session.user.eid
       
   390             return self.u_varname
       
   391         key = (self.current_expr, self.varname, vname)
       
   392         try:
       
   393             return self.rewritten[key]
       
   394         except KeyError:
       
   395             self.rewritten[key] = newvname = self.select.allocate_varname()
       
   396             return newvname