26 |
26 |
27 from __future__ import with_statement |
27 from __future__ import with_statement |
28 |
28 |
29 __docformat__ = "restructuredtext en" |
29 __docformat__ = "restructuredtext en" |
30 |
30 |
31 from pickle import loads, dumps |
31 try: |
|
32 from cPickle import loads, dumps |
|
33 import cPickle as pickle |
|
34 except ImportError: |
|
35 from pickle import loads, dumps |
|
36 import pickle |
32 from threading import Lock |
37 from threading import Lock |
33 from datetime import datetime |
38 from datetime import datetime |
34 from base64 import b64decode, b64encode |
39 from base64 import b64decode, b64encode |
35 from contextlib import contextmanager |
40 from contextlib import contextmanager |
36 from os.path import abspath |
41 from os.path import abspath, basename |
37 import re |
42 import re |
38 import itertools |
43 import itertools |
|
44 import zipfile |
|
45 import logging |
|
46 import sys |
39 |
47 |
40 from logilab.common.compat import any |
48 from logilab.common.compat import any |
41 from logilab.common.cache import Cache |
49 from logilab.common.cache import Cache |
42 from logilab.common.decorators import cached, clear_cache |
50 from logilab.common.decorators import cached, clear_cache |
43 from logilab.common.configuration import Method |
51 from logilab.common.configuration import Method |
44 from logilab.common.shellutils import getlogin |
52 from logilab.common.shellutils import getlogin |
45 from logilab.database import get_db_helper |
53 from logilab.database import get_db_helper, sqlgen |
46 |
54 |
47 from yams import schema2sql as y2sql |
55 from yams import schema2sql as y2sql |
48 from yams.schema import role_name |
56 from yams.schema import role_name |
49 |
57 |
50 from cubicweb import (UnknownEid, AuthenticationError, ValidationError, Binary, |
58 from cubicweb import (UnknownEid, AuthenticationError, ValidationError, Binary, |
352 self.do_fti = False |
360 self.do_fti = False |
353 if pool is None: |
361 if pool is None: |
354 _pool.pool_reset() |
362 _pool.pool_reset() |
355 self.repo._free_pool(_pool) |
363 self.repo._free_pool(_pool) |
356 |
364 |
357 def backup(self, backupfile, confirm): |
365 def backup(self, backupfile, confirm, format='native'): |
358 """method called to create a backup of the source's data""" |
366 """method called to create a backup of the source's data""" |
359 self.close_pool_connections() |
367 if format == 'portable': |
360 try: |
368 self.repo.fill_schema() |
361 self.backup_to_file(backupfile, confirm) |
369 self.set_schema(self.repo.schema) |
362 finally: |
370 helper = DatabaseIndependentBackupRestore(self) |
363 self.open_pool_connections() |
371 self.close_pool_connections() |
364 |
372 try: |
365 def restore(self, backupfile, confirm, drop): |
373 helper.backup(backupfile) |
|
374 finally: |
|
375 self.open_pool_connections() |
|
376 elif format == 'native': |
|
377 self.close_pool_connections() |
|
378 try: |
|
379 self.backup_to_file(backupfile, confirm) |
|
380 finally: |
|
381 self.open_pool_connections() |
|
382 else: |
|
383 raise ValueError('Unknown format %r' % format) |
|
384 |
|
385 |
|
386 def restore(self, backupfile, confirm, drop, format='native'): |
366 """method called to restore a backup of source's data""" |
387 """method called to restore a backup of source's data""" |
367 if self.repo.config.open_connections_pools: |
388 if self.repo.config.open_connections_pools: |
368 self.close_pool_connections() |
389 self.close_pool_connections() |
369 try: |
390 try: |
370 self.restore_from_file(backupfile, confirm, drop=drop) |
391 if format == 'portable': |
|
392 helper = DatabaseIndependentBackupRestore(self) |
|
393 helper.restore(backupfile) |
|
394 elif format == 'native': |
|
395 self.restore_from_file(backupfile, confirm, drop=drop) |
|
396 else: |
|
397 raise ValueError('Unknown format %r' % format) |
371 finally: |
398 finally: |
372 if self.repo.config.open_connections_pools: |
399 if self.repo.config.open_connections_pools: |
373 self.open_pool_connections() |
400 self.open_pool_connections() |
|
401 |
374 |
402 |
375 def init(self, activated, source_entity): |
403 def init(self, activated, source_entity): |
376 self.init_creating(source_entity._cw.pool) |
404 self.init_creating(source_entity._cw.pool) |
377 |
405 |
378 def shutdown(self): |
406 def shutdown(self): |
1562 if rset.rowcount != 1: |
1590 if rset.rowcount != 1: |
1563 raise AuthenticationError('unexisting email') |
1591 raise AuthenticationError('unexisting email') |
1564 login = rset.rows[0][0] |
1592 login = rset.rows[0][0] |
1565 authinfo['email_auth'] = True |
1593 authinfo['email_auth'] = True |
1566 return self.source.repo.check_auth_info(session, login, authinfo) |
1594 return self.source.repo.check_auth_info(session, login, authinfo) |
|
1595 |
|
1596 class DatabaseIndependentBackupRestore(object): |
|
1597 """Helper class to perform db backend agnostic backup and restore |
|
1598 |
|
1599 The backup and restore methods are used to dump / restore the |
|
1600 system database in a database independent format. The file is a |
|
1601 Zip archive containing the following files: |
|
1602 |
|
1603 * format.txt: the format of the archive. Currently '1.0' |
|
1604 * tables.txt: list of filenames in the archive tables/ directory |
|
1605 * sequences.txt: list of filenames in the archive sequences/ directory |
|
1606 * versions.txt: the list of cube versions from CWProperty |
|
1607 * tables/<tablename>.<chunkno>: pickled data |
|
1608 * sequences/<sequencename>: pickled data |
|
1609 |
|
1610 The pickled data format for tables and sequences is a tuple of 3 elements: |
|
1611 * the table name |
|
1612 * a tuple of column names |
|
1613 * a list of rows (as tuples with one element per column) |
|
1614 |
|
1615 Tables are saved in chunks in different files in order to prevent |
|
1616 a too high memory consumption. |
|
1617 """ |
|
1618 def __init__(self, source): |
|
1619 """ |
|
1620 :param: source an instance of the system source |
|
1621 """ |
|
1622 self._source = source |
|
1623 self.logger = logging.getLogger('cubicweb.ctl') |
|
1624 self.logger.setLevel(logging.INFO) |
|
1625 self.logger.addHandler(logging.StreamHandler(sys.stdout)) |
|
1626 self.schema = self._source.schema |
|
1627 self.dbhelper = self._source.dbhelper |
|
1628 self.cnx = None |
|
1629 self.cursor = None |
|
1630 self.sql_generator = sqlgen.SQLGenerator() |
|
1631 |
|
1632 def get_connection(self): |
|
1633 return self._source.get_connection() |
|
1634 |
|
1635 def backup(self, backupfile): |
|
1636 archive=zipfile.ZipFile(backupfile, 'w') |
|
1637 self.cnx = self.get_connection() |
|
1638 try: |
|
1639 self.cursor = self.cnx.cursor() |
|
1640 self.cursor.arraysize=100 |
|
1641 self.logger.info('writing metadata') |
|
1642 self.write_metadata(archive) |
|
1643 for seq in self.get_sequences(): |
|
1644 self.logger.info('processing sequence %s', seq) |
|
1645 self.write_sequence(archive, seq) |
|
1646 for table in self.get_tables(): |
|
1647 self.logger.info('processing table %s', table) |
|
1648 self.write_table(archive, table) |
|
1649 finally: |
|
1650 archive.close() |
|
1651 self.cnx.close() |
|
1652 self.logger.info('done') |
|
1653 |
|
1654 def get_tables(self): |
|
1655 non_entity_tables = ['entities', |
|
1656 'deleted_entities', |
|
1657 'transactions', |
|
1658 'tx_entity_actions', |
|
1659 'tx_relation_actions', |
|
1660 ] |
|
1661 etype_tables = [] |
|
1662 relation_tables = [] |
|
1663 prefix = 'cw_' |
|
1664 for etype in self.schema.entities(): |
|
1665 eschema = self.schema.eschema(etype) |
|
1666 print etype, eschema.final |
|
1667 if eschema.final: |
|
1668 continue |
|
1669 etype_tables.append('%s%s'%(prefix, etype)) |
|
1670 for rtype in self.schema.relations(): |
|
1671 rschema = self.schema.rschema(rtype) |
|
1672 if rschema.final or rschema.inlined: |
|
1673 continue |
|
1674 relation_tables.append('%s_relation' % rtype) |
|
1675 return non_entity_tables + etype_tables + relation_tables |
|
1676 |
|
1677 def get_sequences(self): |
|
1678 return ['entities_id_seq'] |
|
1679 |
|
1680 def write_metadata(self, archive): |
|
1681 archive.writestr('format.txt', '1.0') |
|
1682 archive.writestr('tables.txt', '\n'.join(self.get_tables())) |
|
1683 archive.writestr('sequences.txt', '\n'.join(self.get_sequences())) |
|
1684 versions = self._get_versions() |
|
1685 versions_str = '\n'.join('%s %s' % (k,v) |
|
1686 for k,v in versions) |
|
1687 archive.writestr('versions.txt', versions_str) |
|
1688 |
|
1689 def write_sequence(self, archive, seq): |
|
1690 sql = self.dbhelper.sql_sequence_current_state(seq) |
|
1691 columns, rows_iterator = self._get_cols_and_rows(sql) |
|
1692 rows = list(rows_iterator) |
|
1693 serialized = self._serialize(seq, columns, rows) |
|
1694 archive.writestr('sequences/%s' % seq, serialized) |
|
1695 |
|
1696 def write_table(self, archive, table): |
|
1697 sql = 'SELECT * FROM %s' % table |
|
1698 columns, rows_iterator = self._get_cols_and_rows(sql) |
|
1699 self.logger.info('number of rows: %d', self.cursor.rowcount) |
|
1700 if table.startswith('cw_'): # entities |
|
1701 blocksize = 2000 |
|
1702 else: # relations and metadata |
|
1703 blocksize = 10000 |
|
1704 if self.cursor.rowcount > 0: |
|
1705 for i, start in enumerate(xrange(0, self.cursor.rowcount, blocksize)): |
|
1706 rows = list(itertools.islice(rows_iterator, blocksize)) |
|
1707 serialized = self._serialize(table, columns, rows) |
|
1708 archive.writestr('tables/%s.%04d' % (table, i), serialized) |
|
1709 self.logger.debug('wrote rows %d to %d (out of %d) to %s.%04d', |
|
1710 start, start+len(rows)-1, |
|
1711 self.cursor.rowcount, |
|
1712 table, i) |
|
1713 else: |
|
1714 rows = [] |
|
1715 serialized = self._serialize(table, columns, rows) |
|
1716 archive.writestr('tables/%s.%04d' % (table, 0), serialized) |
|
1717 |
|
1718 def _get_cols_and_rows(self, sql): |
|
1719 process_result = self._source.iter_process_result |
|
1720 self.cursor.execute(sql) |
|
1721 columns = (d[0] for d in self.cursor.description) |
|
1722 rows = process_result(self.cursor) |
|
1723 return tuple(columns), rows |
|
1724 |
|
1725 def _serialize(self, name, columns, rows): |
|
1726 return dumps((name, columns, rows), pickle.HIGHEST_PROTOCOL) |
|
1727 |
|
1728 def restore(self, backupfile): |
|
1729 archive = zipfile.ZipFile(backupfile, 'r') |
|
1730 self.cnx = self.get_connection() |
|
1731 self.cursor = self.cnx.cursor() |
|
1732 sequences, tables, table_chunks = self.read_metadata(archive, backupfile) |
|
1733 for seq in sequences: |
|
1734 self.logger.info('restoring sequence %s', seq) |
|
1735 self.read_sequence(archive, seq) |
|
1736 for table in tables: |
|
1737 self.logger.info('restoring table %s', table) |
|
1738 self.read_table(archive, table, sorted(table_chunks[table])) |
|
1739 self.cnx.close() |
|
1740 archive.close() |
|
1741 self.logger.info('done') |
|
1742 |
|
1743 def read_metadata(self, archive, backupfile): |
|
1744 formatinfo = archive.read('format.txt') |
|
1745 self.logger.info('checking metadata') |
|
1746 if formatinfo.strip() != "1.0": |
|
1747 self.logger.critical('Unsupported format in archive: %s', formatinfo) |
|
1748 raise ValueError('Unknown format in %s: %s' % (backupfile, formatinfo)) |
|
1749 tables = archive.read('tables.txt').splitlines() |
|
1750 sequences = archive.read('sequences.txt').splitlines() |
|
1751 file_versions = self._parse_versions(archive.read('versions.txt')) |
|
1752 versions = set(self._get_versions()) |
|
1753 if file_versions != versions: |
|
1754 self.logger.critical('Unable to restore : versions do not match') |
|
1755 self.logger.critical('Expected:\n%s', '\n'.join(list(sorted(versions)))) |
|
1756 self.logger.critical('Found:\n%s', '\n'.join(list(sorted(file_versions)))) |
|
1757 raise ValueError('Unable to restore : versions do not match') |
|
1758 table_chunks = {} |
|
1759 for name in archive.namelist(): |
|
1760 if not name.startswith('tables/'): |
|
1761 continue |
|
1762 filename = basename(name) |
|
1763 tablename, _ext = filename.rsplit('.', 1) |
|
1764 table_chunks.setdefault(tablename, []).append(name) |
|
1765 return sequences, tables, table_chunks |
|
1766 |
|
1767 def read_sequence(self, archive, seq): |
|
1768 seqname, columns, rows = loads(archive.read('sequences/%s' % seq)) |
|
1769 assert seqname == seq |
|
1770 assert len(rows) == 1 |
|
1771 assert len(rows[0]) == 1 |
|
1772 value = rows[0][0] |
|
1773 sql = self.dbhelper.sql_restart_sequence(seq, value) |
|
1774 self.cursor.execute(sql) |
|
1775 self.cnx.commit() |
|
1776 |
|
1777 def read_table(self, archive, table, filenames): |
|
1778 merge_args = self._source.merge_args |
|
1779 self.cursor.execute('DELETE FROM %s' % table) |
|
1780 self.cnx.commit() |
|
1781 row_count = 0 |
|
1782 for filename in filenames: |
|
1783 tablename, columns, rows = loads(archive.read(filename)) |
|
1784 assert tablename == table |
|
1785 if not rows: |
|
1786 continue |
|
1787 insert = self.sql_generator.insert(table, |
|
1788 dict(zip(columns, rows[0]))) |
|
1789 for row in rows: |
|
1790 self.cursor.execute(insert, merge_args(dict(zip(columns, row)), {})) |
|
1791 row_count += len(rows) |
|
1792 self.cnx.commit() |
|
1793 self.logger.info('inserted %d rows', row_count) |
|
1794 |
|
1795 |
|
1796 def _parse_versions(self, version_str): |
|
1797 versions = set() |
|
1798 for line in version_str.splitlines(): |
|
1799 versions.add(tuple(line.split())) |
|
1800 return versions |
|
1801 |
|
1802 def _get_versions(self): |
|
1803 version_sql = 'SELECT cw_pkey, cw_value FROM cw_CWProperty' |
|
1804 versions = [] |
|
1805 self.cursor.execute(version_sql) |
|
1806 for pkey, value in self.cursor.fetchall(): |
|
1807 if pkey.startswith(u'system.version'): |
|
1808 versions.append((pkey, value)) |
|
1809 return versions |