devtools/testlib.py
brancholdstable
changeset 7078 bad26a22fe29
parent 7071 db7608cb32bc
child 7088 76e0dba5f8f3
--- a/devtools/testlib.py	Fri Mar 11 09:46:45 2011 +0100
+++ b/devtools/testlib.py	Tue Dec 07 12:18:20 2010 +0100
@@ -49,7 +49,7 @@
 from cubicweb.server.session import security_enabled
 from cubicweb.server.hook import SendMailOp
 from cubicweb.devtools import SYSTEM_ENTITIES, SYSTEM_RELATIONS, VIEW_VALIDATORS
-from cubicweb.devtools import BASE_URL, fake, htmlparser
+from cubicweb.devtools import BASE_URL, fake, htmlparser, DEFAULT_EMPTY_DB_ID
 from cubicweb.utils import json
 
 # low-level utilities ##########################################################
@@ -61,7 +61,8 @@
     def do_view(self, arg):
         import webbrowser
         data = self._getval(arg)
-        file('/tmp/toto.html', 'w').write(data)
+        with file('/tmp/toto.html', 'w') as toto:
+            toto.write(data)
         webbrowser.open('file:///tmp/toto.html')
 
 def line_context_filter(line_no, center, before=3, after=None):
@@ -83,22 +84,6 @@
         protected_entities = yams.schema.BASE_TYPES.union(SYSTEM_ENTITIES)
     return set(schema.entities()) - protected_entities
 
-def refresh_repo(repo, resetschema=False, resetvreg=False):
-    for pool in repo.pools:
-        pool.close(True)
-    repo.system_source.shutdown()
-    devtools.reset_test_database(repo.config)
-    for pool in repo.pools:
-        pool.reconnect()
-    repo._type_source_cache = {}
-    repo._extid_cache = {}
-    repo.querier._rql_cache = {}
-    for source in repo.sources:
-        source.reset_caches()
-    if resetschema:
-        repo.set_schema(repo.config.load_schema(), resetvreg=resetvreg)
-
-
 # email handling, to test emails sent by an application ########################
 
 MAILBOX = []
@@ -191,6 +176,19 @@
     configcls = devtools.ApptestConfiguration
     reset_schema = reset_vreg = False # reset schema / vreg between tests
     tags = TestCase.tags | Tags('cubicweb', 'cw_repo')
+    test_db_id = DEFAULT_EMPTY_DB_ID
+    _cnxs = set() # establised connection
+    _cnx  = None  # current connection
+
+    # Too much complicated stuff. the class doesn't need to bear the repo anymore
+    @classmethod
+    def set_cnx(cls, cnx):
+        cls._cnxs.add(cnx)
+        cls._cnx = cnx
+
+    @property
+    def cnx(self):
+        return self.__class__._cnx
 
     @classproperty
     def config(cls):
@@ -199,6 +197,7 @@
         Configuration is cached on the test class.
         """
         try:
+            assert not cls is CubicWebTC, "Don't use CubicWebTC directly to prevent database caching issue"
             return cls.__dict__['_config']
         except KeyError:
             home = abspath(join(dirname(sys.modules[cls.__module__].__file__), cls.appid))
@@ -237,36 +236,33 @@
         except: # not in server only configuration
             pass
 
+    #XXX this doesn't need to a be classmethod anymore
     @classmethod
     def _init_repo(cls):
         """init the repository and connection to it.
+        """
+        # setup configuration for test
+        cls.init_config(cls.config)
+        # get or restore and working db.
+        db_handler = devtools.get_test_db_handler(cls.config)
+        db_handler.build_db_cache(cls.test_db_id, cls.pre_setup_database)
 
-        Repository and connection are cached on the test class. Once
-        initialized, we simply reset connections and repository caches.
-        """
-        if not 'repo' in cls.__dict__:
-            cls._build_repo()
-        else:
-            try:
-                cls.cnx.rollback()
-            except ProgrammingError:
-                pass
-            cls._refresh_repo()
-
-    @classmethod
-    def _build_repo(cls):
-        cls.repo, cls.cnx = devtools.init_test_database(config=cls.config)
-        cls.init_config(cls.config)
-        cls.repo.hm.call_hooks('server_startup', repo=cls.repo)
+        cls.repo, cnx = db_handler.get_repo_and_cnx(cls.test_db_id)
+        # no direct assignation to cls.cnx anymore.
+        # cnx is now an instance property that use a class protected attributes.
+        cls.set_cnx(cnx)
         cls.vreg = cls.repo.vreg
