[session] enhance session's transaction storage handling to fix cases where commit/rollback is done while in the context of hooks_control/security_enabled managers. Closes #1412648: commit or rollback during postcreate reset hooks control state
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Tue, 10 May 2011 12:07:54 +0200
changeset 7350 c2452cd57026
parent 7349 43416f63eca9
child 7352 d68f9319bfda
[session] enhance session's transaction storage handling to fix cases where commit/rollback is done while in the context of hooks_control/security_enabled managers. Closes #1412648: commit or rollback during postcreate reset hooks control state
server/session.py
server/test/unittest_session.py
--- a/server/session.py	Tue May 10 10:28:29 2011 +0200
+++ b/server/session.py	Tue May 10 12:07:54 2011 +0200
@@ -98,21 +98,13 @@
         self.categories = categories
 
     def __enter__(self):
-        self.oldmode = self.session.set_hooks_mode(self.mode)
-        if self.mode is self.session.HOOKS_DENY_ALL:
-            self.changes = self.session.enable_hook_categories(*self.categories)
-        else:
-            self.changes = self.session.disable_hook_categories(*self.categories)
+        self.oldmode, self.changes = self.session.init_hooks_mode_categories(
+            self.mode, self.categories)
 
     def __exit__(self, exctype, exc, traceback):
-        if self.changes:
-            if self.mode is self.session.HOOKS_DENY_ALL:
-                self.session.disable_hook_categories(*self.changes)
-            else:
-                self.session.enable_hook_categories(*self.changes)
-        self.session.set_hooks_mode(self.oldmode)
+        self.session.reset_hooks_mode_categories(self.oldmode, self.mode, self.changes)
 
-INDENT = ''
+
 class security_enabled(object):
     """context manager to control security w/ session.execute, since by
     default security is disabled on queries executed on the repository
@@ -124,29 +116,18 @@
         self.write = write
 
     def __enter__(self):
-#        global INDENT
-        if self.read is not None:
-            self.oldread = self.session.set_read_security(self.read)
-#            print INDENT + 'read', self.read, self.oldread
-        if self.write is not None:
-            self.oldwrite = self.session.set_write_security(self.write)
-#            print INDENT + 'write', self.write, self.oldwrite
-#        INDENT += '  '
+        self.oldread, self.oldwrite = self.session.init_security(
+            self.read, self.write)
 
     def __exit__(self, exctype, exc, traceback):
-#        global INDENT
-#        INDENT = INDENT[:-2]
-        if self.read is not None:
-            self.session.set_read_security(self.oldread)
-#            print INDENT + 'reset read to', self.oldread
-        if self.write is not None:
-            self.session.set_write_security(self.oldwrite)
-#            print INDENT + 'reset write to', self.oldwrite
+        self.session.reset_security(self.oldread, self.oldwrite)
 
 
 class TransactionData(object):
     def __init__(self, txid):
         self.transactionid = txid
+        self.ctx_count = 0
+
 
 class Session(RequestSessionBase):
     """tie session id, user, connections pool and other session data all
@@ -209,6 +190,9 @@
         session = Session(user, self.repo)
         threaddata = session._threaddata
         threaddata.pool = self.pool
+        # we attributed a pool, need to update ctx_count else it will be freed
+        # while undesired
+        threaddata.ctx_count = 1
         # share pending_operations, else operation added in the hi-jacked
         # session such as SendMailOp won't ever be processed
         threaddata.pending_operations = self.pending_operations
