devtools/__init__.py
brancholdstable
changeset 7078 bad26a22fe29
parent 7071 db7608cb32bc
child 7090 d9e6e79e023a
--- a/devtools/__init__.py	Fri Mar 11 09:46:45 2011 +0100
+++ b/devtools/__init__.py	Tue Dec 07 12:18:20 2010 +0100
@@ -22,12 +22,19 @@
 import os
 import sys
 import logging
+import shutil
+import pickle
+import glob
+import warnings
 from datetime import timedelta
 from os.path import (abspath, join, exists, basename, dirname, normpath, split,
-                     isfile, isabs, splitext)
+                     isfile, isabs, splitext, isdir, expanduser)
+from functools import partial
+import hashlib
 
 from logilab.common.date import strptime
-from cubicweb import CW_SOFTWARE_ROOT, ConfigurationError, schema, cwconfig
+from logilab.common.decorators import cached, clear_cache
+from cubicweb import CW_SOFTWARE_ROOT, ConfigurationError, schema, cwconfig, BadConnectionId
 from cubicweb.server.serverconfig import ServerConfiguration
 from cubicweb.etwist.twconfig import TwistedConfiguration
 
@@ -78,13 +85,49 @@
                               },
                    }
 
+def turn_repo_off(repo):
+    """ Idea: this is less costly than a full re-creation of the repo object.
+    off:
+    * session are closed,
+    * pools are closed
+    * system source is shutdown
+    """
+    if not repo._needs_refresh:
+        for sessionid in list(repo._sessions):
+            warnings.warn('%s Open session found while turning repository off'
+                          %sessionid, RuntimeWarning)
+            try:
+                repo.close(sessionid)
+            except BadConnectionId: #this is strange ? thread issue ?
+                print 'XXX unknown session', sessionid
+        for pool in repo.pools:
+            pool.close(True)
+        repo.system_source.shutdown()
+        repo._needs_refresh = True
+        repo._has_started = False
+
+def turn_repo_on(repo):
+    """Idea: this is less costly than a full re-creation of the repo object.
+    on:
+    * pools are connected
+    * cache are cleared
+    """
+    if repo._needs_refresh:
+        for pool in repo.pools:
+            pool.reconnect()
+        repo._type_source_cache = {}
+        repo._extid_cache = {}
+        repo.querier._rql_cache = {}
+        for source in repo.sources:
+            source.reset_caches()
+        repo._needs_refresh = False
+
 
 class TestServerConfiguration(ServerConfiguration):
     mode = 'test'
     set_language = False
     read_instance_schema = False
     init_repository = True
-    db_require_setup = True
 
     def __init__(self, appid='data', apphome=None, log_threshold=logging.CRITICAL+10):
         # must be set before calling parent __init__
