16 # You should have received a copy of the GNU Lesser General Public License along |
16 # You should have received a copy of the GNU Lesser General Public License along |
17 # with CubicWeb. If not, see <http://www.gnu.org/licenses/>. |
17 # with CubicWeb. If not, see <http://www.gnu.org/licenses/>. |
18 """cubicweb.migration unit tests""" |
18 """cubicweb.migration unit tests""" |
19 |
19 |
20 from os.path import dirname, join |
20 from os.path import dirname, join |
|
21 from unittest.mock import patch |
|
22 |
21 from logilab.common.testlib import TestCase, unittest_main |
23 from logilab.common.testlib import TestCase, unittest_main |
22 |
24 |
23 from cubicweb import devtools |
25 from cubicweb import devtools, utils |
|
26 from logilab.common.shellutils import ASK |
24 from cubicweb.cwconfig import CubicWebConfiguration |
27 from cubicweb.cwconfig import CubicWebConfiguration |
25 from cubicweb.migration import ( |
28 from cubicweb.migration import ( |
26 filter_scripts, |
29 filter_scripts, |
27 split_constraint, |
30 split_constraint, |
28 version_strictly_lower, |
31 version_strictly_lower, |
|
32 MigrationHelper, |
29 ) |
33 ) |
30 |
34 |
31 |
35 |
32 class Schema(dict): |
36 class Schema(dict): |
33 def has_entity(self, e_type): |
37 def has_entity(self, e_type): |
126 assert split_constraint("< 0.2.0") == ("<", "0.2.0") |
130 assert split_constraint("< 0.2.0") == ("<", "0.2.0") |
127 assert split_constraint("<=42.1.0") == ("<=", "42.1.0") |
131 assert split_constraint("<=42.1.0") == ("<=", "42.1.0") |
128 assert split_constraint("<= 42.1.0") == ("<=", "42.1.0") |
132 assert split_constraint("<= 42.1.0") == ("<=", "42.1.0") |
129 |
133 |
130 |
134 |
|
135 class WontColideWithOtherExceptionsException(Exception): |
|
136 pass |
|
137 |
|
138 |
|
139 class MigrationHelperTC(TestCase): |
|
140 @patch.object(utils, 'get_pdb') |
|
141 @patch.object(ASK, 'ask', return_value="pdb") |
|
142 def test_confirm_no_traceback(self, ask, get_pdb): |
|
143 post_mortem = get_pdb.return_value.post_mortem |
|
144 set_trace = get_pdb.return_value.set_trace |
|
145 |
|
146 # we need to break after post_mortem is called otherwise we get |
|
147 # infinite recursion |
|
148 set_trace.side_effect = WontColideWithOtherExceptionsException |
|
149 |
|
150 mh = MigrationHelper(config=None) |
|
151 |
|
152 with self.assertRaises(WontColideWithOtherExceptionsException): |
|
153 mh.confirm("some question") |
|
154 |
|
155 get_pdb.assert_called_once() |
|
156 set_trace.assert_called_once() |
|
157 post_mortem.assert_not_called() |
|
158 |
|
159 @patch.object(utils, 'get_pdb') |
|
160 @patch.object(ASK, 'ask', return_value="pdb") |
|
161 def test_confirm_got_traceback(self, ask, get_pdb): |
|
162 post_mortem = get_pdb.return_value.post_mortem |
|
163 set_trace = get_pdb.return_value.set_trace |
|
164 |
|
165 # we need to break after post_mortem is called otherwise we get |
|
166 # infinite recursion |
|
167 post_mortem.side_effect = WontColideWithOtherExceptionsException |
|
168 |
|
169 mh = MigrationHelper(config=None) |
|
170 |
|
171 fake_traceback = object() |
|
172 |
|
173 with self.assertRaises(WontColideWithOtherExceptionsException): |
|
174 mh.confirm("some question", traceback=fake_traceback) |
|
175 |
|
176 get_pdb.assert_called_once() |
|
177 set_trace.assert_not_called() |
|
178 post_mortem.assert_called_once() |
|
179 |
|
180 # we want post_mortem to actually receive the traceback |
|
181 self.assertEqual(post_mortem.call_args, ((fake_traceback,),)) |
|
182 |
|
183 |
131 if __name__ == '__main__': |
184 if __name__ == '__main__': |
132 unittest_main() |
185 unittest_main() |