server/rqlrewrite.py
changeset 0 b97547f5f1fa
child 1132 96752791c2b6
equal deleted inserted replaced
-1:000000000000 0:b97547f5f1fa
       
     1 """RQL rewriting utilities, used for read security checking
       
     2 
       
     3 :organization: Logilab
       
     4 :copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     6 """
       
     7 
       
     8 from rql import nodes, stmts, TypeResolverException
       
     9 from cubicweb import Unauthorized, server, typed_eid
       
    10 from cubicweb.server.ssplanner import add_types_restriction
       
    11 
       
    12 def remove_solutions(origsolutions, solutions, defined):
       
    13     """when a rqlst has been generated from another by introducing security
       
    14     assertions, this method returns solutions which are contained in orig
       
    15     solutions
       
    16     """
       
    17     newsolutions = []
       
    18     for origsol in origsolutions:
       
    19         for newsol in solutions[:]:
       
    20             for var, etype in origsol.items():
       
    21                 try:
       
    22                     if newsol[var] != etype:
       
    23                         try:
       
    24                             defined[var].stinfo['possibletypes'].remove(newsol[var])
       
    25                         except KeyError:
       
    26                             pass
       
    27                         break
       
    28                 except KeyError,ex:
       
    29                     # variable has been rewritten
       
    30                     continue
       
    31             else:
       
    32                 newsolutions.append(newsol)
       
    33                 solutions.remove(newsol)
       
    34     return newsolutions
       
    35 
       
    36 class Unsupported(Exception): pass
       
    37         
       
    38 class RQLRewriter(object):
       
    39     """insert some rql snippets into another rql syntax tree"""
       
    40     def __init__(self, querier, session):
       
    41         self.session = session
       
    42         self.annotate = querier._rqlhelper.annotate
       
    43         self._compute_solutions = querier.solutions
       
    44         self.schema = querier.schema
       
    45 
       
    46     def compute_solutions(self):
       
    47         self.annotate(self.select)
       
    48         try:
       
    49             self._compute_solutions(self.session, self.select, self.kwargs)
       
    50         except TypeResolverException:
       
    51             raise Unsupported()
       
    52         if len(self.select.solutions) < len(self.solutions):
       
    53             raise Unsupported()
       
    54         
       
    55     def rewrite(self, select, snippets, solutions, kwargs):
       
    56         if server.DEBUG:
       
    57             print '---- rewrite', select, snippets, solutions
       
    58         self.select = select
       
    59         self.solutions = solutions
       
    60         self.kwargs = kwargs
       
    61         self.u_varname = None
       
    62         self.removing_ambiguity = False
       
    63         self.exists_snippet = {}
       
    64         # we have to annotate the rqlst before inserting snippets, even though
       
    65         # we'll have to redo it latter
       
    66         self.annotate(select)
       
    67         self.insert_snippets(snippets)
       
    68         if not self.exists_snippet and self.u_varname:
       
    69             # U has been inserted than cancelled, cleanup
       
    70             select.undefine_variable(select.defined_vars[self.u_varname])
       
    71         # clean solutions according to initial solutions
       
    72         newsolutions = remove_solutions(solutions, select.solutions,
       
    73                                         select.defined_vars)
       
    74         assert len(newsolutions) >= len(solutions), \
       
    75                'rewritten rql %s has lost some solutions, there is probably something '\
       
    76                'wrong in your schema permission (for instance using a '\
       
    77               'RQLExpression which insert a relation which doesn\'t exists in '\
       
    78                'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
       
    79             select, solutions, newsolutions)
       
    80         if len(newsolutions) > len(solutions):
       
    81             # the snippet has introduced some ambiguities, we have to resolve them
       
    82             # "manually"
       
    83             variantes = self.build_variantes(newsolutions)
       
    84             # insert "is" where necessary
       
    85             varexistsmap = {}
       
    86             self.removing_ambiguity = True
       
    87             for (erqlexpr, mainvar, oldvarname), etype in variantes[0].iteritems():
       
    88                 varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
       
    89                 var = select.defined_vars[varname]
       
    90                 exists = var.references()[0].scope
       
    91                 exists.add_constant_restriction(var, 'is', etype, 'etype')
       
    92                 varexistsmap[mainvar] = exists
       
    93             # insert ORED exists where necessary
       
    94             for variante in variantes[1:]:
       
    95                 self.insert_snippets(snippets, varexistsmap)
       
    96                 for (erqlexpr, mainvar, oldvarname), etype in variante.iteritems():
       
    97                     varname = self.rewritten[(erqlexpr, mainvar, oldvarname)]
       
    98                     try:
       
    99                         var = select.defined_vars[varname]
       
   100                     except KeyError:
       
   101                         # not a newly inserted variable
       
   102                         continue
       
   103                     exists = var.references()[0].scope
       
   104                     exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   105             # recompute solutions
       
   106             #select.annotated = False # avoid assertion error
       
   107             self.compute_solutions()
       
   108             # clean solutions according to initial solutions
       
   109             newsolutions = remove_solutions(solutions, select.solutions,
       
   110                                             select.defined_vars)
       
   111         select.solutions = newsolutions
       
   112         add_types_restriction(self.schema, select)
       
   113         if server.DEBUG:
       
   114             print '---- rewriten', select
       
   115             
       
   116     def build_variantes(self, newsolutions):
       
   117         variantes = set()
       
   118         for sol in newsolutions:
       
   119             variante = []
       
   120             for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
       
   121                 variante.append( ((erqlexpr, mainvar, oldvar), sol[newvar]) )
       
   122             variantes.add(tuple(variante))
       
   123         # rebuild variantes as dict
       
   124         variantes = [dict(variante) for variante in variantes]
       
   125         # remove variable which have always the same type
       
   126         for erqlexpr, mainvar, oldvar in self.rewritten:
       
   127             it = iter(variantes)
       
   128             etype = it.next()[(erqlexpr, mainvar, oldvar)]
       
   129             for variante in it:
       
   130                 if variante[(erqlexpr, mainvar, oldvar)] != etype:
       
   131                     break
       
   132             else:
       
   133                 for variante in variantes:
       
   134                     del variante[(erqlexpr, mainvar, oldvar)]
       
   135         return variantes
       
   136     
       
   137     def insert_snippets(self, snippets, varexistsmap=None):
       
   138         self.rewritten = {}
       
   139         for varname, erqlexprs in snippets:
       
   140             if varexistsmap is not None and not varname in varexistsmap:
       
   141                 continue
       
   142             try:
       
   143                 self.const = typed_eid(varname)
       
   144                 self.varname = self.const
       
   145                 self.rhs_rels = self.lhs_rels = {}
       
   146             except ValueError:
       
   147                 self.varname = varname
       
   148                 self.const = None
       
   149                 self.varstinfo = stinfo = self.select.defined_vars[varname].stinfo
       
   150                 if varexistsmap is None:
       
   151                     self.rhs_rels = dict( (rel.r_type, rel) for rel in stinfo['rhsrelations'])
       
   152                     self.lhs_rels = dict( (rel.r_type, rel) for rel in stinfo['relations']
       
   153                                                   if not rel in stinfo['rhsrelations'])
       
   154                 else:
       
   155                     self.rhs_rels = self.lhs_rels = {}
       
   156             parent = None
       
   157             inserted = False
       
   158             for erqlexpr in erqlexprs:
       
   159                 self.current_expr = erqlexpr
       
   160                 if varexistsmap is None:
       
   161                     try:
       
   162                         new = self.insert_snippet(varname, erqlexpr.snippet_rqlst, parent)
       
   163                     except Unsupported:
       
   164                         continue
       
   165                     inserted = True
       
   166                     if new is not None:
       
   167                         self.exists_snippet[erqlexpr] = new
       
   168                     parent = parent or new
       
   169                 else:
       
   170                     # called to reintroduce snippet due to ambiguity creation,
       
   171                     # so skip snippets which are not introducing this ambiguity
       
   172                     exists = varexistsmap[varname]
       
   173                     if self.exists_snippet[erqlexpr] is exists:
       
   174                         self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
       
   175             if varexistsmap is None and not inserted:
       
   176                 # no rql expression found matching rql solutions. User has no access right
       
   177                 raise Unauthorized()
       
   178             
       
   179     def insert_snippet(self, varname, snippetrqlst, parent=None):
       
   180         new = snippetrqlst.where.accept(self)
       
   181         if new is not None:
       
   182             try:
       
   183                 var = self.select.defined_vars[varname]
       
   184             except KeyError:
       
   185                 # not a variable
       
   186                 pass
       
   187             else:
       
   188                 if var.stinfo['optrelations']:
       
   189                     # use a subquery
       
   190                     subselect = stmts.Select()
       
   191                     subselect.append_selected(nodes.VariableRef(subselect.get_variable(varname)))
       
   192                     subselect.add_restriction(new.copy(subselect))
       
   193                     aliases = [varname]
       
   194                     for rel in var.stinfo['relations']:
       
   195                         rschema = self.schema.rschema(rel.r_type)
       
   196                         if rschema.is_final() or (rschema.inlined and not rel in var.stinfo['rhsrelations']):
       
   197                             self.select.remove_node(rel)
       
   198                             rel.children[0].name = varname
       
   199                             subselect.add_restriction(rel.copy(subselect))
       
   200                             for vref in rel.children[1].iget_nodes(nodes.VariableRef):
       
   201                                 subselect.append_selected(vref.copy(subselect))
       
   202                                 aliases.append(vref.name)
       
   203                     if self.u_varname:
       
   204                         # generate an identifier for the substitution
       
   205                         argname = subselect.allocate_varname()
       
   206                         while argname in self.kwargs:
       
   207                             argname = subselect.allocate_varname()
       
   208                         subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
       
   209                                                         'eid', unicode(argname), 'Substitute')
       
   210                         self.kwargs[argname] = self.session.user.eid
       
   211                     add_types_restriction(self.schema, subselect, subselect, solutions=self.solutions)
       
   212                     assert parent is None
       
   213                     myunion = stmts.Union()
       
   214                     myunion.append(subselect)
       
   215                     aliases = [nodes.VariableRef(self.select.get_variable(name, i))
       
   216                                for i, name in enumerate(aliases)]
       
   217                     self.select.add_subquery(nodes.SubQuery(aliases, myunion), check=False)
       
   218                     self._cleanup_inserted(new)
       
   219                     try:
       
   220                         self.compute_solutions()
       
   221                     except Unsupported:
       
   222                         # some solutions have been lost, can't apply this rql expr
       
   223                         self.select.remove_subquery(new, undefine=True)
       
   224                         raise
       
   225                     return
       
   226             new = nodes.Exists(new)
       
   227             if parent is None:
       
   228                 self.select.add_restriction(new)
       
   229             else:
       
   230                 grandpa = parent.parent
       
   231                 or_ = nodes.Or(parent, new)
       
   232                 grandpa.replace(parent, or_)
       
   233             if not self.removing_ambiguity:
       
   234                 try:
       
   235                     self.compute_solutions()
       
   236                 except Unsupported:
       
   237                     # some solutions have been lost, can't apply this rql expr
       
   238                     if parent is None:
       
   239                         self.select.remove_node(new, undefine=True)
       
   240                     else:
       
   241                         parent.parent.replace(or_, or_.children[0])
       
   242                         self._cleanup_inserted(new)
       
   243                     raise 
       
   244             return new
       
   245 
       
   246     def _cleanup_inserted(self, node):
       
   247         # cleanup inserted variable references
       
   248         for vref in node.iget_nodes(nodes.VariableRef):
       
   249             vref.unregister_reference()
       
   250             if not vref.variable.stinfo['references']:
       
   251                 # no more references, undefine the variable
       
   252                 del self.select.defined_vars[vref.name]
       
   253         
       
   254     def _visit_binary(self, node, cls):
       
   255         newnode = cls()
       
   256         for c in node.children:
       
   257             new = c.accept(self)
       
   258             if new is None:
       
   259                 continue
       
   260             newnode.append(new)
       
   261         if len(newnode.children) == 0:
       
   262             return None
       
   263         if len(newnode.children) == 1:
       
   264             return newnode.children[0]
       
   265         return newnode
       
   266 
       
   267     def _visit_unary(self, node, cls):
       
   268         newc = node.children[0].accept(self)
       
   269         if newc is None:
       
   270             return None
       
   271         newnode = cls()
       
   272         newnode.append(newc)
       
   273         return newnode 
       
   274         
       
   275     def visit_and(self, et):
       
   276         return self._visit_binary(et, nodes.And)
       
   277 
       
   278     def visit_or(self, ou):
       
   279         return self._visit_binary(ou, nodes.Or)
       
   280         
       
   281     def visit_not(self, node):
       
   282         return self._visit_unary(node, nodes.Not)
       
   283 
       
   284     def visit_exists(self, node):
       
   285         return self._visit_unary(node, nodes.Exists)
       
   286    
       
   287     def visit_relation(self, relation):
       
   288         lhs, rhs = relation.get_variable_parts()
       
   289         if lhs.name == 'X':
       
   290             # on lhs
       
   291             # see if we can reuse this relation
       
   292             if relation.r_type in self.lhs_rels and isinstance(rhs, nodes.VariableRef) and rhs.name != 'U':
       
   293                 if self._may_be_shared(relation, 'object'):
       
   294                     # ok, can share variable
       
   295                     term = self.lhs_rels[relation.r_type].children[1].children[0]
       
   296                     self._use_outer_term(rhs.name, term)
       
   297                     return
       
   298         elif isinstance(rhs, nodes.VariableRef) and rhs.name == 'X' and lhs.name != 'U':
       
   299             # on rhs
       
   300             # see if we can reuse this relation
       
   301             if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
       
   302                 # ok, can share variable
       
   303                 term = self.rhs_rels[relation.r_type].children[0]
       
   304                 self._use_outer_term(lhs.name, term)            
       
   305                 return
       
   306         rel = nodes.Relation(relation.r_type, relation.optional)
       
   307         for c in relation.children:
       
   308             rel.append(c.accept(self))
       
   309         return rel
       
   310 
       
   311     def visit_comparison(self, cmp):
       
   312         cmp_ = nodes.Comparison(cmp.operator)
       
   313         for c in cmp.children:
       
   314             cmp_.append(c.accept(self))
       
   315         return cmp_
       
   316 
       
   317     def visit_mathexpression(self, mexpr):
       
   318         cmp_ = nodes.MathExpression(cmp.operator)
       
   319         for c in cmp.children:
       
   320             cmp_.append(c.accept(self))
       
   321         return cmp_
       
   322         
       
   323     def visit_function(self, function):
       
   324         """generate filter name for a function"""
       
   325         function_ = nodes.Function(function.name)
       
   326         for c in function.children:
       
   327             function_.append(c.accept(self))
       
   328         return function_
       
   329 
       
   330     def visit_constant(self, constant):
       
   331         """generate filter name for a constant"""
       
   332         return nodes.Constant(constant.value, constant.type)
       
   333 
       
   334     def visit_variableref(self, vref):
       
   335         """get the sql name for a variable reference"""
       
   336         if vref.name == 'X':
       
   337             if self.const is not None:
       
   338                 return nodes.Constant(self.const, 'Int')
       
   339             return nodes.VariableRef(self.select.get_variable(self.varname))
       
   340         vname_or_term = self._get_varname_or_term(vref.name)
       
   341         if isinstance(vname_or_term, basestring):
       
   342             return nodes.VariableRef(self.select.get_variable(vname_or_term))
       
   343         # shared term
       
   344         return vname_or_term.copy(self.select)
       
   345 
       
   346     def _may_be_shared(self, relation, target):
       
   347         """return True if the snippet relation can be skipped to use a relation
       
   348         from the original query
       
   349         """
       
   350         # if cardinality is in '?1', we can ignore the relation and use variable
       
   351         # from the original query
       
   352         rschema = self.schema.rschema(relation.r_type)
       
   353         if target == 'object':
       
   354             cardindex = 0
       
   355             ttypes_func = rschema.objects
       
   356             rprop = rschema.rproperty
       
   357         else: # target == 'subject':
       
   358             cardindex = 1
       
   359             ttypes_func = rschema.subjects
       
   360             rprop = lambda x,y,z: rschema.rproperty(y, x, z)
       
   361         for etype in self.varstinfo['possibletypes']:
       
   362             for ttype in ttypes_func(etype):
       
   363                 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
       
   364                     return False
       
   365         return True
       
   366 
       
   367     def _use_outer_term(self, snippet_varname, term):
       
   368         key = (self.current_expr, self.varname, snippet_varname)
       
   369         if key in self.rewritten:
       
   370             insertedvar = self.select.defined_vars.pop(self.rewritten[key])
       
   371             for inserted_vref in insertedvar.references():
       
   372                 inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
       
   373         self.rewritten[key] = term
       
   374         
       
   375     def _get_varname_or_term(self, vname):
       
   376         if vname == 'U':
       
   377             if self.u_varname is None:
       
   378                 select = self.select
       
   379                 self.u_varname = select.allocate_varname()
       
   380                 # generate an identifier for the substitution
       
   381                 argname = select.allocate_varname()
       
   382                 while argname in self.kwargs:
       
   383                     argname = select.allocate_varname()
       
   384                 # insert "U eid %(u)s"
       
   385                 var = select.get_variable(self.u_varname)
       
   386                 select.add_constant_restriction(select.get_variable(self.u_varname),
       
   387                                                 'eid', unicode(argname), 'Substitute')
       
   388                 self.kwargs[argname] = self.session.user.eid
       
   389             return self.u_varname
       
   390         key = (self.current_expr, self.varname, vname)
       
   391         try:
       
   392             return self.rewritten[key]
       
   393         except KeyError:
       
   394             self.rewritten[key] = newvname = self.select.allocate_varname()
       
   395             return newvname