server/sources/native.py
branchstable
changeset 7342 d1c8b5b3531c
parent 7243 9ab01bf84eac
child 7398 26695dd703d8
child 7501 2983dd24494a
--- a/server/sources/native.py	Thu Apr 21 16:33:55 2011 +0200
+++ b/server/sources/native.py	Thu Apr 21 12:35:41 2011 +0200
@@ -28,21 +28,29 @@
 
 __docformat__ = "restructuredtext en"
 
-from pickle import loads, dumps
+try:
+    from cPickle import loads, dumps
+    import cPickle as pickle
+except ImportError:
+    from pickle import loads, dumps
+    import pickle
 from threading import Lock
 from datetime import datetime
 from base64 import b64decode, b64encode
 from contextlib import contextmanager
-from os.path import abspath
+from os.path import abspath, basename
 import re
 import itertools
+import zipfile
+import logging
+import sys
 
 from logilab.common.compat import any
 from logilab.common.cache import Cache
 from logilab.common.decorators import cached, clear_cache
 from logilab.common.configuration import Method
 from logilab.common.shellutils import getlogin
-from logilab.database import get_db_helper
+from logilab.database import get_db_helper, sqlgen
 
 from yams import schema2sql as y2sql
 from yams.schema import role_name
@@ -354,24 +362,44 @@
                 _pool.pool_reset()
                 self.repo._free_pool(_pool)
 
-    def backup(self, backupfile, confirm):
+    def backup(self, backupfile, confirm, format='native'):
         """method called to create a backup of the source's data"""
-        self.close_pool_connections()
-        try:
-            self.backup_to_file(backupfile, confirm)
-        finally:
-            self.open_pool_connections()
+        if format == 'portable':
+            self.repo.fill_schema()
+            self.set_schema(self.repo.schema)
+            helper = DatabaseIndependentBackupRestore(self)
+            self.close_pool_connections()
+            try:
+                helper.backup(backupfile)
+            finally:
+                self.open_pool_connections()
+        elif format == 'native':
+            self.close_pool_connections()
+            try:
+                self.backup_to_file(backupfile, confirm)
+            finally:
+                self.open_pool_connections()
+        else:
+            raise ValueError('Unknown format %r' % format)
 
-    def restore(self, backupfile, confirm, drop):
+
+    def restore(self, backupfile, confirm, drop, format='native'):
         """method called to restore a backup of source's data"""
         if self.repo.config.open_connections_pools:
             self.close_pool_connections()
         try:
-            self.restore_from_file(backupfile, confirm, drop=drop)
+            if format == 'portable':
+                helper = DatabaseIndependentBackupRestore(self)
+                helper.restore(backupfile)
+            elif format == 'native':
+                self.restore_from_file(backupfile, confirm, drop=drop)
+            else:
+                raise ValueError('Unknown format %r' % format)
         finally:
             if self.repo.config.open_connections_pools:
                 self.open_pool_connections()
 
+
     def init(self, activated, source_entity):
         self.init_creating(source_entity._cw.pool)
 
@@ -1564,3 +1592,218 @@
         login = rset.rows[0][0]
         authinfo['email_auth'] = True
         return self.source.repo.check_auth_info(session, login, authinfo)
