# HG changeset patch # User Sylvain Thénault # Date 1305029321 -7200 # Node ID d68f9319bfda02e7381a8c9f980769a07f1f0c7d # Parent c2452cd570264786f809007caa141d66d7d368d0# Parent ed66f236715d8269f2ca1a684cbdfb48539e24f6 backport stable diff -r c2452cd57026 -r d68f9319bfda selectors.py --- a/selectors.py Tue May 10 12:07:54 2011 +0200 +++ b/selectors.py Tue May 10 14:08:41 2011 +0200 @@ -1171,13 +1171,19 @@ def __str__(self): return '%s(%r)' % (self.__class__.__name__, self.rql) - def score(self, req, rset, row, col): + def _score(self, req, eid): try: - return req.execute(self.rql, {'x': rset[row][col], - 'u': req.user.eid})[0][0] + return req.execute(self.rql, {'x': eid, 'u': req.user.eid})[0][0] except Unauthorized: return 0 + def score(self, req, rset, row, col): + return self._score(req, rset[row][col]) + + def score_entity(self, entity): + return self._score(entity._cw, entity.eid) + + # workflow selectors ########################################################### class is_in_state(score_entity): diff -r c2452cd57026 -r d68f9319bfda server/migractions.py --- a/server/migractions.py Tue May 10 12:07:54 2011 +0200 +++ b/server/migractions.py Tue May 10 14:08:41 2011 +0200 @@ -162,7 +162,7 @@ # server specific migration methods ######################################## - def backup_database(self, backupfile=None, askconfirm=True): + def backup_database(self, backupfile=None, askconfirm=True, format='native'): config = self.config repo = self.repo_connect() # paths @@ -185,16 +185,24 @@ # backup tmpdir = tempfile.mkdtemp() try: + failed = False for source in repo.sources: try: - source.backup(osp.join(tmpdir, source.uri), self.confirm) + source.backup(osp.join(tmpdir, source.uri), self.confirm, format=format) except Exception, ex: print '-> error trying to backup %s [%s]' % (source.uri, ex) if not self.confirm('Continue anyway?', default='n'): raise SystemExit(1) else: - break - else: + failed = True + with open(osp.join(tmpdir, 'format.txt'), 'w') as format_file: + format_file.write('%s\n' % format) + with open(osp.join(tmpdir, 'versions.txt'), 'w') as version_file: + versions = repo.get_versions() + for cube, version in versions.iteritems(): + version_file.write('%s %s\n' % (cube, version)) + + if not failed: bkup = tarfile.open(backupfile, 'w|gz') for filename in os.listdir(tmpdir): bkup.add(osp.join(tmpdir, filename), filename) @@ -207,7 +215,7 @@ shutil.rmtree(tmpdir) def restore_database(self, backupfile, drop=True, systemonly=True, - askconfirm=True): + askconfirm=True, format='native'): # check if not osp.exists(backupfile): raise ExecutionError("Backup file %s doesn't exist" % backupfile) @@ -229,13 +237,18 @@ bkup = tarfile.open(backupfile, 'r|gz') bkup.extractall(path=tmpdir) bkup.close() + if osp.isfile(osp.join(tmpdir, 'format.txt')): + with open(osp.join(tmpdir, 'format.txt')) as format_file: + written_format = format_file.readline().strip() + if written_format in ('portable', 'native'): + format = written_format self.config.open_connections_pools = False repo = self.repo_connect() for source in repo.sources: if systemonly and source.uri != 'system': continue try: - source.restore(osp.join(tmpdir, source.uri), self.confirm, drop) + source.restore(osp.join(tmpdir, source.uri), self.confirm, drop, format) except Exception, exc: print '-> error trying to restore %s [%s]' % (source.uri, exc) if not self.confirm('Continue anyway?', default='n'): diff -r c2452cd57026 -r d68f9319bfda server/serverctl.py --- a/server/serverctl.py Tue May 10 12:07:54 2011 +0200 +++ b/server/serverctl.py Tue May 10 14:08:41 2011 +0200 @@ -691,19 +691,20 @@ 'Continue anyway?' % filename): raise ExecutionError('Error while deleting remote dump at /tmp/%s' % filename) -def _local_dump(appid, output): + +def _local_dump(appid, output, format='native'): config = ServerConfiguration.config_for(appid) config.quick_start = True mih = config.migration_handler(connect=False, verbosity=1) - mih.backup_database(output, askconfirm=False) + mih.backup_database(output, askconfirm=False, format=format) mih.shutdown() -def _local_restore(appid, backupfile, drop, systemonly=True): +def _local_restore(appid, backupfile, drop, systemonly=True, format='native'): config = ServerConfiguration.config_for(appid) config.verbosity = 1 # else we won't be asked for confirmation on problems config.quick_start = True mih = config.migration_handler(connect=False, verbosity=1) - mih.restore_database(backupfile, drop, systemonly, askconfirm=False) + mih.restore_database(backupfile, drop, systemonly, askconfirm=False, format=format) repo = mih.repo_connect() # version of the database dbversions = repo.get_versions() @@ -777,6 +778,12 @@ 'default' : False, 'help': 'Use sudo on the remote host.'} ), + ('format', + {'short': 'f', 'default': 'native', 'type': 'choice', + 'choices': ('native', 'portable'), + 'help': '"native" format uses db backend utilities to dump the database. ' + '"portable" format uses a database independent format'} + ), ) def run(self, args): @@ -785,7 +792,9 @@ host, appid = appid.split(':') _remote_dump(host, appid, self.config.output, self.config.sudo) else: - _local_dump(appid, self.config.output) + _local_dump(appid, self.config.output, format=self.config.format) + + class DBRestoreCommand(Command): @@ -811,13 +820,33 @@ 'instance data. In that case, is expected to be the ' 'timestamp of the backup to restore, not a file'} ), + ('format', + {'short': 'f', 'default': 'native', 'type': 'choice', + 'choices': ('native', 'portable'), + 'help': 'the format used when dumping the database'}), ) def run(self, args): appid, backupfile = args + if self.config.format == 'portable': + # we need to ensure a DB exist before restoring from portable format + if not self.config.no_drop: + try: + CWCTL.run(['db-create', '--automatic', appid]) + except SystemExit, exc: + # continue if the command exited with status 0 (success) + if exc.code: + raise _local_restore(appid, backupfile, drop=not self.config.no_drop, - systemonly=not self.config.restore_all) + systemonly=not self.config.restore_all, + format=self.config.format) + if self.config.format == 'portable': + try: + CWCTL.run(['db-rebuild-fti', appid]) + except SystemExit, exc: + if exc.code: + raise class DBCopyCommand(Command): @@ -850,6 +879,12 @@ 'default' : False, 'help': 'Use sudo on the remote host.'} ), + ('format', + {'short': 'f', 'default': 'native', 'type': 'choice', + 'choices': ('native', 'portable'), + 'help': '"native" format uses db backend utilities to dump the database. ' + '"portable" format uses a database independent format'} + ), ) def run(self, args): @@ -861,8 +896,9 @@ host, srcappid = srcappid.split(':') _remote_dump(host, srcappid, output, self.config.sudo) else: - _local_dump(srcappid, output) - _local_restore(destappid, output, not self.config.no_drop) + _local_dump(srcappid, output, format=self.config.format) + _local_restore(destappid, output, not self.config.no_drop, + self.config.format) if self.config.keep_dump: print '-> you can get the dump file at', output else: diff -r c2452cd57026 -r d68f9319bfda server/session.py --- a/server/session.py Tue May 10 12:07:54 2011 +0200 +++ b/server/session.py Tue May 10 14:08:41 2011 +0200 @@ -163,6 +163,7 @@ self.__threaddata = threading.local() self._threads_in_transaction = set() self._closed = False + self._closed_lock = threading.Lock() def __unicode__(self): return '<%ssession %s (%s 0x%x)>' % ( @@ -647,22 +648,23 @@ def set_pool(self): """the session need a pool to execute some queries""" - if self._closed: - self.reset_pool(True) - raise Exception('try to set pool on a closed session') - if self.pool is None: - # get pool first to avoid race-condition - self._threaddata.pool = pool = self.repo._get_pool() - self._threaddata.ctx_count += 1 - try: - pool.pool_set() - except: - self._threaddata.pool = None - self.repo._free_pool(pool) - raise - self._threads_in_transaction.add( - (threading.currentThread(), pool) ) - return self._threaddata.pool + with self._closed_lock: + if self._closed: + self.reset_pool(True) + raise Exception('try to set pool on a closed session') + if self.pool is None: + # get pool first to avoid race-condition + self._threaddata.pool = pool = self.repo._get_pool() + self._threaddata.ctx_count += 1 + try: + pool.pool_set() + except: + self._threaddata.pool = None + self.repo._free_pool(pool) + raise + self._threads_in_transaction.add( + (threading.currentThread(), pool) ) + return self._threaddata.pool def _free_thread_pool(self, thread, pool, force_close=False): try: @@ -911,7 +913,8 @@ def close(self): """do not close pool on session close, since they are shared now""" - self._closed = True + with self._closed_lock: + self._closed = True # copy since _threads_in_transaction maybe modified while waiting for thread, pool in self._threads_in_transaction.copy(): if thread is threading.currentThread(): diff -r c2452cd57026 -r d68f9319bfda server/sources/__init__.py --- a/server/sources/__init__.py Tue May 10 12:07:54 2011 +0200 +++ b/server/sources/__init__.py Tue May 10 14:08:41 2011 +0200 @@ -139,11 +139,11 @@ return -1 return cmp(self.uri, other.uri) - def backup(self, backupfile, confirm): + def backup(self, backupfile, confirm, format='native'): """method called to create a backup of source's data""" pass - def restore(self, backupfile, confirm, drop): + def restore(self, backupfile, confirm, drop, format='native'): """method called to restore a backup of source's data""" pass diff -r c2452cd57026 -r d68f9319bfda server/sources/datafeed.py --- a/server/sources/datafeed.py Tue May 10 12:07:54 2011 +0200 +++ b/server/sources/datafeed.py Tue May 10 14:08:41 2011 +0200 @@ -120,7 +120,7 @@ return False return datetime.now() < (self.latest_retrieval + self.synchro_interval) - def pull_data(self, session, force=False): + def pull_data(self, session, force=False, raise_on_error=False): if not force and self.fresh(): return {} if self.config['delete-entities']: @@ -135,6 +135,8 @@ if parser.process(url): error = True except IOError, exc: + if raise_on_error: + raise self.error('could not pull data while processing %s: %s', url, exc) error = True diff -r c2452cd57026 -r d68f9319bfda server/sources/native.py --- a/server/sources/native.py Tue May 10 12:07:54 2011 +0200 +++ b/server/sources/native.py Tue May 10 14:08:41 2011 +0200 @@ -28,21 +28,29 @@ __docformat__ = "restructuredtext en" -from pickle import loads, dumps +try: + from cPickle import loads, dumps + import cPickle as pickle +except ImportError: + from pickle import loads, dumps + import pickle from threading import Lock from datetime import datetime from base64 import b64decode, b64encode from contextlib import contextmanager -from os.path import abspath +from os.path import abspath, basename import re import itertools +import zipfile +import logging +import sys from logilab.common.compat import any from logilab.common.cache import Cache from logilab.common.decorators import cached, clear_cache from logilab.common.configuration import Method from logilab.common.shellutils import getlogin -from logilab.database import get_db_helper +from logilab.database import get_db_helper, sqlgen from yams import schema2sql as y2sql from yams.schema import role_name @@ -354,24 +362,44 @@ _pool.pool_reset() self.repo._free_pool(_pool) - def backup(self, backupfile, confirm): + def backup(self, backupfile, confirm, format='native'): """method called to create a backup of the source's data""" - self.close_pool_connections() - try: - self.backup_to_file(backupfile, confirm) - finally: - self.open_pool_connections() + if format == 'portable': + self.repo.fill_schema() + self.set_schema(self.repo.schema) + helper = DatabaseIndependentBackupRestore(self) + self.close_pool_connections() + try: + helper.backup(backupfile) + finally: + self.open_pool_connections() + elif format == 'native': + self.close_pool_connections() + try: + self.backup_to_file(backupfile, confirm) + finally: + self.open_pool_connections() + else: + raise ValueError('Unknown format %r' % format) - def restore(self, backupfile, confirm, drop): + + def restore(self, backupfile, confirm, drop, format='native'): """method called to restore a backup of source's data""" if self.repo.config.open_connections_pools: self.close_pool_connections() try: - self.restore_from_file(backupfile, confirm, drop=drop) + if format == 'portable': + helper = DatabaseIndependentBackupRestore(self) + helper.restore(backupfile) + elif format == 'native': + self.restore_from_file(backupfile, confirm, drop=drop) + else: + raise ValueError('Unknown format %r' % format) finally: if self.repo.config.open_connections_pools: self.open_pool_connections() + def init(self, activated, source_entity): self.init_creating(source_entity._cw.pool) @@ -1564,3 +1592,218 @@ login = rset.rows[0][0] authinfo['email_auth'] = True return self.source.repo.check_auth_info(session, login, authinfo) + +class DatabaseIndependentBackupRestore(object): + """Helper class to perform db backend agnostic backup and restore + + The backup and restore methods are used to dump / restore the + system database in a database independent format. The file is a + Zip archive containing the following files: + + * format.txt: the format of the archive. Currently '1.0' + * tables.txt: list of filenames in the archive tables/ directory + * sequences.txt: list of filenames in the archive sequences/ directory + * versions.txt: the list of cube versions from CWProperty + * tables/.: pickled data + * sequences/: pickled data + + The pickled data format for tables and sequences is a tuple of 3 elements: + * the table name + * a tuple of column names + * a list of rows (as tuples with one element per column) + + Tables are saved in chunks in different files in order to prevent + a too high memory consumption. + """ + def __init__(self, source): + """ + :param: source an instance of the system source + """ + self._source = source + self.logger = logging.getLogger('cubicweb.ctl') + self.logger.setLevel(logging.INFO) + self.logger.addHandler(logging.StreamHandler(sys.stdout)) + self.schema = self._source.schema + self.dbhelper = self._source.dbhelper + self.cnx = None + self.cursor = None + self.sql_generator = sqlgen.SQLGenerator() + + def get_connection(self): + return self._source.get_connection() + + def backup(self, backupfile): + archive=zipfile.ZipFile(backupfile, 'w') + self.cnx = self.get_connection() + try: + self.cursor = self.cnx.cursor() + self.cursor.arraysize=100 + self.logger.info('writing metadata') + self.write_metadata(archive) + for seq in self.get_sequences(): + self.logger.info('processing sequence %s', seq) + self.write_sequence(archive, seq) + for table in self.get_tables(): + self.logger.info('processing table %s', table) + self.write_table(archive, table) + finally: + archive.close() + self.cnx.close() + self.logger.info('done') + + def get_tables(self): + non_entity_tables = ['entities', + 'deleted_entities', + 'transactions', + 'tx_entity_actions', + 'tx_relation_actions', + ] + etype_tables = [] + relation_tables = [] + prefix = 'cw_' + for etype in self.schema.entities(): + eschema = self.schema.eschema(etype) + print etype, eschema.final + if eschema.final: + continue + etype_tables.append('%s%s'%(prefix, etype)) + for rtype in self.schema.relations(): + rschema = self.schema.rschema(rtype) + if rschema.final or rschema.inlined: + continue + relation_tables.append('%s_relation' % rtype) + return non_entity_tables + etype_tables + relation_tables + + def get_sequences(self): + return ['entities_id_seq'] + + def write_metadata(self, archive): + archive.writestr('format.txt', '1.0') + archive.writestr('tables.txt', '\n'.join(self.get_tables())) + archive.writestr('sequences.txt', '\n'.join(self.get_sequences())) + versions = self._get_versions() + versions_str = '\n'.join('%s %s' % (k,v) + for k,v in versions) + archive.writestr('versions.txt', versions_str) + + def write_sequence(self, archive, seq): + sql = self.dbhelper.sql_sequence_current_state(seq) + columns, rows_iterator = self._get_cols_and_rows(sql) + rows = list(rows_iterator) + serialized = self._serialize(seq, columns, rows) + archive.writestr('sequences/%s' % seq, serialized) + + def write_table(self, archive, table): + sql = 'SELECT * FROM %s' % table + columns, rows_iterator = self._get_cols_and_rows(sql) + self.logger.info('number of rows: %d', self.cursor.rowcount) + if table.startswith('cw_'): # entities + blocksize = 2000 + else: # relations and metadata + blocksize = 10000 + if self.cursor.rowcount > 0: + for i, start in enumerate(xrange(0, self.cursor.rowcount, blocksize)): + rows = list(itertools.islice(rows_iterator, blocksize)) + serialized = self._serialize(table, columns, rows) + archive.writestr('tables/%s.%04d' % (table, i), serialized) + self.logger.debug('wrote rows %d to %d (out of %d) to %s.%04d', + start, start+len(rows)-1, + self.cursor.rowcount, + table, i) + else: + rows = [] + serialized = self._serialize(table, columns, rows) + archive.writestr('tables/%s.%04d' % (table, 0), serialized) + + def _get_cols_and_rows(self, sql): + process_result = self._source.iter_process_result + self.cursor.execute(sql) + columns = (d[0] for d in self.cursor.description) + rows = process_result(self.cursor) + return tuple(columns), rows + + def _serialize(self, name, columns, rows): + return dumps((name, columns, rows), pickle.HIGHEST_PROTOCOL) + + def restore(self, backupfile): + archive = zipfile.ZipFile(backupfile, 'r') + self.cnx = self.get_connection() + self.cursor = self.cnx.cursor() + sequences, tables, table_chunks = self.read_metadata(archive, backupfile) + for seq in sequences: + self.logger.info('restoring sequence %s', seq) + self.read_sequence(archive, seq) + for table in tables: + self.logger.info('restoring table %s', table) + self.read_table(archive, table, sorted(table_chunks[table])) + self.cnx.close() + archive.close() + self.logger.info('done') + + def read_metadata(self, archive, backupfile): + formatinfo = archive.read('format.txt') + self.logger.info('checking metadata') + if formatinfo.strip() != "1.0": + self.logger.critical('Unsupported format in archive: %s', formatinfo) + raise ValueError('Unknown format in %s: %s' % (backupfile, formatinfo)) + tables = archive.read('tables.txt').splitlines() + sequences = archive.read('sequences.txt').splitlines() + file_versions = self._parse_versions(archive.read('versions.txt')) + versions = set(self._get_versions()) + if file_versions != versions: + self.logger.critical('Unable to restore : versions do not match') + self.logger.critical('Expected:\n%s', '\n'.join(list(sorted(versions)))) + self.logger.critical('Found:\n%s', '\n'.join(list(sorted(file_versions)))) + raise ValueError('Unable to restore : versions do not match') + table_chunks = {} + for name in archive.namelist(): + if not name.startswith('tables/'): + continue + filename = basename(name) + tablename, _ext = filename.rsplit('.', 1) + table_chunks.setdefault(tablename, []).append(name) + return sequences, tables, table_chunks + + def read_sequence(self, archive, seq): + seqname, columns, rows = loads(archive.read('sequences/%s' % seq)) + assert seqname == seq + assert len(rows) == 1 + assert len(rows[0]) == 1 + value = rows[0][0] + sql = self.dbhelper.sql_restart_sequence(seq, value) + self.cursor.execute(sql) + self.cnx.commit() + + def read_table(self, archive, table, filenames): + merge_args = self._source.merge_args + self.cursor.execute('DELETE FROM %s' % table) + self.cnx.commit() + row_count = 0 + for filename in filenames: + tablename, columns, rows = loads(archive.read(filename)) + assert tablename == table + if not rows: + continue + insert = self.sql_generator.insert(table, + dict(zip(columns, rows[0]))) + for row in rows: + self.cursor.execute(insert, merge_args(dict(zip(columns, row)), {})) + row_count += len(rows) + self.cnx.commit() + self.logger.info('inserted %d rows', row_count) + + + def _parse_versions(self, version_str): + versions = set() + for line in version_str.splitlines(): + versions.add(tuple(line.split())) + return versions + + def _get_versions(self): + version_sql = 'SELECT cw_pkey, cw_value FROM cw_CWProperty' + versions = [] + self.cursor.execute(version_sql) + for pkey, value in self.cursor.fetchall(): + if pkey.startswith(u'system.version'): + versions.append((pkey, value)) + return versions diff -r c2452cd57026 -r d68f9319bfda server/sqlutils.py --- a/server/sqlutils.py Tue May 10 12:07:54 2011 +0200 +++ b/server/sqlutils.py Tue May 10 14:08:41 2011 +0200 @@ -204,6 +204,12 @@ def process_result(self, cursor, column_callbacks=None, session=None): """return a list of CubicWeb compliant values from data in the given cursor """ + return list(self.iter_process_result(cursor, column_callbacks, session)) + + def iter_process_result(self, cursor, column_callbacks=None, session=None): + """return a iterator on tuples of CubicWeb compliant values from data + in the given cursor + """ # use two different implementations to avoid paying the price of # callback lookup for each *cell* in results when there is nothing to # lookup @@ -219,16 +225,19 @@ process_value = self._process_value binary = Binary # /end - results = cursor.fetchall() - for i, line in enumerate(results): - result = [] - for col, value in enumerate(line): - if value is None: - result.append(value) - continue - result.append(process_value(value, descr[col], encoding, binary)) - results[i] = result - return results + cursor.arraysize = 100 + while True: + results = cursor.fetchmany() + if not results: + break + for line in results: + result = [] + for col, value in enumerate(line): + if value is None: + result.append(value) + continue + result.append(process_value(value, descr[col], encoding, binary)) + yield result def _cb_process_result(self, cursor, column_callbacks, session): # begin bind to locals for optimization @@ -237,22 +246,25 @@ process_value = self._process_value binary = Binary # /end - results = cursor.fetchall() - for i, line in enumerate(results): - result = [] - for col, value in enumerate(line): - if value is None: + cursor.arraysize = 100 + while True: + results = cursor.fetchmany() + if not results: + break + for line in results: + result = [] + for col, value in enumerate(line): + if value is None: + result.append(value) + continue + cbstack = column_callbacks.get(col, None) + if cbstack is None: + value = process_value(value, descr[col], encoding, binary) + else: + for cb in cbstack: + value = cb(self, session, value) result.append(value) - continue - cbstack = column_callbacks.get(col, None) - if cbstack is None: - value = process_value(value, descr[col], encoding, binary) - else: - for cb in cbstack: - value = cb(self, session, value) - result.append(value) - results[i] = result - return results + yield result def preprocess_entity(self, entity): """return a dictionary to use as extra argument to cursor.execute diff -r c2452cd57026 -r d68f9319bfda server/test/unittest_repository.py --- a/server/test/unittest_repository.py Tue May 10 12:07:54 2011 +0200 +++ b/server/test/unittest_repository.py Tue May 10 14:08:41 2011 +0200 @@ -277,13 +277,16 @@ cnxid = repo.connect(self.admlogin, password=self.admpassword) repo.execute(cnxid, 'INSERT CWUser X: X login "toto", X upassword "tutu", X in_group G WHERE G name "users"') repo.commit(cnxid) + lock = threading.Lock() + lock.acquire() # close has to be in the thread due to sqlite limitations def close_in_a_few_moment(): - time.sleep(0.1) + lock.acquire() repo.close(cnxid) t = threading.Thread(target=close_in_a_few_moment) t.start() def run_transaction(): + lock.release() repo.execute(cnxid, 'DELETE CWUser X WHERE X login "toto"') repo.commit(cnxid) try: diff -r c2452cd57026 -r d68f9319bfda sobjects/parsers.py --- a/sobjects/parsers.py Tue May 10 12:07:54 2011 +0200 +++ b/sobjects/parsers.py Tue May 10 14:08:41 2011 +0200 @@ -361,15 +361,18 @@ {'x': entity.eid}) def _set_relation(self, entity, rtype, role, eids): - eidstr = ','.join(str(eid) for eid in eids) - rql = rtype_role_rql(rtype, role) - self._cw.execute('DELETE %s, NOT Y eid IN (%s)' % (rql, eidstr), - {'x': entity.eid}) - if role == 'object': - rql = 'SET %s, Y eid IN (%s), NOT Y %s X' % (rql, eidstr, rtype) - else: - rql = 'SET %s, Y eid IN (%s), NOT X %s Y' % (rql, eidstr, rtype) + rqlbase = rtype_role_rql(rtype, role) + rql = 'DELETE %s' % rqlbase + if eids: + eidstr = ','.join(str(eid) for eid in eids) + rql += ', NOT Y eid IN (%s)' % eidstr self._cw.execute(rql, {'x': entity.eid}) + if eids: + if role == 'object': + rql = 'SET %s, Y eid IN (%s), NOT Y %s X' % (rqlbase, eidstr, rtype) + else: + rql = 'SET %s, Y eid IN (%s), NOT X %s Y' % (rqlbase, eidstr, rtype) + self._cw.execute(rql, {'x': entity.eid}) def registration_callback(vreg): vreg.register_all(globals().values(), __name__) diff -r c2452cd57026 -r d68f9319bfda sobjects/test/unittest_parsers.py --- a/sobjects/test/unittest_parsers.py Tue May 10 12:07:54 2011 +0200 +++ b/sobjects/test/unittest_parsers.py Tue May 10 14:08:41 2011 +0200 @@ -129,7 +129,7 @@ } }) session = self.repo.internal_session() - stats = dfsource.pull_data(session, force=True) + stats = dfsource.pull_data(session, force=True, raise_on_error=True) self.assertEqual(sorted(stats.keys()), ['created', 'updated']) self.assertEqual(len(stats['created']), 2) self.assertEqual(stats['updated'], set()) @@ -156,12 +156,12 @@ self.assertEqual(tag.cwuri, 'http://testing.fr/cubicweb/%s' % tag.eid) self.assertEqual(tag.cw_source[0].name, 'system') - stats = dfsource.pull_data(session, force=True) + stats = dfsource.pull_data(session, force=True, raise_on_error=True) self.assertEqual(stats['created'], set()) self.assertEqual(len(stats['updated']), 2) self.repo._type_source_cache.clear() self.repo._extid_cache.clear() - stats = dfsource.pull_data(session, force=True) + stats = dfsource.pull_data(session, force=True, raise_on_error=True) self.assertEqual(stats['created'], set()) self.assertEqual(len(stats['updated']), 2) diff -r c2452cd57026 -r d68f9319bfda test/unittest_selectors.py --- a/test/unittest_selectors.py Tue May 10 12:07:54 2011 +0200 +++ b/test/unittest_selectors.py Tue May 10 14:08:41 2011 +0200 @@ -26,7 +26,7 @@ from cubicweb.appobject import Selector, AndSelector, OrSelector from cubicweb.selectors import (is_instance, adaptable, match_user_groups, multi_lines_rset, score_entity, is_in_state, - on_transition) + on_transition, rql_condition) from cubicweb.web import action @@ -221,7 +221,7 @@ def test_is_in_state(self): for state in ('created', 'validated', 'abandoned'): selector = is_in_state(state) - self.assertEqual(selector(None, self.req, self.rset), + self.assertEqual(selector(None, self.req, rset=self.rset), state=="created") self.adapter.fire_transition('validate') @@ -229,75 +229,75 @@ self.assertEqual(self.adapter.state, 'validated') selector = is_in_state('created') - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = is_in_state('validated') - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) selector = is_in_state('validated', 'abandoned') - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) selector = is_in_state('abandoned') - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) self.adapter.fire_transition('forsake') self._commit() self.assertEqual(self.adapter.state, 'abandoned') selector = is_in_state('created') - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = is_in_state('validated') - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = is_in_state('validated', 'abandoned') - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) self.assertEqual(self.adapter.state, 'abandoned') - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) def test_is_in_state_unvalid_names(self): selector = is_in_state("unknown") with self.assertRaises(ValueError) as cm: - selector(None, self.req, self.rset) + selector(None, self.req, rset=self.rset) self.assertEqual(str(cm.exception), "wf_test: unknown state(s): unknown") selector = is_in_state("weird", "unknown", "created", "weird") with self.assertRaises(ValueError) as cm: - selector(None, self.req, self.rset) + selector(None, self.req, rset=self.rset) self.assertEqual(str(cm.exception), "wf_test: unknown state(s): unknown,weird") def test_on_transition(self): for transition in ('validate', 'forsake'): selector = on_transition(transition) - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) self.adapter.fire_transition('validate') self._commit() self.assertEqual(self.adapter.state, 'validated') selector = on_transition("validate") - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) selector = on_transition("validate", "forsake") - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) selector = on_transition("forsake") - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) self.adapter.fire_transition('forsake') self._commit() self.assertEqual(self.adapter.state, 'abandoned') selector = on_transition("validate") - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = on_transition("validate", "forsake") - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) selector = on_transition("forsake") - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) def test_on_transition_unvalid_names(self): selector = on_transition("unknown") with self.assertRaises(ValueError) as cm: - selector(None, self.req, self.rset) + selector(None, self.req, rset=self.rset) self.assertEqual(str(cm.exception), "wf_test: unknown transition(s): unknown") selector = on_transition("weird", "unknown", "validate", "weird") with self.assertRaises(ValueError) as cm: - selector(None, self.req, self.rset) + selector(None, self.req, rset=self.rset) self.assertEqual(str(cm.exception), "wf_test: unknown transition(s): unknown,weird") @@ -308,11 +308,11 @@ self.assertEqual(self.adapter.state, 'validated') selector = on_transition("validate") - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = on_transition("validate", "forsake") - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) selector = on_transition("forsake") - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) class MatchUserGroupsTC(CubicWebTC): @@ -362,13 +362,13 @@ def test_default_op_in_selector(self): expected = len(self.rset) selector = multi_lines_rset(expected) - self.assertEqual(selector(None, self.req, self.rset), 1) + self.assertEqual(selector(None, self.req, rset=self.rset), 1) self.assertEqual(selector(None, self.req, None), 0) selector = multi_lines_rset(expected + 1) - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) self.assertEqual(selector(None, self.req, None), 0) selector = multi_lines_rset(expected - 1) - self.assertEqual(selector(None, self.req, self.rset), 0) + self.assertEqual(selector(None, self.req, rset=self.rset), 0) self.assertEqual(selector(None, self.req, None), 0) def test_without_rset(self): @@ -399,7 +399,7 @@ for (expected, operator, assertion) in testdata: selector = multi_lines_rset(expected, operator) - yield self.assertEqual, selector(None, self.req, self.rset), assertion + yield self.assertEqual, selector(None, self.req, rset=self.rset), assertion class ScoreEntitySelectorTC(CubicWebTC): @@ -408,17 +408,24 @@ req = self.request() rset = req.execute('Any E WHERE E eid 1') selector = score_entity(lambda x: None) - self.assertEqual(selector(None, req, rset), 0) + self.assertEqual(selector(None, req, rset=rset), 0) selector = score_entity(lambda x: "something") - self.assertEqual(selector(None, req, rset), 1) + self.assertEqual(selector(None, req, rset=rset), 1) selector = score_entity(lambda x: object) - self.assertEqual(selector(None, req, rset), 1) + self.assertEqual(selector(None, req, rset=rset), 1) rset = req.execute('Any G LIMIT 2 WHERE G is CWGroup') selector = score_entity(lambda x: 10) - self.assertEqual(selector(None, req, rset), 20) + self.assertEqual(selector(None, req, rset=rset), 20) selector = score_entity(lambda x: 10, once_is_enough=True) - self.assertEqual(selector(None, req, rset), 10) + self.assertEqual(selector(None, req, rset=rset), 10) + def test_rql_condition_entity(self): + req = self.request() + selector = rql_condition('X identity U') + rset = req.user.as_rset() + self.assertEqual(selector(None, req, rset=rset), 1) + self.assertEqual(selector(None, req, entity=req.user), 1) + self.assertEqual(selector(None, req), 0) if __name__ == '__main__': unittest_main() diff -r c2452cd57026 -r d68f9319bfda view.py --- a/view.py Tue May 10 12:07:54 2011 +0200 +++ b/view.py Tue May 10 14:08:41 2011 +0200 @@ -447,11 +447,14 @@ rqlstdescr = self.cw_rset.syntax_tree().get_description(mainindex, translate)[0] labels = [] - for colindex, label in enumerate(rqlstdescr): - # compute column header - if label == 'Any': # find a better label - label = ','.join(translate(et) - for et in self.cw_rset.column_types(colindex)) + for colidx, label in enumerate(rqlstdescr): + try: + label = getattr(self, 'label_column_%s' % colidx)() + except AttributeError: + # compute column header + if label == 'Any': # find a better label + label = ','.join(translate(et) + for et in self.cw_rset.column_types(colidx)) labels.append(label) return labels