-        cls.websession = DBAPISession(cls.cnx, cls.admlogin,
+        cls.websession = DBAPISession(cnx, cls.admlogin,
                                       {'password': cls.admpassword})
-        cls._orig_cnx = (cls.cnx, cls.websession)
+        cls._orig_cnx = (cnx, cls.websession)
         cls.config.repository = lambda x=None: cls.repo
 
-    @classmethod
-    def _refresh_repo(cls):
-        refresh_repo(cls.repo, cls.reset_schema, cls.reset_vreg)
+    def _close_cnx(self):
+        for cnx in list(self._cnxs):
+            if not cnx._closed:
+                cnx.rollback()
+                cnx.close()
+            self._cnxs.remove(cnx)
 
     # global resources accessors ###############################################
 
@@ -308,34 +304,47 @@
 
     def setUp(self):
         # monkey patch send mail operation so emails are sent synchronously
-        self._old_mail_postcommit_event = SendMailOp.postcommit_event
-        SendMailOp.postcommit_event = SendMailOp.sendmails
+        self._patch_SendMailOp()
         pause_tracing()
         previous_failure = self.__class__.__dict__.get('_repo_init_failed')
         if previous_failure is not None:
             self.skipTest('repository is not initialised: %r' % previous_failure)
         try:
             self._init_repo()
+            self.addCleanup(self._close_cnx)
         except Exception, ex:
             self.__class__._repo_init_failed = ex
             raise
         resume_tracing()
-        self._cnxs = []
         self.setup_database()
         self.commit()
         MAILBOX[:] = [] # reset mailbox
 
     def tearDown(self):
-        if not self.cnx._closed:
-            self.cnx.rollback()
-        for cnx in self._cnxs:
-            if not cnx._closed:
-                cnx.close()
-        SendMailOp.postcommit_event = self._old_mail_postcommit_event
+        # XXX hack until logilab.common.testlib is fixed
+        while self._cleanups:
+            cleanup, args, kwargs = self._cleanups.pop(-1)
+            cleanup(*args, **kwargs)
+
+    def _patch_SendMailOp(self):
+        # monkey patch send mail operation so emails are sent synchronously
+        _old_mail_postcommit_event = SendMailOp.postcommit_event
+        SendMailOp.postcommit_event = SendMailOp.sendmails
+        def reverse_SendMailOp_monkey_patch():
+            SendMailOp.postcommit_event = _old_mail_postcommit_event
+        self.addCleanup(reverse_SendMailOp_monkey_patch)
 
     def setup_database(self):
         """add your database setup code by overriding this method"""
 
+    @classmethod
+    def pre_setup_database(cls, session, config):
+        """add your pre database setup code by overriding this method
+
+        Do not forget to set the cls.test_db_id value to enable caching of the
+        result.
+        """
+
     # user / session management ###############################################
 
     def user(self, req=None):
@@ -372,9 +381,8 @@
         autoclose = kwargs.pop('autoclose', True)
         if not kwargs:
             kwargs['password'] = str(login)
-        self.cnx = repo_connect(self.repo, unicode(login), **kwargs)
+        self.set_cnx(repo_connect(self.repo, unicode(login), **kwargs))
         self.websession = DBAPISession(self.cnx)
-        self._cnxs.append(self.cnx)
         if login == self.vreg.config.anonymous_user()[0]:
             self.cnx.anonymous_connection = True
         if autoclose:
@@ -385,11 +393,8 @@
         if not self.cnx is self._orig_cnx[0]:
             if not self.cnx._closed:
                 self.cnx.close()
-            try:
-                self._cnxs.remove(self.cnx)
-            except ValueError:
-                pass
-        self.cnx, self.websession = self._orig_cnx
+        cnx, self.websession = self._orig_cnx
+        self.set_cnx(cnx)
 
     # db api ##################################################################
 
@@ -953,6 +958,8 @@
     """base class for test with auto-populating of the database"""
     __abstract__ = True
 
+    test_db_id = 'autopopulate'
+
     tags = CubicWebTC.tags | Tags('autopopulated')
 
     pdbclass = CubicWebDebugger
@@ -1086,7 +1093,9 @@
     tags = AutoPopulateTest.tags | Tags('web', 'generated')
 
     def setUp(self):
-        AutoPopulateTest.setUp(self)
+        assert not self.__class__ is AutomaticWebTest, 'Please subclass AutomaticWebTest to pprevent database caching issue'
+        super(AutomaticWebTest, self).setUp()
+
         # access to self.app for proper initialization of the authentication
         # machinery (else some views may fail)
         self.app