+
+class DatabaseIndependentBackupRestore(object):
+    """Helper class to perform db backend agnostic backup and restore
+
+    The backup and restore methods are used to dump / restore the
+    system database in a database independent format. The file is a
+    Zip archive containing the following files:
+
+    * format.txt: the format of the archive. Currently '1.0'
+    * tables.txt: list of filenames in the archive tables/ directory
+    * sequences.txt: list of filenames in the archive sequences/ directory
+    * versions.txt: the list of cube versions from CWProperty
+    * tables/<tablename>.<chunkno>: pickled data
+    * sequences/<sequencename>: pickled data
+
+    The pickled data format for tables and sequences is a tuple of 3 elements:
+    * the table name
+    * a tuple of column names
+    * a list of rows (as tuples with one element per column)
+
+    Tables are saved in chunks in different files in order to prevent
+    a too high memory consumption. 
+    """
+    def __init__(self, source):
+        """
+        :param: source an instance of the system source
+        """
+        self._source = source
+        self.logger = logging.getLogger('cubicweb.ctl')
+        self.logger.setLevel(logging.INFO)
+        self.logger.addHandler(logging.StreamHandler(sys.stdout))
+        self.schema = self._source.schema
+        self.dbhelper = self._source.dbhelper
+        self.cnx = None
+        self.cursor = None
+        self.sql_generator = sqlgen.SQLGenerator()
+
+    def get_connection(self):
+        return self._source.get_connection()
+
+    def backup(self, backupfile):
+        archive=zipfile.ZipFile(backupfile, 'w')
+        self.cnx = self.get_connection()
+        try:
+            self.cursor = self.cnx.cursor()
+            self.cursor.arraysize=100
+            self.logger.info('writing metadata')
+            self.write_metadata(archive)
+            for seq in self.get_sequences():
+                self.logger.info('processing sequence %s', seq)
+                self.write_sequence(archive, seq)
+            for table in self.get_tables():
+                self.logger.info('processing table %s', table)
+                self.write_table(archive, table)
+        finally:
+            archive.close()
+            self.cnx.close()
+        self.logger.info('done')
+
+    def get_tables(self):
+        non_entity_tables = ['entities',
+                             'deleted_entities',
+                             'transactions',
+                             'tx_entity_actions',
+                             'tx_relation_actions',
+                             ]
+        etype_tables = []
+        relation_tables = []
+        prefix = 'cw_'
+        for etype in self.schema.entities():
+            eschema = self.schema.eschema(etype)
+            print etype, eschema.final
+            if eschema.final:
+                continue
+            etype_tables.append('%s%s'%(prefix, etype))
+        for rtype in self.schema.relations():
+            rschema = self.schema.rschema(rtype)
+            if rschema.final or rschema.inlined:
+                continue
+            relation_tables.append('%s_relation' % rtype)
+        return non_entity_tables + etype_tables + relation_tables
+
+    def get_sequences(self):
+        return ['entities_id_seq']
+
+    def write_metadata(self, archive):
+        archive.writestr('format.txt', '1.0')
+        archive.writestr('tables.txt', '\n'.join(self.get_tables()))
+        archive.writestr('sequences.txt', '\n'.join(self.get_sequences()))
+        versions = self._get_versions()
+        versions_str = '\n'.join('%s %s' % (k,v)
+                                 for k,v in versions)
+        archive.writestr('versions.txt', versions_str)
+
+    def write_sequence(self, archive, seq):
+        sql = self.dbhelper.sql_sequence_current_state(seq)
+        columns, rows_iterator = self._get_cols_and_rows(sql)
+        rows = list(rows_iterator)
+        serialized = self._serialize(seq, columns, rows)
+        archive.writestr('sequences/%s' % seq, serialized)
+
+    def write_table(self, archive, table):
+        sql = 'SELECT * FROM %s' % table
+        columns, rows_iterator = self._get_cols_and_rows(sql)
+        self.logger.info('number of rows: %d', self.cursor.rowcount)
+        if table.startswith('cw_'): # entities
+            blocksize = 2000
+        else: # relations and metadata
+            blocksize = 10000
+        if self.cursor.rowcount > 0:
+            for i, start in enumerate(xrange(0, self.cursor.rowcount, blocksize)):
+                rows = list(itertools.islice(rows_iterator, blocksize))
+                serialized = self._serialize(table, columns, rows)
+                archive.writestr('tables/%s.%04d' % (table, i), serialized)
+                self.logger.debug('wrote rows %d to %d (out of %d) to %s.%04d',
+                                  start, start+len(rows)-1,
+                                  self.cursor.rowcount,
+                                  table, i)
+        else:
+            rows = []
+            serialized = self._serialize(table, columns, rows)
+            archive.writestr('tables/%s.%04d' % (table, 0), serialized)
+
+    def _get_cols_and_rows(self, sql):
+        process_result = self._source.iter_process_result
+        self.cursor.execute(sql)
+        columns = (d[0] for d in self.cursor.description)
+        rows = process_result(self.cursor)
+        return tuple(columns), rows
+
+    def _serialize(self, name, columns, rows):
+        return dumps((name, columns, rows), pickle.HIGHEST_PROTOCOL)
+
+    def restore(self, backupfile):
+        archive = zipfile.ZipFile(backupfile, 'r')
+        self.cnx = self.get_connection()
+        self.cursor = self.cnx.cursor()
+        sequences, tables, table_chunks = self.read_metadata(archive, backupfile)
+        for seq in sequences:
+            self.logger.info('restoring sequence %s', seq)
+            self.read_sequence(archive, seq)
+        for table in tables:
+            self.logger.info('restoring table %s', table)
+            self.read_table(archive, table, sorted(table_chunks[table]))
+        self.cnx.close()
+        archive.close()
+        self.logger.info('done')
+
+    def read_metadata(self, archive, backupfile):
+        formatinfo = archive.read('format.txt')
+        self.logger.info('checking metadata')
+        if formatinfo.strip() != "1.0":
+            self.logger.critical('Unsupported format in archive: %s', formatinfo)
+            raise ValueError('Unknown format in %s: %s' % (backupfile, formatinfo))
+        tables = archive.read('tables.txt').splitlines()
+        sequences = archive.read('sequences.txt').splitlines()
+        file_versions = self._parse_versions(archive.read('versions.txt'))
+        versions = set(self._get_versions())
+        if file_versions != versions:
+            self.logger.critical('Unable to restore : versions do not match')
+            self.logger.critical('Expected:\n%s', '\n'.join(list(sorted(versions))))
+            self.logger.critical('Found:\n%s', '\n'.join(list(sorted(file_versions))))
+            raise ValueError('Unable to restore : versions do not match')
+        table_chunks = {}
+        for name in archive.namelist():
+            if not name.startswith('tables/'):
+                continue
+            filename = basename(name)
+            tablename, _ext = filename.rsplit('.', 1)
+            table_chunks.setdefault(tablename, []).append(name)
+        return sequences, tables, table_chunks
+
+    def read_sequence(self, archive, seq):
+        seqname, columns, rows = loads(archive.read('sequences/%s' % seq))
+        assert seqname == seq
+        assert len(rows) == 1
+        assert len(rows[0]) == 1
+        value = rows[0][0]
+        sql = self.dbhelper.sql_restart_sequence(seq, value)
+        self.cursor.execute(sql)
+        self.cnx.commit()
+
+    def read_table(self, archive, table, filenames):
+        merge_args = self._source.merge_args
+        self.cursor.execute('DELETE FROM %s' % table)
+        self.cnx.commit()
+        row_count = 0
+        for filename in filenames:
+            tablename, columns, rows = loads(archive.read(filename))
+            assert tablename == table
+            if not rows:
+                continue
+            insert = self.sql_generator.insert(table,
+                                               dict(zip(columns, rows[0])))
+            for row in rows:
+                self.cursor.execute(insert, merge_args(dict(zip(columns, row)), {}))
+            row_count += len(rows)
+            self.cnx.commit()
+        self.logger.info('inserted %d rows', row_count)
+
+
+    def _parse_versions(self, version_str):
+        versions = set()
+        for line in version_str.splitlines():
+            versions.add(tuple(line.split()))
+        return versions
+
+    def _get_versions(self):
+        version_sql = 'SELECT cw_pkey, cw_value FROM cw_CWProperty'
+        versions = []
+        self.cursor.execute(version_sql)
+        for pkey, value in self.cursor.fetchall():
+            if pkey.startswith(u'system.version'):
+                versions.append((pkey, value))
+        return versions