cubicweb/test/unittest_migration.py
changeset 12745 cc681b6fcffa
parent 12629 6b314fc558ed
equal deleted inserted replaced
12744:19aef4729d45 12745:cc681b6fcffa
    16 # You should have received a copy of the GNU Lesser General Public License along
    16 # You should have received a copy of the GNU Lesser General Public License along
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
    18 """cubicweb.migration unit tests"""
    18 """cubicweb.migration unit tests"""
    19 
    19 
    20 from os.path import dirname, join
    20 from os.path import dirname, join
       
    21 from unittest.mock import patch
       
    22 
    21 from logilab.common.testlib import TestCase, unittest_main
    23 from logilab.common.testlib import TestCase, unittest_main
    22 
    24 
    23 from cubicweb import devtools
    25 from cubicweb import devtools, utils
       
    26 from logilab.common.shellutils import ASK
    24 from cubicweb.cwconfig import CubicWebConfiguration
    27 from cubicweb.cwconfig import CubicWebConfiguration
    25 from cubicweb.migration import (
    28 from cubicweb.migration import (
    26     filter_scripts,
    29     filter_scripts,
    27     split_constraint,
    30     split_constraint,
    28     version_strictly_lower,
    31     version_strictly_lower,
       
    32     MigrationHelper,
    29 )
    33 )
    30 
    34 
    31 
    35 
    32 class Schema(dict):
    36 class Schema(dict):
    33     def has_entity(self, e_type):
    37     def has_entity(self, e_type):
   126     assert split_constraint("< 0.2.0") == ("<", "0.2.0")
   130     assert split_constraint("< 0.2.0") == ("<", "0.2.0")
   127     assert split_constraint("<=42.1.0") == ("<=", "42.1.0")
   131     assert split_constraint("<=42.1.0") == ("<=", "42.1.0")
   128     assert split_constraint("<= 42.1.0") == ("<=", "42.1.0")
   132     assert split_constraint("<= 42.1.0") == ("<=", "42.1.0")
   129 
   133 
   130 
   134 
       
   135 class WontColideWithOtherExceptionsException(Exception):
       
   136     pass
       
   137 
       
   138 
       
   139 class MigrationHelperTC(TestCase):
       
   140     @patch.object(utils, 'get_pdb')
       
   141     @patch.object(ASK, 'ask', return_value="pdb")
       
   142     def test_confirm_no_traceback(self, ask, get_pdb):
       
   143         post_mortem = get_pdb.return_value.post_mortem
       
   144         set_trace = get_pdb.return_value.set_trace
       
   145 
       
   146         # we need to break after post_mortem is called otherwise we get
       
   147         # infinite recursion
       
   148         set_trace.side_effect = WontColideWithOtherExceptionsException
       
   149 
       
   150         mh = MigrationHelper(config=None)
       
   151 
       
   152         with self.assertRaises(WontColideWithOtherExceptionsException):
       
   153             mh.confirm("some question")
       
   154 
       
   155         get_pdb.assert_called_once()
       
   156         set_trace.assert_called_once()
       
   157         post_mortem.assert_not_called()
       
   158 
       
   159     @patch.object(utils, 'get_pdb')
       
   160     @patch.object(ASK, 'ask', return_value="pdb")
       
   161     def test_confirm_got_traceback(self, ask, get_pdb):
       
   162         post_mortem = get_pdb.return_value.post_mortem
       
   163         set_trace = get_pdb.return_value.set_trace
       
   164 
       
   165         # we need to break after post_mortem is called otherwise we get
       
   166         # infinite recursion
       
   167         post_mortem.side_effect = WontColideWithOtherExceptionsException
       
   168 
       
   169         mh = MigrationHelper(config=None)
       
   170 
       
   171         fake_traceback = object()
       
   172 
       
   173         with self.assertRaises(WontColideWithOtherExceptionsException):
       
   174             mh.confirm("some question", traceback=fake_traceback)
       
   175 
       
   176         get_pdb.assert_called_once()
       
   177         set_trace.assert_not_called()
       
   178         post_mortem.assert_called_once()
       
   179 
       
   180         # we want post_mortem to actually receive the traceback
       
   181         self.assertEqual(post_mortem.call_args, ((fake_traceback,),))
       
   182 
       
   183 
   131 if __name__ == '__main__':
   184 if __name__ == '__main__':
   132     unittest_main()
   185     unittest_main()