cubicweb/test/unittest_migration.py
changeset 12745 cc681b6fcffa
parent 12629 6b314fc558ed
--- a/cubicweb/test/unittest_migration.py	Wed May 22 17:08:09 2019 +0200
+++ b/cubicweb/test/unittest_migration.py	Wed May 22 17:10:06 2019 +0200
@@ -18,14 +18,18 @@
 """cubicweb.migration unit tests"""
 
 from os.path import dirname, join
+from unittest.mock import patch
+
 from logilab.common.testlib import TestCase, unittest_main
 
-from cubicweb import devtools
+from cubicweb import devtools, utils
+from logilab.common.shellutils import ASK
 from cubicweb.cwconfig import CubicWebConfiguration
 from cubicweb.migration import (
     filter_scripts,
     split_constraint,
     version_strictly_lower,
+    MigrationHelper,
 )
 
 
@@ -128,5 +132,54 @@
     assert split_constraint("<= 42.1.0") == ("<=", "42.1.0")
 
 
+class WontColideWithOtherExceptionsException(Exception):
+    pass
+
+
+class MigrationHelperTC(TestCase):
+    @patch.object(utils, 'get_pdb')
+    @patch.object(ASK, 'ask', return_value="pdb")
+    def test_confirm_no_traceback(self, ask, get_pdb):
+        post_mortem = get_pdb.return_value.post_mortem
+        set_trace = get_pdb.return_value.set_trace
+
+        # we need to break after post_mortem is called otherwise we get
+        # infinite recursion
+        set_trace.side_effect = WontColideWithOtherExceptionsException
+
+        mh = MigrationHelper(config=None)
+
+        with self.assertRaises(WontColideWithOtherExceptionsException):
+            mh.confirm("some question")
+
+        get_pdb.assert_called_once()
+        set_trace.assert_called_once()
+        post_mortem.assert_not_called()
+
+    @patch.object(utils, 'get_pdb')
+    @patch.object(ASK, 'ask', return_value="pdb")
+    def test_confirm_got_traceback(self, ask, get_pdb):
+        post_mortem = get_pdb.return_value.post_mortem
+        set_trace = get_pdb.return_value.set_trace
+
+        # we need to break after post_mortem is called otherwise we get
+        # infinite recursion
+        post_mortem.side_effect = WontColideWithOtherExceptionsException
+
+        mh = MigrationHelper(config=None)
+
+        fake_traceback = object()
+
+        with self.assertRaises(WontColideWithOtherExceptionsException):
+            mh.confirm("some question", traceback=fake_traceback)
+
+        get_pdb.assert_called_once()
+        set_trace.assert_not_called()
+        post_mortem.assert_called_once()
+
+        # we want post_mortem to actually receive the traceback
+        self.assertEqual(post_mortem.call_args, ((fake_traceback,),))
+
+
 if __name__ == '__main__':
     unittest_main()