--- a/cubicweb/test/unittest_migration.py Wed May 22 17:08:09 2019 +0200
+++ b/cubicweb/test/unittest_migration.py Wed May 22 17:10:06 2019 +0200
@@ -18,14 +18,18 @@
"""cubicweb.migration unit tests"""
from os.path import dirname, join
+from unittest.mock import patch
+
from logilab.common.testlib import TestCase, unittest_main
-from cubicweb import devtools
+from cubicweb import devtools, utils
+from logilab.common.shellutils import ASK
from cubicweb.cwconfig import CubicWebConfiguration
from cubicweb.migration import (
filter_scripts,
split_constraint,
version_strictly_lower,
+ MigrationHelper,
)
@@ -128,5 +132,54 @@
assert split_constraint("<= 42.1.0") == ("<=", "42.1.0")
+class WontColideWithOtherExceptionsException(Exception):
+ pass
+
+
+class MigrationHelperTC(TestCase):
+ @patch.object(utils, 'get_pdb')
+ @patch.object(ASK, 'ask', return_value="pdb")
+ def test_confirm_no_traceback(self, ask, get_pdb):
+ post_mortem = get_pdb.return_value.post_mortem
+ set_trace = get_pdb.return_value.set_trace
+
+ # we need to break after post_mortem is called otherwise we get
+ # infinite recursion
+ set_trace.side_effect = WontColideWithOtherExceptionsException
+
+ mh = MigrationHelper(config=None)
+
+ with self.assertRaises(WontColideWithOtherExceptionsException):
+ mh.confirm("some question")
+
+ get_pdb.assert_called_once()
+ set_trace.assert_called_once()
+ post_mortem.assert_not_called()
+
+ @patch.object(utils, 'get_pdb')
+ @patch.object(ASK, 'ask', return_value="pdb")
+ def test_confirm_got_traceback(self, ask, get_pdb):
+ post_mortem = get_pdb.return_value.post_mortem
+ set_trace = get_pdb.return_value.set_trace
+
+ # we need to break after post_mortem is called otherwise we get
+ # infinite recursion
+ post_mortem.side_effect = WontColideWithOtherExceptionsException
+
+ mh = MigrationHelper(config=None)
+
+ fake_traceback = object()
+
+ with self.assertRaises(WontColideWithOtherExceptionsException):
+ mh.confirm("some question", traceback=fake_traceback)
+
+ get_pdb.assert_called_once()
+ set_trace.assert_not_called()
+ post_mortem.assert_called_once()
+
+ # we want post_mortem to actually receive the traceback
+ self.assertEqual(post_mortem.call_args, ((fake_traceback,),))
+
+
if __name__ == '__main__':
unittest_main()