[read security] minor optimizations
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Fri, 19 Mar 2010 08:18:31 +0100
changeset 4953 c8c0e10dbd97
parent 4951 7dc54e12c606
child 4954 96f67c5be0e6
[read security] minor optimizations
server/querier.py
--- a/server/querier.py	Thu Mar 18 17:39:17 2010 +0100
+++ b/server/querier.py	Fri Mar 19 08:18:31 2010 +0100
@@ -42,17 +42,6 @@
 
 # permission utilities ########################################################
 
-def var_kwargs(restriction, args):
-    varkwargs = {}
-    for rel in restriction.iget_nodes(Relation):
-        cmp = rel.children[1]
-        if rel.r_type == 'eid' and cmp.operator == '=' and \
-               not rel.neged(strict=True) and \
-               isinstance(cmp.children[0], Constant) and \
-               cmp.children[0].type == 'Substitute':
-            varkwargs[rel.children[0].name] = typed_eid(cmp.children[0].eval(args))
-    return varkwargs
-
 def check_no_password_selected(rqlst):
     """check that Password entities are not selected"""
     for solution in rqlst.solutions:
@@ -84,28 +73,26 @@
     localchecks = {}
     # iterate on defined_vars and not on solutions to ignore column aliases
     for varname in rqlst.defined_vars:
-        etype = solution[varname]
-        eschema = schema.eschema(etype)
+        eschema = schema.eschema(solution[varname])
         if eschema.final:
             continue
         if not user.matching_groups(eschema.get_groups('read')):
             erqlexprs = eschema.get_rqlexprs('read')
             if not erqlexprs:
-                ex = Unauthorized('read', etype)
+                ex = Unauthorized('read', solution[varname])
                 ex.var = varname
                 raise ex
-            #assert len(erqlexprs) == 1
-            localchecks[varname] = tuple(erqlexprs)
+            localchecks[varname] = erqlexprs
     return localchecks
 
-def noinvariant_vars(restricted, select, nbtrees):
+def add_noinvariant(noinvariant, restricted, select, nbtrees):
     # a variable can actually be invariant if it has not been restricted for
     # security reason or if security assertion hasn't modified the possible
     # solutions for the query
     if nbtrees != 1:
         for vname in restricted:
             try:
-                yield select.defined_vars[vname]
+                noinvariant.add(select.defined_vars[vname])
             except KeyError:
                 # this is an alias
                 continue
@@ -117,7 +104,7 @@
                 # this is an alias
                 continue
             if len(var.stinfo['possibletypes']) != 1:
-                yield var
+                noinvariant.add(var)
 
 def _expand_selection(terms, selected, aliases, select, newselect):
     for term in terms:
@@ -280,11 +267,11 @@
                     lcheckdef = [((varmap, 'X'), rqlexprs)
                                  for varmap, rqlexprs in lcheckdef]
                     rewrite(myrqlst, lcheckdef, lchecksolutions, self.args)
-                    noinvariant.update(noinvariant_vars(restricted, myrqlst, nbtrees))
+                    add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
                 if () in localchecks:
                     select.set_possible_types(localchecks[()])
                     add_types_restriction(self.schema, select)
-                    noinvariant.update(noinvariant_vars(restricted, select, nbtrees))
+                    add_noinvariant(noinvariant, restricted, select, nbtrees)
 
     def _check_permissions(self, rqlst):
         """return a dict defining "local checks", e.g. RQLExpression defined in
@@ -304,16 +291,20 @@
 
         note: rqlst should not have been simplified at this point
         """
-        user = self.session.user
+        session = self.session
+        user = session.user
         schema = self.schema
         msgs = []
+        neweids = session.transaction_data.get('neweids', ())
+        varkwargs = {}
+        if not session.transaction_data.get('security-rqlst-cache'):
+            for var in rqlst.defined_vars.itervalues():
+                for rel in var.stinfo['uidrels']:
+                    const = rel.children[1].children[0]
+                    varkwargs[var.name] = typed_eid(const.eval(self.args))
+                    break
         # dictionnary of variables restricted for security reason
         localchecks = {}
-        if rqlst.where is not None:
-            varkwargs = var_kwargs(rqlst.where, self.args)
-            neweids = self.session.transaction_data.get('neweids', ())
-        else:
-            varkwargs = None
         restricted_vars = set()
         newsolutions = []
         for solution in rqlst.solutions:
@@ -326,21 +317,20 @@
                 LOGGER.info(msg)
             else:
                 newsolutions.append(solution)
-                if varkwargs:
-                    # try to benefit of rqlexpr.check cache for entities which
-                    # are specified by eid in query'args
-                    for varname, eid in varkwargs.iteritems():
-                        try:
-                            rqlexprs = localcheck.pop(varname)
-                        except KeyError:
-                            continue
-                        if eid in neweids:
-                            continue
-                        for rqlexpr in rqlexprs:
-                            if rqlexpr.check(self.session, eid):
-                                break
-                        else:
-                            raise Unauthorized()
+                # try to benefit of rqlexpr.check cache for entities which
+                # are specified by eid in query'args
+                for varname, eid in varkwargs.iteritems():
+                    try:
+                        rqlexprs = localcheck.pop(varname)
+                    except KeyError:
+                        continue
+                    if eid in neweids:
+                        continue
+                    for rqlexpr in rqlexprs:
+                        if rqlexpr.check(session, eid):
+                            break
+                    else:
+                        raise Unauthorized()
                 restricted_vars.update(localcheck)
                 localchecks.setdefault(tuple(localcheck.iteritems()), []).append(solution)
         # raise Unautorized exception if the user can't access to any solution