server/sources/native.py
branchstable
changeset 7342 d1c8b5b3531c
parent 7243 9ab01bf84eac
child 7398 26695dd703d8
child 7501 2983dd24494a
equal deleted inserted replaced
7341:c419c2d0d13e 7342:d1c8b5b3531c
    26 
    26 
    27 from __future__ import with_statement
    27 from __future__ import with_statement
    28 
    28 
    29 __docformat__ = "restructuredtext en"
    29 __docformat__ = "restructuredtext en"
    30 
    30 
    31 from pickle import loads, dumps
    31 try:
       
    32     from cPickle import loads, dumps
       
    33     import cPickle as pickle
       
    34 except ImportError:
       
    35     from pickle import loads, dumps
       
    36     import pickle
    32 from threading import Lock
    37 from threading import Lock
    33 from datetime import datetime
    38 from datetime import datetime
    34 from base64 import b64decode, b64encode
    39 from base64 import b64decode, b64encode
    35 from contextlib import contextmanager
    40 from contextlib import contextmanager
    36 from os.path import abspath
    41 from os.path import abspath, basename
    37 import re
    42 import re
    38 import itertools
    43 import itertools
       
    44 import zipfile
       
    45 import logging
       
    46 import sys
    39 
    47 
    40 from logilab.common.compat import any
    48 from logilab.common.compat import any
    41 from logilab.common.cache import Cache
    49 from logilab.common.cache import Cache
    42 from logilab.common.decorators import cached, clear_cache
    50 from logilab.common.decorators import cached, clear_cache
    43 from logilab.common.configuration import Method
    51 from logilab.common.configuration import Method
    44 from logilab.common.shellutils import getlogin
    52 from logilab.common.shellutils import getlogin
    45 from logilab.database import get_db_helper
    53 from logilab.database import get_db_helper, sqlgen
    46 
    54 
    47 from yams import schema2sql as y2sql
    55 from yams import schema2sql as y2sql
    48 from yams.schema import role_name
    56 from yams.schema import role_name
    49 
    57 
    50 from cubicweb import (UnknownEid, AuthenticationError, ValidationError, Binary,
    58 from cubicweb import (UnknownEid, AuthenticationError, ValidationError, Binary,
   352                 self.do_fti = False
   360                 self.do_fti = False
   353             if pool is None:
   361             if pool is None:
   354                 _pool.pool_reset()
   362                 _pool.pool_reset()
   355                 self.repo._free_pool(_pool)
   363                 self.repo._free_pool(_pool)
   356 
   364 
   357     def backup(self, backupfile, confirm):
   365     def backup(self, backupfile, confirm, format='native'):
   358         """method called to create a backup of the source's data"""
   366         """method called to create a backup of the source's data"""
   359         self.close_pool_connections()
   367         if format == 'portable':
   360         try:
   368             self.repo.fill_schema()
   361             self.backup_to_file(backupfile, confirm)
   369             self.set_schema(self.repo.schema)
   362         finally:
   370             helper = DatabaseIndependentBackupRestore(self)
   363             self.open_pool_connections()
   371             self.close_pool_connections()
   364 
   372             try:
   365     def restore(self, backupfile, confirm, drop):
   373                 helper.backup(backupfile)
       
   374             finally:
       
   375                 self.open_pool_connections()
       
   376         elif format == 'native':
       
   377             self.close_pool_connections()
       
   378             try:
       
   379                 self.backup_to_file(backupfile, confirm)
       
   380             finally:
       
   381                 self.open_pool_connections()
       
   382         else:
       
   383             raise ValueError('Unknown format %r' % format)
       
   384 
       
   385 
       
   386     def restore(self, backupfile, confirm, drop, format='native'):
   366         """method called to restore a backup of source's data"""
   387         """method called to restore a backup of source's data"""
   367         if self.repo.config.open_connections_pools:
   388         if self.repo.config.open_connections_pools:
   368             self.close_pool_connections()
   389             self.close_pool_connections()
   369         try:
   390         try:
   370             self.restore_from_file(backupfile, confirm, drop=drop)
   391             if format == 'portable':
       
   392                 helper = DatabaseIndependentBackupRestore(self)
       
   393                 helper.restore(backupfile)
       
   394             elif format == 'native':
       
   395                 self.restore_from_file(backupfile, confirm, drop=drop)
       
   396             else:
       
   397                 raise ValueError('Unknown format %r' % format)
   371         finally:
   398         finally:
   372             if self.repo.config.open_connections_pools:
   399             if self.repo.config.open_connections_pools:
   373                 self.open_pool_connections()
   400                 self.open_pool_connections()
       
   401 
   374 
   402 
   375     def init(self, activated, source_entity):
   403     def init(self, activated, source_entity):
   376         self.init_creating(source_entity._cw.pool)
   404         self.init_creating(source_entity._cw.pool)
   377 
   405 
   378     def shutdown(self):
   406     def shutdown(self):
  1562         if rset.rowcount != 1:
  1590         if rset.rowcount != 1:
  1563             raise AuthenticationError('unexisting email')
  1591             raise AuthenticationError('unexisting email')
  1564         login = rset.rows[0][0]
  1592         login = rset.rows[0][0]
  1565         authinfo['email_auth'] = True
  1593         authinfo['email_auth'] = True
  1566         return self.source.repo.check_auth_info(session, login, authinfo)
  1594         return self.source.repo.check_auth_info(session, login, authinfo)
       
  1595 
       
  1596 class DatabaseIndependentBackupRestore(object):
       
  1597     """Helper class to perform db backend agnostic backup and restore
       
  1598 
       
  1599     The backup and restore methods are used to dump / restore the
       
  1600     system database in a database independent format. The file is a
       
  1601     Zip archive containing the following files:
       
  1602 
       
  1603     * format.txt: the format of the archive. Currently '1.0'
       
  1604     * tables.txt: list of filenames in the archive tables/ directory
       
  1605     * sequences.txt: list of filenames in the archive sequences/ directory
       
  1606     * versions.txt: the list of cube versions from CWProperty
       
  1607     * tables/<tablename>.<chunkno>: pickled data
       
  1608     * sequences/<sequencename>: pickled data
       
  1609 
       
  1610     The pickled data format for tables and sequences is a tuple of 3 elements:
       
  1611     * the table name
       
  1612     * a tuple of column names
       
  1613     * a list of rows (as tuples with one element per column)
       
  1614 
       
  1615     Tables are saved in chunks in different files in order to prevent
       
  1616     a too high memory consumption. 
       
  1617     """
       
  1618     def __init__(self, source):
       
  1619         """
       
  1620         :param: source an instance of the system source
       
  1621         """
       
  1622         self._source = source
       
  1623         self.logger = logging.getLogger('cubicweb.ctl')
       
  1624         self.logger.setLevel(logging.INFO)
       
  1625         self.logger.addHandler(logging.StreamHandler(sys.stdout))
       
  1626         self.schema = self._source.schema
       
  1627         self.dbhelper = self._source.dbhelper
       
  1628         self.cnx = None
       
  1629         self.cursor = None
       
  1630         self.sql_generator = sqlgen.SQLGenerator()
       
  1631 
       
  1632     def get_connection(self):
       
  1633         return self._source.get_connection()
       
  1634 
       
  1635     def backup(self, backupfile):
       
  1636         archive=zipfile.ZipFile(backupfile, 'w')
       
  1637         self.cnx = self.get_connection()
       
  1638         try:
       
  1639             self.cursor = self.cnx.cursor()
       
  1640             self.cursor.arraysize=100
       
  1641             self.logger.info('writing metadata')
       
  1642             self.write_metadata(archive)
       
  1643             for seq in self.get_sequences():
       
  1644                 self.logger.info('processing sequence %s', seq)
       
  1645                 self.write_sequence(archive, seq)
       
  1646             for table in self.get_tables():
       
  1647                 self.logger.info('processing table %s', table)
       
  1648                 self.write_table(archive, table)
       
  1649         finally:
       
  1650             archive.close()
       
  1651             self.cnx.close()
       
  1652         self.logger.info('done')
       
  1653 
       
  1654     def get_tables(self):
       
  1655         non_entity_tables = ['entities',
       
  1656                              'deleted_entities',
       
  1657                              'transactions',
       
  1658                              'tx_entity_actions',
       
  1659                              'tx_relation_actions',
       
  1660                              ]
       
  1661         etype_tables = []
       
  1662         relation_tables = []
       
  1663         prefix = 'cw_'
       
  1664         for etype in self.schema.entities():
       
  1665             eschema = self.schema.eschema(etype)
       
  1666             print etype, eschema.final
       
  1667             if eschema.final:
       
  1668                 continue
       
  1669             etype_tables.append('%s%s'%(prefix, etype))
       
  1670         for rtype in self.schema.relations():
       
  1671             rschema = self.schema.rschema(rtype)
       
  1672             if rschema.final or rschema.inlined:
       
  1673                 continue
       
  1674             relation_tables.append('%s_relation' % rtype)
       
  1675         return non_entity_tables + etype_tables + relation_tables
       
  1676 
       
  1677     def get_sequences(self):
       
  1678         return ['entities_id_seq']
       
  1679 
       
  1680     def write_metadata(self, archive):
       
  1681         archive.writestr('format.txt', '1.0')
       
  1682         archive.writestr('tables.txt', '\n'.join(self.get_tables()))
       
  1683         archive.writestr('sequences.txt', '\n'.join(self.get_sequences()))
       
  1684         versions = self._get_versions()
       
  1685         versions_str = '\n'.join('%s %s' % (k,v)
       
  1686                                  for k,v in versions)
       
  1687         archive.writestr('versions.txt', versions_str)
       
  1688 
       
  1689     def write_sequence(self, archive, seq):
       
  1690         sql = self.dbhelper.sql_sequence_current_state(seq)
       
  1691         columns, rows_iterator = self._get_cols_and_rows(sql)
       
  1692         rows = list(rows_iterator)
       
  1693         serialized = self._serialize(seq, columns, rows)
       
  1694         archive.writestr('sequences/%s' % seq, serialized)
       
  1695 
       
  1696     def write_table(self, archive, table):
       
  1697         sql = 'SELECT * FROM %s' % table
       
  1698         columns, rows_iterator = self._get_cols_and_rows(sql)
       
  1699         self.logger.info('number of rows: %d', self.cursor.rowcount)
       
  1700         if table.startswith('cw_'): # entities
       
  1701             blocksize = 2000
       
  1702         else: # relations and metadata
       
  1703             blocksize = 10000
       
  1704         if self.cursor.rowcount > 0:
       
  1705             for i, start in enumerate(xrange(0, self.cursor.rowcount, blocksize)):
       
  1706                 rows = list(itertools.islice(rows_iterator, blocksize))
       
  1707                 serialized = self._serialize(table, columns, rows)
       
  1708                 archive.writestr('tables/%s.%04d' % (table, i), serialized)
       
  1709                 self.logger.debug('wrote rows %d to %d (out of %d) to %s.%04d',
       
  1710                                   start, start+len(rows)-1,
       
  1711                                   self.cursor.rowcount,
       
  1712                                   table, i)
       
  1713         else:
       
  1714             rows = []
       
  1715             serialized = self._serialize(table, columns, rows)
       
  1716             archive.writestr('tables/%s.%04d' % (table, 0), serialized)
       
  1717 
       
  1718     def _get_cols_and_rows(self, sql):
       
  1719         process_result = self._source.iter_process_result
       
  1720         self.cursor.execute(sql)
       
  1721         columns = (d[0] for d in self.cursor.description)
       
  1722         rows = process_result(self.cursor)
       
  1723         return tuple(columns), rows
       
  1724 
       
  1725     def _serialize(self, name, columns, rows):
       
  1726         return dumps((name, columns, rows), pickle.HIGHEST_PROTOCOL)
       
  1727 
       
  1728     def restore(self, backupfile):
       
  1729         archive = zipfile.ZipFile(backupfile, 'r')
       
  1730         self.cnx = self.get_connection()
       
  1731         self.cursor = self.cnx.cursor()
       
  1732         sequences, tables, table_chunks = self.read_metadata(archive, backupfile)
       
  1733         for seq in sequences:
       
  1734             self.logger.info('restoring sequence %s', seq)
       
  1735             self.read_sequence(archive, seq)
       
  1736         for table in tables:
       
  1737             self.logger.info('restoring table %s', table)
       
  1738             self.read_table(archive, table, sorted(table_chunks[table]))
       
  1739         self.cnx.close()
       
  1740         archive.close()
       
  1741         self.logger.info('done')
       
  1742 
       
  1743     def read_metadata(self, archive, backupfile):
       
  1744         formatinfo = archive.read('format.txt')
       
  1745         self.logger.info('checking metadata')
       
  1746         if formatinfo.strip() != "1.0":
       
  1747             self.logger.critical('Unsupported format in archive: %s', formatinfo)
       
  1748             raise ValueError('Unknown format in %s: %s' % (backupfile, formatinfo))
       
  1749         tables = archive.read('tables.txt').splitlines()
       
  1750         sequences = archive.read('sequences.txt').splitlines()
       
  1751         file_versions = self._parse_versions(archive.read('versions.txt'))
       
  1752         versions = set(self._get_versions())
       
  1753         if file_versions != versions:
       
  1754             self.logger.critical('Unable to restore : versions do not match')
       
  1755             self.logger.critical('Expected:\n%s', '\n'.join(list(sorted(versions))))
       
  1756             self.logger.critical('Found:\n%s', '\n'.join(list(sorted(file_versions))))
       
  1757             raise ValueError('Unable to restore : versions do not match')
       
  1758         table_chunks = {}
       
  1759         for name in archive.namelist():
       
  1760             if not name.startswith('tables/'):
       
  1761                 continue
       
  1762             filename = basename(name)
       
  1763             tablename, _ext = filename.rsplit('.', 1)
       
  1764             table_chunks.setdefault(tablename, []).append(name)
       
  1765         return sequences, tables, table_chunks
       
  1766 
       
  1767     def read_sequence(self, archive, seq):
       
  1768         seqname, columns, rows = loads(archive.read('sequences/%s' % seq))
       
  1769         assert seqname == seq
       
  1770         assert len(rows) == 1
       
  1771         assert len(rows[0]) == 1
       
  1772         value = rows[0][0]
       
  1773         sql = self.dbhelper.sql_restart_sequence(seq, value)
       
  1774         self.cursor.execute(sql)
       
  1775         self.cnx.commit()
       
  1776 
       
  1777     def read_table(self, archive, table, filenames):
       
  1778         merge_args = self._source.merge_args
       
  1779         self.cursor.execute('DELETE FROM %s' % table)
       
  1780         self.cnx.commit()
       
  1781         row_count = 0
       
  1782         for filename in filenames:
       
  1783             tablename, columns, rows = loads(archive.read(filename))
       
  1784             assert tablename == table
       
  1785             if not rows:
       
  1786                 continue
       
  1787             insert = self.sql_generator.insert(table,
       
  1788                                                dict(zip(columns, rows[0])))
       
  1789             for row in rows:
       
  1790                 self.cursor.execute(insert, merge_args(dict(zip(columns, row)), {}))
       
  1791             row_count += len(rows)
       
  1792             self.cnx.commit()
       
  1793         self.logger.info('inserted %d rows', row_count)
       
  1794 
       
  1795 
       
  1796     def _parse_versions(self, version_str):
       
  1797         versions = set()
       
  1798         for line in version_str.splitlines():
       
  1799             versions.add(tuple(line.split()))
       
  1800         return versions
       
  1801 
       
  1802     def _get_versions(self):
       
  1803         version_sql = 'SELECT cw_pkey, cw_value FROM cw_CWProperty'
       
  1804         versions = []
       
  1805         self.cursor.execute(version_sql)
       
  1806         for pkey, value in self.cursor.fetchall():
       
  1807             if pkey.startswith(u'system.version'):
       
  1808                 versions.append((pkey, value))
       
  1809         return versions