[session] call rollback in Connection.__exit__
If we just free the cnxset and clear the Connection, we're missing the
pending operation's rollback_event callbacks, which may be needed for
cleanup.
--- a/server/session.py Thu Jul 24 14:52:16 2014 +0200
+++ b/server/session.py Fri Jul 25 16:24:44 2014 +0200
@@ -532,8 +532,7 @@
def __exit__(self, exctype=None, excvalue=None, tb=None):
assert self._open # actually already open
assert self._cnxset_count == 0
- self._free_cnxset(ignoremode=True)
- self.clear()
+ self.rollback()
self._open = False
--- a/server/test/unittest_session.py Thu Jul 24 14:52:16 2014 +0200
+++ b/server/test/unittest_session.py Fri Jul 25 16:24:44 2014 +0200
@@ -18,6 +18,8 @@
from cubicweb.devtools.testlib import CubicWebTC
from cubicweb.server.session import HOOKS_ALLOW_ALL, HOOKS_DENY_ALL
+from cubicweb.server import hook
+from cubicweb.predicates import is_instance
class InternalSessionTC(CubicWebTC):
def test_dbapi_query(self):
@@ -76,7 +78,7 @@
self.assertEqual(set(), session.disabled_hook_categories)
self.assertEqual(set(), session.enabled_hook_categories)
- def test_explicite_connection(self):
+ def test_explicit_connection(self):
with self.session.new_cnx() as cnx:
rset = cnx.execute('Any X LIMIT 1 WHERE X is CWUser')
self.assertEqual(1, len(rset))
@@ -98,7 +100,24 @@
self.assertIsNotNone(new_user.login)
self.assertFalse(cnx._open)
-
+ def test_connection_exit(self):
+ """exiting a connection should roll back the transaction, including any
+ pending operations"""
+ self.rollbacked = False
+ class RollbackOp(hook.Operation):
+ _test = self
+ def rollback_event(self):
+ self._test.rollbacked = True
+ class RollbackHook(hook.Hook):
+ __regid__ = 'rollback'
+ events = ('after_update_entity',)
+ __select__ = hook.Hook.__select__ & is_instance('CWGroup')
+ def __call__(self):
+ RollbackOp(self._cw)
+ with self.temporary_appobjects(RollbackHook):
+ with self.admin_access.client_cnx() as cnx:
+ cnx.execute('SET G name "foo" WHERE G is CWGroup, G name "managers"')
+ self.assertTrue(self.rollbacked)
if __name__ == '__main__':
from logilab.common.testlib import unittest_main