@@ -216,131 +259,402 @@
               self.view('foaf', rset)
 
     """
-    db_require_setup = False    # skip init_db / reset_db steps
     read_instance_schema = True # read schema from database
 
 
 # test database handling #######################################################
 
-def init_test_database(config=None, appid='data', apphome=None):
-    """init a test database for a specific driver"""
-    from cubicweb.dbapi import in_memory_repo_cnx
-    config = config or TestServerConfiguration(appid, apphome=apphome)
-    sources = config.sources()
-    driver = sources['system']['db-driver']
-    if config.db_require_setup:
-        if driver == 'sqlite':
-            init_test_database_sqlite(config)
-        elif driver == 'postgres':
-            init_test_database_postgres(config)
+DEFAULT_EMPTY_DB_ID = '__default_empty_db__'
+
+class TestDataBaseHandler(object):
+    DRIVER = None
+    db_cache = {}
+    explored_glob = set()
+
+    def __init__(self, config):
+        self.config = config
+        self._repo = None
+        # pure consistency check
+        assert self.system_source['db-driver'] == self.DRIVER
+
+    def _ensure_test_backup_db_dir(self):
+        """Return path of directory for database backup.
+
+        The function create it if necessary"""
+        backupdir = join(self.config.apphome, 'database')
+        if not isdir(backupdir):
+            os.makedirs(backupdir)
+        return backupdir
+
+    def config_path(self, db_id):
+        """Path for config backup of a given database id"""
+        return self.absolute_backup_file(db_id, 'config')
+
+    def absolute_backup_file(self, db_id, suffix):
+        """Path for config backup of a given database id"""
+        dbname = self.dbname.replace('-', '_')
+        assert '.' not in db_id
+        filename = '%s-%s.%s' % (dbname, db_id, suffix)
+        return join(self._ensure_test_backup_db_dir(), filename)
+
+    def db_cache_key(self, db_id, dbname=None):
+        """Build a database cache key for a db_id with the current config
+
+        This key is meant to be used in the cls.db_cache mapping"""
+        if dbname is None:
+            dbname = self.dbname
+        dbname = os.path.basename(dbname)
+        dbname = dbname.replace('-', '_')
+        return (self.config.apphome, dbname, db_id)
+
+    def backup_database(self, db_id):
+        """Store the content of the current database as <db_id>
+
+        The config used are also stored."""
+        backup_data = self._backup_database(db_id)
+        config_path = self.config_path(db_id)
+        # XXX we dump a dict of the config
+        # This is an experimental to help config dependant setup (like BFSS) to
+        # be propertly restored
+        with open(config_path, 'wb') as conf_file:
+            conf_file.write(pickle.dumps(dict(self.config)))
+        self.db_cache[self.db_cache_key(db_id)] = (backup_data, config_path)
+
+    def _backup_database(self, db_id):
+        """Actual backup the current database.
+
+        return a value to be stored in db_cache to allow restoration"""
+        raise NotImplementedError()
+
+    def restore_database(self, db_id):
+        """Restore a database.
+
+        takes as argument value stored in db_cache by self._backup_database"""
+        # XXX set a clearer error message ???
+        backup_coordinates, config_path = self.db_cache[self.db_cache_key(db_id)]
+        # reload the config used to create the database.
+        config = pickle.loads(open(config_path, 'rb').read())
+        # shutdown repo before changing database content
+        if self._repo is not None:
+            self._repo.turn_repo_off()
+        self._restore_database(backup_coordinates, config)
+
+    def _restore_database(self, backup_coordinates, config):
+        """Actual restore of the current database.
+
+        Use the value tostored in db_cache as input """
+        raise NotImplementedError()
+
+    def get_repo(self, startup=False):
+        """ return Repository object on the current database.
+
+        (turn the current repo object "on" if there is one or recreate one)
+        if startup is True, server startup server hooks will be called if needed
+        """
+        if self._repo is None:
+            self._repo = self._new_repo(self.config)
+        repo = self._repo
+        repo.turn_repo_on()
+        if startup and not repo._has_started:
+            repo.hm.call_hooks('server_startup', repo=repo)
+            repo._has_started = True
+        return repo
+
+    def _new_repo(self, config):
+        """Factory method to create a new Repository Instance"""
+        from cubicweb.dbapi import in_memory_repo
+        config._cubes = None
+        repo = in_memory_repo(config)
+        # extending Repository class
+        repo._has_started = False
+        repo._needs_refresh = False
+        repo.turn_repo_on = partial(turn_repo_on, repo)
+        repo.turn_repo_off = partial(turn_repo_off, repo)
+        return repo
+
+
+    def get_cnx(self):
+        """return Connection object ont he current repository"""
+        from cubicweb.dbapi import in_memory_cnx
+        repo = self.get_repo()
+        sources = self.config.sources()
+        login  = unicode(sources['admin']['login'])
+        password = sources['admin']['password'] or 'xxx'
+        cnx = in_memory_cnx(repo, login, password=password)
+        return cnx
+
+    def get_repo_and_cnx(self, db_id=DEFAULT_EMPTY_DB_ID):
+        """Reset database with the current db_id and return (repo, cnx)
+
+        A database *MUST* have been build with the current <db_id> prior to
+        call this method. See the ``build_db_cache`` method. The returned
+        repository have it's startup hooks called and the connection is
+        establised as admin."""
+
+        self.restore_database(db_id)
+        repo = self.get_repo(startup=True)
+        cnx  = self.get_cnx()
+        return repo, cnx
+
+    @property
+    def system_source(self):
+        sources = self.config.sources()
+        return sources['system']
+
+    @property
+    def dbname(self):
+        return self.system_source['db-name']
+
+    def init_test_database():
+        """actual initialisation of the database"""
+        raise ValueError('no initialization function for driver %r' % driver)
+
+    def has_cache(self, db_id):
+        """Check if a given database id exist in cb cache for the current config"""
+        cache_glob = self.absolute_backup_file('*', '*')
+        if cache_glob not in self.explored_glob:
+            self.discover_cached_db()
+        return self.db_cache_key(db_id) in self.db_cache
+
+    def discover_cached_db(self):
+        """Search available db_if for the current config"""
+        cache_glob = self.absolute_backup_file('*', '*')
+        directory = os.path.dirname(cache_glob)
+        entries={}
+        candidates = glob.glob(cache_glob)
+        for filepath in candidates:
+            data = os.path.basename(filepath)
+            # database backup are in the forms are <dbname>-<db_id>.<backtype>
+            dbname, data = data.split('-', 1)
+            db_id, filetype = data.split('.', 1)
+            entries.setdefault((dbname, db_id), {})[filetype] = filepath
+        for (dbname, db_id), entry in entries.iteritems():
+            # apply necessary transformation from the driver
+            value = self.process_cache_entry(directory, dbname, db_id, entry)
+            assert 'config' in entry
+            if value is not None: # None value means "not handled by this driver
+                                  # XXX Ignored value are shadowed to other Handler if cache are common.
+                key = self.db_cache_key(db_id, dbname=dbname)
+                self.db_cache[key] = value, entry['config']
+        self.explored_glob.add(cache_glob)
+
+    def process_cache_entry(self, directory, dbname, db_id, entry):
+        """Transforms potential cache entry to proper backup coordinate
+
+        entry argument is a "filetype" -> "filepath" mapping
+        Return None if an entry should be ignored."""
+        return None
+
+    def build_db_cache(self, test_db_id=DEFAULT_EMPTY_DB_ID, pre_setup_func=None):
+        """Build Database cache for ``test_db_id`` if a cache doesn't exist
+
+        if ``test_db_id is DEFAULT_EMPTY_DB_ID`` self.init_test_database is
+        called. otherwise, DEFAULT_EMPTY_DB_ID is build/restored and
+        ``pre_setup_func`` to setup the database.
+
+        This function backup any database it build"""
+
+        if self.has_cache(test_db_id):
+            return #test_db_id, 'already in cache'
+        if test_db_id is DEFAULT_EMPTY_DB_ID:
+            self.init_test_database()
         else:
-            raise ValueError('no initialization function for driver %r' % driver)
-    config._cubes = None # avoid assertion error
-    repo, cnx = in_memory_repo_cnx(config, unicode(sources['admin']['login']),
-                              password=sources['admin']['password'] or 'xxx')
-    if driver == 'sqlite':
-        install_sqlite_patch(repo.querier)
-    return repo, cnx
-
-def reset_test_database(config):
-    """init a test database for a specific driver"""
-    if not config.db_require_setup:
-        return
-    driver = config.sources()['system']['db-driver']
-    if driver == 'sqlite':
-        reset_test_database_sqlite(config)
-    elif driver == 'postgres':
-        init_test_database_postgres(config)
-    else:
-        raise ValueError('no reset function for driver %r' % driver)
-
+            print 'Building %s for database %s' % (test_db_id, self.dbname)
+            self.build_db_cache(DEFAULT_EMPTY_DB_ID)
+            self.restore_database(DEFAULT_EMPTY_DB_ID)
+            repo = self.get_repo(startup=True)
+            cnx = self.get_cnx()
+            session = repo._sessions[cnx.sessionid]
+            session.set_pool()
+            _commit = session.commit
+            def always_pooled_commit():
+                _commit()
+                session.set_pool()
+            session.commit = always_pooled_commit
+            pre_setup_func(session, self.config)
+            session.commit()
+            cnx.close()
+        self.backup_database(test_db_id)
 
 ### postgres test database handling ############################################
 
-def init_test_database_postgres(config):
-    """initialize a fresh postgresql databse used for testing purpose"""
-    from logilab.database import get_db_helper
-    from cubicweb.server import init_repository
-    from cubicweb.server.serverctl import (createdb, system_source_cnx,
-                                           _db_sys_cnx)
-    source = config.sources()['system']
-    dbname = source['db-name']
-    templdbname = dbname + '_template'
-    helper = get_db_helper('postgres')
-    # connect on the dbms system base to create our base
-    dbcnx = _db_sys_cnx(source, 'CREATE DATABASE and / or USER', verbose=0)
-    cursor = dbcnx.cursor()
-    try:
-        if dbname in helper.list_databases(cursor):
-            cursor.execute('DROP DATABASE %s' % dbname)
-        if not templdbname in helper.list_databases(cursor):
-            source['db-name'] = templdbname
-            createdb(helper, source, dbcnx, cursor)
-            dbcnx.commit()
-            cnx = system_source_cnx(source, special_privs='LANGUAGE C', verbose=0)
+class PostgresTestDataBaseHandler(TestDataBaseHandler):
+
+    # XXX
+    # XXX PostgresTestDataBaseHandler Have not been tested at all.
+    # XXX
+    DRIVER = 'postgres'
+
+    @property
+    @cached
+    def helper(self):
+        from logilab.database import get_db_helper
+        return get_db_helper('postgres')
+
+    @property
+    @cached
+    def dbcnx(self):
+        from cubicweb.server.serverctl import _db_sys_cnx
+        return  _db_sys_cnx(self.system_source, 'CREATE DATABASE and / or USER', verbose=0)
+
+    @property
+    @cached
+    def cursor(self):
+        return self.dbcnx.cursor()
+
+    def init_test_database(self):
+        """initialize a fresh postgresql databse used for testing purpose"""
+        from cubicweb.server import init_repository
+        from cubicweb.server.serverctl import system_source_cnx, createdb
+        # connect on the dbms system base to create our base
+        try:
+            self._drop(self.dbname)
+
+            createdb(self.helper, self.system_source, self.dbcnx, self.cursor)
+            self.dbcnx.commit()
+            cnx = system_source_cnx(self.system_source, special_privs='LANGUAGE C', verbose=0)
             templcursor = cnx.cursor()
-            # XXX factorize with db-create code
-            helper.init_fti_extensions(templcursor)
-            # install plpythonu/plpgsql language if not installed by the cube
-            langs = sys.platform == 'win32' and ('plpgsql',) or ('plpythonu', 'plpgsql')
-            for extlang in langs:
-                helper.create_language(templcursor, extlang)
-            cnx.commit()
-            templcursor.close()
-            cnx.close()
-            init_repository(config, interactive=False)
-            source['db-name'] = dbname
-    except:
-        dbcnx.rollback()
-        # XXX drop template
-        raise
-    createdb(helper, source, dbcnx, cursor, template=templdbname)
-    dbcnx.commit()
-    dbcnx.close()
+            try:
+                # XXX factorize with db-create code
+                self.helper.init_fti_extensions(templcursor)
+                # install plpythonu/plpgsql language if not installed by the cube
+                langs = sys.platform == 'win32' and ('plpgsql',) or ('plpythonu', 'plpgsql')
+                for extlang in langs:
+                    self.helper.create_language(templcursor, extlang)
+                cnx.commit()
+            finally:
+                templcursor.close()
+                cnx.close()
+            init_repository(self.config, interactive=False)
+        except:
+            self.dbcnx.rollback()
+            print >> sys.stderr, 'building', self.dbname, 'failed'
+            #self._drop(self.dbname)
+            raise
+
+    def helper_clear_cache(self):
+        self.dbcnx.commit()
+        self.dbcnx.close()
+        clear_cache(self, 'dbcnx')
+        clear_cache(self, 'helper')
+        clear_cache(self, 'cursor')
+
+    def __del__(self):
+        self.helper_clear_cache()
+
+    @property
+    def _config_id(self):
+        return hashlib.sha1(self.config.apphome).hexdigest()[:10]
+
+    def _backup_name(self, db_id): # merge me with parent
+        backup_name = '_'.join(('cache', self._config_id, self.dbname, db_id))
+        return backup_name.lower()
+
+    def _drop(self, db_name):
+        if db_name in self.helper.list_databases(self.cursor):
+            #print 'dropping overwritted database:', db_name
+            self.cursor.execute('DROP DATABASE %s' % db_name)
+            self.dbcnx.commit()
+
+    def _backup_database(self, db_id):
+        """Actual backup the current database.
+
+        return a value to be stored in db_cache to allow restoration"""
+        from cubicweb.server.serverctl import createdb
+        orig_name = self.system_source['db-name']
+        try:
+            backup_name = self._backup_name(db_id)
+            #print 'storing postgres backup as', backup_name
+            self._drop(backup_name)
+            self.system_source['db-name'] = backup_name
+            createdb(self.helper, self.system_source, self.dbcnx, self.cursor, template=orig_name)
+            self.dbcnx.commit()
+            return backup_name
+        finally:
+            self.system_source['db-name'] = orig_name
+
+    def _restore_database(self, backup_coordinates, config):
+        from cubicweb.server.serverctl import createdb
+        """Actual restore of the current database.
+
+        Use the value tostored in db_cache as input """
+        #print 'restoring postgrest backup from', backup_coordinates
+        self._drop(self.dbname)
+        createdb(self.helper, self.system_source, self.dbcnx, self.cursor,
+                 template=backup_coordinates)
+        self.dbcnx.commit()
+
+
 
 ### sqlserver2005 test database handling #######################################
 
-def init_test_database_sqlserver2005(config):
-    """initialize a fresh sqlserver databse used for testing purpose"""
-    if config.init_repository:
-        from cubicweb.server import init_repository
-        init_repository(config, interactive=False, drop=True)
+class SQLServerTestDataBaseHandler(TestDataBaseHandler):
+    DRIVER = 'sqlserver'
+
+    # XXX complete me
+
+    def init_test_database(self):
+        """initialize a fresh sqlserver databse used for testing purpose"""
+        if self.config.init_repository:
+            from cubicweb.server import init_repository
+            init_repository(config, interactive=False, drop=True)
 
 ### sqlite test database handling ##############################################
 
-def cleanup_sqlite(dbfile, removetemplate=False):
-    try:
-        os.remove(dbfile)
-        os.remove('%s-journal' % dbfile)
-    except OSError:
-        pass
-    if removetemplate:
+class SQLiteTestDataBaseHandler(TestDataBaseHandler):
+    DRIVER = 'sqlite'
+
+    @staticmethod
+    def _cleanup_database(dbfile):
         try:
-            os.remove('%s-template' % dbfile)
+            os.remove(dbfile)
+            os.remove('%s-journal' % dbfile)
         except OSError:
             pass
 
-def reset_test_database_sqlite(config):
-    import shutil
-    dbfile = config.sources()['system']['db-name']
-    cleanup_sqlite(dbfile)
-    template = '%s-template' % dbfile
-    if exists(template):
-        shutil.copy(template, dbfile)
-        return True
-    return False
+    def absolute_dbfile(self):
+        """absolute path of current database file"""
+        dbfile = join(self._ensure_test_backup_db_dir(),
+                      self.config.sources()['system']['db-name'])
+        self.config.sources()['system']['db-name'] = dbfile
+        return dbfile
+
+
+    def process_cache_entry(self, directory, dbname, db_id, entry):
+        return entry.get('sqlite')
 
-def init_test_database_sqlite(config):
-    """initialize a fresh sqlite databse used for testing purpose"""
-    # remove database file if it exists
-    dbfile = join(config.apphome, config.sources()['system']['db-name'])
-    config.sources()['system']['db-name'] = dbfile
-    if not reset_test_database_sqlite(config):
+    def _backup_database(self, db_id=DEFAULT_EMPTY_DB_ID):
+        # XXX remove database file if it exists ???
+        dbfile = self.absolute_dbfile()
+        backup_file = self.absolute_backup_file(db_id, 'sqlite')
+        shutil.copy(dbfile, backup_file)
+        # Usefull to debug WHO write a database
+        # backup_stack = self.absolute_backup_file(db_id, '.stack')
+        #with open(backup_stack, 'w') as backup_stack_file:
+        #    import traceback
+        #    traceback.print_stack(file=backup_stack_file)
+        return backup_file
+
+    def _new_repo(self, config):
+        repo = super(SQLiteTestDataBaseHandler, self)._new_repo(config)
+        install_sqlite_patch(repo.querier)
+        return repo
+
+    def _restore_database(self, backup_coordinates, _config):
+        # remove database file if it exists ?
+        dbfile = self.absolute_dbfile()
+        self._cleanup_database(dbfile)
+        #print 'resto from', backup_coordinates
+        shutil.copy(backup_coordinates, dbfile)
+        repo = self.get_repo()
+
+    def init_test_database(self):
+        """initialize a fresh sqlite databse used for testing purpose"""
         # initialize the database
-        import shutil
         from cubicweb.server import init_repository
-        init_repository(config, interactive=False)
-        shutil.copy(dbfile, '%s-template' % dbfile)
+        self._cleanup_database(self.absolute_dbfile())
+        init_repository(self.config, interactive=False)
+
 
 def install_sqlite_patch(querier):
     """This patch hotfixes the following sqlite bug :