@@ -233,7 +217,7 @@
 
     def add_relations(self, relations):
         '''set many relation using a shortcut similar to the one in add_relation
-        
+
         relations is a list of 2-uples, the first element of each
         2-uple is the rtype, and the second is a list of (fromeid,
         toeid) tuples
@@ -405,6 +389,29 @@
 
     DEFAULT_SECURITY = object() # evaluated to true by design
 
+    def init_security(self, read, write):
+        if read is None:
+            oldread = None
+        else:
+            oldread = self.set_read_security(read)
+        if write is None:
+            oldwrite = None
+        else:
+            oldwrite = self.set_write_security(write)
+        self._threaddata.ctx_count += 1
+        return oldread, oldwrite
+
+    def reset_security(self, read, write):
+        txstore = self._threaddata
+        txstore.ctx_count -= 1
+        if txstore.ctx_count == 0:
+            self._clear_thread_storage(txstore)
+        else:
+            if read is not None:
+                self.set_read_security(read)
+            if write is not None:
+                self.set_write_security(write)
+
     @property
     def read_security(self):
         """return a boolean telling if read security is activated or not"""
@@ -500,6 +507,28 @@
         self._threaddata.hooks_mode = mode
         return oldmode
 
+    def init_hooks_mode_categories(self, mode, categories):
+        oldmode = self.set_hooks_mode(mode)
+        if mode is self.HOOKS_DENY_ALL:
+            changes = self.enable_hook_categories(*categories)
+        else:
+            changes = self.disable_hook_categories(*categories)
+        self._threaddata.ctx_count += 1
+        return oldmode, changes
+
+    def reset_hooks_mode_categories(self, oldmode, mode, categories):
+        txstore = self._threaddata
+        txstore.ctx_count -= 1
+        if txstore.ctx_count == 0:
+            self._clear_thread_storage(txstore)
+        else:
+            if categories:
+                if mode is self.HOOKS_DENY_ALL:
+                    return self.disable_hook_categories(*categories)
+                else:
+                    return self.enable_hook_categories(*categories)
+            self.set_hooks_mode(oldmode)
+
     @property
     def disabled_hook_categories(self):
         try:
@@ -624,6 +653,7 @@
         if self.pool is None:
             # get pool first to avoid race-condition
             self._threaddata.pool = pool = self.repo._get_pool()
+            self._threaddata.ctx_count += 1
             try:
                 pool.pool_set()
             except:
@@ -658,6 +688,7 @@
             # even in read mode, we must release the current transaction
             self._free_thread_pool(threading.currentThread(), pool)
             del self._threaddata.pool
+            self._threaddata.ctx_count -= 1
 
     def _touch(self):
         """update latest session usage timestamp and reset mode to read"""
@@ -757,18 +788,28 @@
             pass
         else:
             if reset_pool:
-                self._tx_data.pop(txstore.transactionid, None)
-                try:
-                    del self.__threaddata.txdata
-                except AttributeError:
-                    pass
+                self.reset_pool()
+                if txstore.ctx_count == 0:
+                    self._clear_thread_storage(txstore)
+                else:
+                    self._clear_tx_storage(txstore)
             else:
-                for name in ('commit_state', 'transaction_data',
-                             'pending_operations', '_rewriter'):
-                    try:
-                        delattr(txstore, name)
-                    except AttributeError:
-                        continue
+                self._clear_tx_storage(txstore)
+
+    def _clear_thread_storage(self, txstore):
+        self._tx_data.pop(txstore.transactionid, None)
+        try:
+            del self.__threaddata.txdata
+        except AttributeError:
+            pass
+
+    def _clear_tx_storage(self, txstore):
+        for name in ('commit_state', 'transaction_data',
+                     'pending_operations', '_rewriter'):
+            try:
+                delattr(txstore, name)
+            except AttributeError:
+                continue
 
     def commit(self, reset_pool=True):
         """commit the current session's transaction"""
--- a/server/test/unittest_session.py	Tue May 10 10:28:29 2011 +0200
+++ b/server/test/unittest_session.py	Tue May 10 12:07:54 2011 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -15,13 +15,10 @@
 #
 # You should have received a copy of the GNU Lesser General Public License along
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
-"""
-
-"""
 from logilab.common.testlib import TestCase, unittest_main, mock_object
 
 from cubicweb.devtools.testlib import CubicWebTC
-from cubicweb.server.session import _make_description
+from cubicweb.server.session import _make_description, hooks_control
 
 class Variable:
     def __init__(self, name):
@@ -46,11 +43,38 @@
         self.assertEqual(_make_description((Function('max', 'A'), Variable('B')), {}, solution),
                           ['Int','CWUser'])
 
+
 class InternalSessionTC(CubicWebTC):
     def test_dbapi_query(self):
         session = self.repo.internal_session()
         self.assertFalse(session.running_dbapi_query)
         session.close()
 
+
+class SessionTC(CubicWebTC):
+
+    def test_hooks_control(self):
+        session = self.session
+        self.assertEqual(session.hooks_mode, session.HOOKS_ALLOW_ALL)
+        self.assertEqual(session.disabled_hook_categories, set())
+        self.assertEqual(session.enabled_hook_categories, set())
+        self.assertEqual(len(session._tx_data), 1)
+        with hooks_control(session, session.HOOKS_DENY_ALL, 'metadata'):
+            self.assertEqual(session.hooks_mode, session.HOOKS_DENY_ALL)
+            self.assertEqual(session.disabled_hook_categories, set())
+            self.assertEqual(session.enabled_hook_categories, set(('metadata',)))
+            session.commit()
+            self.assertEqual(session.hooks_mode, session.HOOKS_DENY_ALL)
+            self.assertEqual(session.disabled_hook_categories, set())
+            self.assertEqual(session.enabled_hook_categories, set(('metadata',)))
+            session.rollback()
+            self.assertEqual(session.hooks_mode, session.HOOKS_DENY_ALL)
+            self.assertEqual(session.disabled_hook_categories, set())
+            self.assertEqual(session.enabled_hook_categories, set(('metadata',)))
+        # leaving context manager with no transaction running should reset the
+        # transaction local storage (and associated pool)
+        self.assertEqual(session._tx_data, {})
+        self.assertEqual(session.pool, None)
+
 if __name__ == '__main__':
     unittest_main()