server/querier.py
changeset 4953 c8c0e10dbd97
parent 4843 5f7363416765
child 4954 96f67c5be0e6
equal deleted inserted replaced
4951:7dc54e12c606 4953:c8c0e10dbd97
    39         if varmap.get(key, value) != value:
    39         if varmap.get(key, value) != value:
    40             raise Exception('variable name conflict on %s' % key)
    40             raise Exception('variable name conflict on %s' % key)
    41         varmap[key] = value
    41         varmap[key] = value
    42 
    42 
    43 # permission utilities ########################################################
    43 # permission utilities ########################################################
    44 
       
    45 def var_kwargs(restriction, args):
       
    46     varkwargs = {}
       
    47     for rel in restriction.iget_nodes(Relation):
       
    48         cmp = rel.children[1]
       
    49         if rel.r_type == 'eid' and cmp.operator == '=' and \
       
    50                not rel.neged(strict=True) and \
       
    51                isinstance(cmp.children[0], Constant) and \
       
    52                cmp.children[0].type == 'Substitute':
       
    53             varkwargs[rel.children[0].name] = typed_eid(cmp.children[0].eval(args))
       
    54     return varkwargs
       
    55 
    44 
    56 def check_no_password_selected(rqlst):
    45 def check_no_password_selected(rqlst):
    57     """check that Password entities are not selected"""
    46     """check that Password entities are not selected"""
    58     for solution in rqlst.solutions:
    47     for solution in rqlst.solutions:
    59         if 'Password' in solution.itervalues():
    48         if 'Password' in solution.itervalues():
    82             if not user.matching_groups(rdef.get_groups('read')):
    71             if not user.matching_groups(rdef.get_groups('read')):
    83                 raise Unauthorized('read', rel.r_type)
    72                 raise Unauthorized('read', rel.r_type)
    84     localchecks = {}
    73     localchecks = {}
    85     # iterate on defined_vars and not on solutions to ignore column aliases
    74     # iterate on defined_vars and not on solutions to ignore column aliases
    86     for varname in rqlst.defined_vars:
    75     for varname in rqlst.defined_vars:
    87         etype = solution[varname]
    76         eschema = schema.eschema(solution[varname])
    88         eschema = schema.eschema(etype)
       
    89         if eschema.final:
    77         if eschema.final:
    90             continue
    78             continue
    91         if not user.matching_groups(eschema.get_groups('read')):
    79         if not user.matching_groups(eschema.get_groups('read')):
    92             erqlexprs = eschema.get_rqlexprs('read')
    80             erqlexprs = eschema.get_rqlexprs('read')
    93             if not erqlexprs:
    81             if not erqlexprs:
    94                 ex = Unauthorized('read', etype)
    82                 ex = Unauthorized('read', solution[varname])
    95                 ex.var = varname
    83                 ex.var = varname
    96                 raise ex
    84                 raise ex
    97             #assert len(erqlexprs) == 1
    85             localchecks[varname] = erqlexprs
    98             localchecks[varname] = tuple(erqlexprs)
       
    99     return localchecks
    86     return localchecks
   100 
    87 
   101 def noinvariant_vars(restricted, select, nbtrees):
    88 def add_noinvariant(noinvariant, restricted, select, nbtrees):
   102     # a variable can actually be invariant if it has not been restricted for
    89     # a variable can actually be invariant if it has not been restricted for
   103     # security reason or if security assertion hasn't modified the possible
    90     # security reason or if security assertion hasn't modified the possible
   104     # solutions for the query
    91     # solutions for the query
   105     if nbtrees != 1:
    92     if nbtrees != 1:
   106         for vname in restricted:
    93         for vname in restricted:
   107             try:
    94             try:
   108                 yield select.defined_vars[vname]
    95                 noinvariant.add(select.defined_vars[vname])
   109             except KeyError:
    96             except KeyError:
   110                 # this is an alias
    97                 # this is an alias
   111                 continue
    98                 continue
   112     else:
    99     else:
   113         for vname in restricted:
   100         for vname in restricted:
   115                 var = select.defined_vars[vname]
   102                 var = select.defined_vars[vname]
   116             except KeyError:
   103             except KeyError:
   117                 # this is an alias
   104                 # this is an alias
   118                 continue
   105                 continue
   119             if len(var.stinfo['possibletypes']) != 1:
   106             if len(var.stinfo['possibletypes']) != 1:
   120                 yield var
   107                 noinvariant.add(var)
   121 
   108 
   122 def _expand_selection(terms, selected, aliases, select, newselect):
   109 def _expand_selection(terms, selected, aliases, select, newselect):
   123     for term in terms:
   110     for term in terms:
   124         for vref in term.iget_nodes(VariableRef):
   111         for vref in term.iget_nodes(VariableRef):
   125             if not vref.name in selected:
   112             if not vref.name in selected:
   278                     myunion.append(myrqlst)
   265                     myunion.append(myrqlst)
   279                     # in-place rewrite + annotation / simplification
   266                     # in-place rewrite + annotation / simplification
   280                     lcheckdef = [((varmap, 'X'), rqlexprs)
   267                     lcheckdef = [((varmap, 'X'), rqlexprs)
   281                                  for varmap, rqlexprs in lcheckdef]
   268                                  for varmap, rqlexprs in lcheckdef]
   282                     rewrite(myrqlst, lcheckdef, lchecksolutions, self.args)
   269                     rewrite(myrqlst, lcheckdef, lchecksolutions, self.args)
   283                     noinvariant.update(noinvariant_vars(restricted, myrqlst, nbtrees))
   270                     add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
   284                 if () in localchecks:
   271                 if () in localchecks:
   285                     select.set_possible_types(localchecks[()])
   272                     select.set_possible_types(localchecks[()])
   286                     add_types_restriction(self.schema, select)
   273                     add_types_restriction(self.schema, select)
   287                     noinvariant.update(noinvariant_vars(restricted, select, nbtrees))
   274                     add_noinvariant(noinvariant, restricted, select, nbtrees)
   288 
   275 
   289     def _check_permissions(self, rqlst):
   276     def _check_permissions(self, rqlst):
   290         """return a dict defining "local checks", e.g. RQLExpression defined in
   277         """return a dict defining "local checks", e.g. RQLExpression defined in
   291         the schema that should be inserted in the original query
   278         the schema that should be inserted in the original query
   292 
   279 
   302         So solutions which don't require local checks will be associated to
   289         So solutions which don't require local checks will be associated to
   303         the empty tuple key.
   290         the empty tuple key.
   304 
   291 
   305         note: rqlst should not have been simplified at this point
   292         note: rqlst should not have been simplified at this point
   306         """
   293         """
   307         user = self.session.user
   294         session = self.session
       
   295         user = session.user
   308         schema = self.schema
   296         schema = self.schema
   309         msgs = []
   297         msgs = []
       
   298         neweids = session.transaction_data.get('neweids', ())
       
   299         varkwargs = {}
       
   300         if not session.transaction_data.get('security-rqlst-cache'):
       
   301             for var in rqlst.defined_vars.itervalues():
       
   302                 for rel in var.stinfo['uidrels']:
       
   303                     const = rel.children[1].children[0]
       
   304                     varkwargs[var.name] = typed_eid(const.eval(self.args))
       
   305                     break
   310         # dictionnary of variables restricted for security reason
   306         # dictionnary of variables restricted for security reason
   311         localchecks = {}
   307         localchecks = {}
   312         if rqlst.where is not None:
       
   313             varkwargs = var_kwargs(rqlst.where, self.args)
       
   314             neweids = self.session.transaction_data.get('neweids', ())
       
   315         else:
       
   316             varkwargs = None
       
   317         restricted_vars = set()
   308         restricted_vars = set()
   318         newsolutions = []
   309         newsolutions = []
   319         for solution in rqlst.solutions:
   310         for solution in rqlst.solutions:
   320             try:
   311             try:
   321                 localcheck = check_read_access(schema, user, rqlst, solution)
   312                 localcheck = check_read_access(schema, user, rqlst, solution)
   324                 msg %= (solution, user.login, ex.args[0], ex.args[1])
   315                 msg %= (solution, user.login, ex.args[0], ex.args[1])
   325                 msgs.append(msg)
   316                 msgs.append(msg)
   326                 LOGGER.info(msg)
   317                 LOGGER.info(msg)
   327             else:
   318             else:
   328                 newsolutions.append(solution)
   319                 newsolutions.append(solution)
   329                 if varkwargs:
   320                 # try to benefit of rqlexpr.check cache for entities which
   330                     # try to benefit of rqlexpr.check cache for entities which
   321                 # are specified by eid in query'args
   331                     # are specified by eid in query'args
   322                 for varname, eid in varkwargs.iteritems():
   332                     for varname, eid in varkwargs.iteritems():
   323                     try:
   333                         try:
   324                         rqlexprs = localcheck.pop(varname)
   334                             rqlexprs = localcheck.pop(varname)
   325                     except KeyError:
   335                         except KeyError:
   326                         continue
   336                             continue
   327                     if eid in neweids:
   337                         if eid in neweids:
   328                         continue
   338                             continue
   329                     for rqlexpr in rqlexprs:
   339                         for rqlexpr in rqlexprs:
   330                         if rqlexpr.check(session, eid):
   340                             if rqlexpr.check(self.session, eid):
   331                             break
   341                                 break
   332                     else:
   342                         else:
   333                         raise Unauthorized()
   343                             raise Unauthorized()
       
   344                 restricted_vars.update(localcheck)
   334                 restricted_vars.update(localcheck)
   345                 localchecks.setdefault(tuple(localcheck.iteritems()), []).append(solution)
   335                 localchecks.setdefault(tuple(localcheck.iteritems()), []).append(solution)
   346         # raise Unautorized exception if the user can't access to any solution
   336         # raise Unautorized exception if the user can't access to any solution
   347         if not newsolutions:
   337         if not newsolutions:
   348             raise Unauthorized('\n'.join(msgs))
   338             raise Unauthorized('\n'.join(msgs))