@@ -379,3 +693,74 @@
         return new_execute
     querier.__class__.execute = wrap_execute(querier.__class__.execute)
     querier.__class__._devtools_sqlite_patched = True
+
+
+
+HANDLERS = {}
+
+def register_handler(handlerkls):
+    assert handlerkls is not None
+    HANDLERS[handlerkls.DRIVER] = handlerkls
+
+register_handler(PostgresTestDataBaseHandler)
+register_handler(SQLiteTestDataBaseHandler)
+register_handler(SQLServerTestDataBaseHandler)
+
+
+class HCache(object):
+    """Handler cache object: store database handler for a given configuration.
+
+    We only keep one repo in cache to prevent too much objects to stay alive
+    (database handler holds a reference to a repository). As at the moment a new
+    handler is created for each TestCase class and all test methods are executed
+    sequentialy whithin this class, there should not have more cache miss that
+    if we had a wider cache as once a Handler stop being used it won't be used
+    again.
+    """
+
+    def __init__(self):
+        self.config = None
+        self.handler = None
+
+    def get(self, config):
+        if config is self.config:
+            return self.handler
+        else:
+            return None
+
+    def set(self, config, handler):
+        self.config = config
+        self.handler = handler
+
+HCACHE = HCache()
+
+
+# XXX a class method on Test ?
+def get_test_db_handler(config):
+    handler = HCACHE.get(config)
+    if handler is not None:
+        return handler
+    sources = config.sources()
+    driver = sources['system']['db-driver']
+    key = (driver, config)
+    handlerkls = HANDLERS.get(driver, None)
+    if handlerkls is not None:
+        handler = handlerkls(config)
+        HCACHE.set(config, handler)
+        return handler
+    else:
+        raise ValueError('no initialization function for driver %r' % driver)
+
+### compatibility layer ##############################################
+from logilab.common.deprecation import deprecated
+
+@deprecated("please use the new DatabaseHandler mecanism")
+def init_test_database(config=None, configdir='data', apphome=None):
+    """init a test database for a specific driver"""
+    if config is None:
+        config = TestServerConfiguration(apphome=apphome)
+    handler = get_test_db_handler(config)
+    handler.build_db_cache()
+    return handler.get_repo_and_cnx()
+
+