diff -r 9d5cfbde9856 -r 1f3757ef3762 server/sqlutils.py --- a/server/sqlutils.py Thu Oct 18 15:52:05 2012 +0200 +++ b/server/sqlutils.py Tue Oct 18 16:55:16 2011 +0200 @@ -20,8 +20,10 @@ __docformat__ = "restructuredtext en" import os +import re import subprocess from datetime import datetime, date +from itertools import ifilter from logilab import database as db, common as lgc from logilab.common.shellutils import ProgressBar @@ -49,27 +51,53 @@ def sqlexec(sqlstmts, cursor_or_execute, withpb=not os.environ.get('APYCOT_ROOT'), - pbtitle='', delimiter=';'): + pbtitle='', delimiter=';', cnx=None): """execute sql statements ignoring DROP/ CREATE GROUP or USER statements - error. If a cnx is given, commit at each statement + error. + + :sqlstmts_as_string: a string or a list of sql statements. + :cursor_or_execute: sql cursor or a callback used to execute statements + :cnx: if given, commit/rollback at each statement. + + :withpb: if True, display a progresse bar + :pbtitle: a string displayed as the progress bar title (if `withpb=True`) + + :delimiter: a string used to split sqlstmts (if it is a string) + + Return the failed statements (same type as sqlstmts) """ if hasattr(cursor_or_execute, 'execute'): execute = cursor_or_execute.execute else: execute = cursor_or_execute - sqlstmts = sqlstmts.split(delimiter) + sqlstmts_as_string = False + if isinstance(sqlstmts, basestring): + sqlstmts_as_string = True + sqlstmts = sqlstmts.split(delimiter) if withpb: pb = ProgressBar(len(sqlstmts), title=pbtitle) + failed = [] for sql in sqlstmts: sql = sql.strip() if withpb: pb.update() if not sql: continue - # some dbapi modules doesn't accept unicode for sql string - execute(str(sql)) + try: + # some dbapi modules doesn't accept unicode for sql string + execute(str(sql)) + except Exception, err: + if cnx: + cnx.rollback() + failed.append(sql) + else: + if cnx: + cnx.commit() if withpb: print + if sqlstmts_as_string: + failed = delimiter.join(failed) + return failed def sqlgrants(schema, driver, user, @@ -137,6 +165,23 @@ return '\n'.join(output) +_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION = re.compile('^(?!(sql|pg)_)').match +def sql_drop_all_user_tables(driver_or_helper, sqlcursor): + """Return ths sql to drop all tables found in the database system.""" + if not getattr(driver_or_helper, 'list_tables', None): + dbhelper = db.get_db_helper(driver_or_helper) + else: + dbhelper = driver_or_helper + + cmds = [dbhelper.sql_drop_sequence('entities_id_seq')] + # for mssql, we need to drop views before tables + if hasattr(dbhelper, 'list_views'): + cmds += ['DROP VIEW %s;' % name + for name in ifilter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_views(sqlcursor))] + cmds += ['DROP TABLE %s;' % name + for name in ifilter(_SQL_DROP_ALL_USER_TABLES_FILTER_FUNCTION, dbhelper.list_tables(sqlcursor))] + return '\n'.join(cmds) + class SQLAdapterMixIn(object): """Mixin for SQL data sources, getting a connection from a configuration dictionary and handling connection locking