give session to doexec so it's able to rollback the connection on unexpected error stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Tue, 07 Jul 2009 12:24:40 +0200
branchstable
changeset 2306 95da5d9f0870
parent 2305 8f6dbe884700
child 2307 060c3f3f7d28
child 2308 b478c3a8ad2a
give session to doexec so it's able to rollback the connection on unexpected error
server/__init__.py
server/session.py
server/sources/extlite.py
server/sources/native.py
--- a/server/__init__.py	Tue Jul 07 11:42:24 2009 +0200
+++ b/server/__init__.py	Tue Jul 07 12:24:40 2009 +0200
@@ -50,8 +50,7 @@
     driver = source['db-driver']
     sqlcnx = repo.system_source.get_connection()
     sqlcursor = sqlcnx.cursor()
-    def execute(sql, args=None):
-        repo.system_source.doexec(sqlcursor, sql, args)
+    execute = sqlcursor.execute
     if drop:
         dropsql = sqldropschema(schema, driver)
         try:
--- a/server/session.py	Tue Jul 07 11:42:24 2009 +0200
+++ b/server/session.py	Tue Jul 07 12:24:40 2009 +0200
@@ -88,9 +88,7 @@
         """return a sql cursor on the system database"""
         if not sql.split(None, 1)[0].upper() == 'SELECT':
             self.mode = 'write'
-        cursor = self.pool['system']
-        self.pool.source('system').doexec(cursor, sql, args)
-        return cursor
+        return self.pool.source('system').doexec(self, sql, args)
 
     def set_language(self, language):
         """i18n configuration for translation"""
@@ -137,12 +135,12 @@
             raise Exception('try to set pool on a closed session')
         if self.pool is None:
             # get pool first to avoid race-condition
-            self._threaddata.pool = self.repo._get_pool()
+            self._threaddata.pool = pool = self.repo._get_pool()
             try:
-                self._threaddata.pool.pool_set()
+                pool.pool_set()
             except:
                 self._threaddata.pool = None
-                self.repo._free_pool(self.pool)
+                self.repo._free_pool(pool)
                 raise
             self._threads_in_transaction.add(threading.currentThread())
         return self._threaddata.pool
--- a/server/sources/extlite.py	Tue Jul 07 11:42:24 2009 +0200
+++ b/server/sources/extlite.py	Tue Jul 07 12:24:40 2009 +0200
@@ -174,9 +174,7 @@
         if server.DEBUG:
             print self.uri, 'SOURCE RQL', union.as_string()
         args = self.sqladapter.merge_args(args, query_args)
-        cursor = session.pool[self.uri]
-        self.doexec(cursor, sql, args)
-        res = self.sqladapter.process_result(cursor)
+        res = self.sqladapter.process_result(self.doexec(session, sql, args))
         if server.DEBUG:
             print '------>', res
         return res
@@ -190,7 +188,7 @@
         """
         attrs = self.sqladapter.preprocess_entity(entity)
         sql = self.sqladapter.sqlgen.insert(SQL_PREFIX + str(entity.e_schema), attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def add_entity(self, session, entity):
         """add a new entity to the source"""
@@ -207,7 +205,7 @@
             attrs = self.sqladapter.preprocess_entity(entity)
         sql = self.sqladapter.sqlgen.update(SQL_PREFIX + str(entity.e_schema),
                                             attrs, [SQL_PREFIX + 'eid'])
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def update_entity(self, session, entity):
         """update an entity in the source"""
@@ -222,7 +220,7 @@
         """
         attrs = {SQL_PREFIX + 'eid': eid}
         sql = self.sqladapter.sqlgen.delete(SQL_PREFIX + etype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def local_add_relation(self, session, subject, rtype, object):
         """add a relation to the source
@@ -233,7 +231,7 @@
         """
         attrs = {'eid_from': subject, 'eid_to': object}
         sql = self.sqladapter.sqlgen.insert('%s_relation' % rtype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def add_relation(self, session, subject, rtype, object):
         """add a relation to the source"""
@@ -252,21 +250,25 @@
         else:
             attrs = {'eid_from': subject, 'eid_to': object}
             sql = self.sqladapter.sqlgen.delete('%s_relation' % rtype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
-    def doexec(self, cursor, query, args=None):
+    def doexec(self, session, query, args=None):
         """Execute a query.
         it's a function just so that it shows up in profiling
         """
-        #t1 = time()
         if server.DEBUG:
             print 'exec', query, args
-        #import sys
-        #sys.stdout.flush()
-        # str(query) to avoid error if it's an unicode string
+        cursor = session.pool[self.uri]
         try:
+            # str(query) to avoid error if it's an unicode string
             cursor.execute(str(query), args)
         except Exception, ex:
             self.critical("sql: %r\n args: %s\ndbms message: %r",
                           query, args, ex.args[0])
+            try:
+                session.pool.connection(self.uri).rollback()
+                self.critical('transaction has been rollbacked')
+            except:
+                pass
             raise
+        return cursor
--- a/server/sources/native.py	Tue Jul 07 11:42:24 2009 +0200
+++ b/server/sources/native.py	Tue Jul 07 12:24:40 2009 +0200
@@ -185,9 +185,7 @@
 
     def sqlexec(self, session, sql, args=None):
         """execute the query and return its result"""
-        cursor = session.pool[self.uri]
-        self.doexec(cursor, sql, args)
-        return self.process_result(cursor)
+        return self.process_result(self.doexec(session, sql, args))
 
     def init_creating(self):
         pool = self.repo._get_pool()
@@ -305,17 +303,15 @@
                 sql, query_args = self._rql_sqlgen.generate(union, args, varmap)
                 self._cache[cachekey] = sql, query_args
         args = self.merge_args(args, query_args)
-        cursor = session.pool[self.uri]
         assert isinstance(sql, basestring), repr(sql)
         try:
-            self.doexec(cursor, sql, args)
+            cursor = self.doexec(session, sql, args)
         except (self.dbapi_module.OperationalError,
                 self.dbapi_module.InterfaceError):
             # FIXME: better detection of deconnection pb
             self.info("request failed '%s' ... retry with a new cursor", sql)
             session.pool.reconnect(self)
-            cursor = session.pool[self.uri]
-            self.doexec(cursor, sql, args)
+            cursor = self.doexec(session, sql, args)
         res = self.process_result(cursor)
         if server.DEBUG:
             print '------>', res
@@ -337,8 +333,7 @@
             # generate sql queries if we are able to do so
             sql, query_args = self._rql_sqlgen.generate(union, args, varmap)
             query = 'INSERT INTO %s %s' % (table, sql.encode(self.encoding))
-            self.doexec(session.pool[self.uri], query,
-                        self.merge_args(args, query_args))
+            self.doexec(session, query, self.merge_args(args, query_args))
         else:
             super(NativeSQLSource, self).flying_insert(table, session, union,
                                                        args, varmap)
@@ -358,15 +353,14 @@
                     cell = self.binary(cell.getvalue())
                 kwargs[str(index)] = cell
             kwargs_list.append(kwargs)
-        self.doexecmany(session.pool[self.uri], query, kwargs_list)
+        self.doexecmany(session, query, kwargs_list)
 
     def clean_temp_data(self, session, temptables):
         """remove temporary data, usually associated to temporary tables"""
         if temptables:
-            cursor = session.pool[self.uri]
             for table in temptables:
                 try:
-                    self.doexec(cursor,'DROP TABLE %s' % table)
+                    self.doexec(session,'DROP TABLE %s' % table)
                 except:
                     pass
                 try:
@@ -378,25 +372,25 @@
         """add a new entity to the source"""
         attrs = self.preprocess_entity(entity)
         sql = self.sqlgen.insert(SQL_PREFIX + str(entity.e_schema), attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def update_entity(self, session, entity):
         """replace an entity in the source"""
         attrs = self.preprocess_entity(entity)
         sql = self.sqlgen.update(SQL_PREFIX + str(entity.e_schema), attrs, [SQL_PREFIX + 'eid'])
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def delete_entity(self, session, etype, eid):
         """delete an entity from the source"""
         attrs = {SQL_PREFIX + 'eid': eid}
         sql = self.sqlgen.delete(SQL_PREFIX + etype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def add_relation(self, session, subject, rtype, object):
         """add a relation to the source"""
         attrs = {'eid_from': subject, 'eid_to': object}
         sql = self.sqlgen.insert('%s_relation' % rtype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
     def delete_relation(self, session, subject, rtype, object):
         """delete a relation from the source"""
@@ -410,39 +404,47 @@
         else:
             attrs = {'eid_from': subject, 'eid_to': object}
             sql = self.sqlgen.delete('%s_relation' % rtype, attrs)
-        self.doexec(session.pool[self.uri], sql, attrs)
+        self.doexec(session, sql, attrs)
 
-    def doexec(self, cursor, query, args=None):
+    def doexec(self, session, query, args=None):
         """Execute a query.
         it's a function just so that it shows up in profiling
         """
-        #t1 = time()
         if server.DEBUG:
             print 'exec', query, args
-        #import sys
-        #sys.stdout.flush()
-        # str(query) to avoid error if it's an unicode string
+        cursor = session.pool[self.uri]
         try:
+            # str(query) to avoid error if it's an unicode string
             cursor.execute(str(query), args)
         except Exception, ex:
             self.critical("sql: %r\n args: %s\ndbms message: %r",
                           query, args, ex.args[0])
+            try:
+                session.pool.connection(self.uri).rollback()
+                self.critical('transaction has been rollbacked')
+            except:
+                pass
             raise
+        return cursor
 
-    def doexecmany(self, cursor, query, args):
+    def doexecmany(self, session, query, args):
         """Execute a query.
         it's a function just so that it shows up in profiling
         """
-        #t1 = time()
         if server.DEBUG:
             print 'execmany', query, 'with', len(args), 'arguments'
-        #import sys
-        #sys.stdout.flush()
-        # str(query) to avoid error if it's an unicode string
+        cursor = session.pool[self.uri]
         try:
+            # str(query) to avoid error if it's an unicode string
             cursor.executemany(str(query), args)
-        except:
-            self.critical("sql many: %r\n args: %s", query, args)
+        except Exception, ex:
+            self.critical("sql many: %r\n args: %s\ndbms message: %r",
+                          query, args, ex.args[0])
+            try:
+                session.pool.connection(self.uri).rollback()
+                self.critical('transaction has been rollbacked')
+            except:
+                pass
             raise
 
     # short cut to method requiring advanced db helper usage ##################
@@ -498,14 +500,13 @@
         # running with an ldap source, and table will be deleted manually any way
         # on commit
         sql = self.dbhelper.sql_temporary_table(table, schema, False)
-        self.doexec(session.pool[self.uri], sql)
+        self.doexec(session, sql)
 
     def create_eid(self, session):
         self._eid_creation_lock.acquire()
         try:
-            cursor = session.pool[self.uri]
             for sql in self.dbhelper.sqls_increment_sequence('entities_id_seq'):
-                self.doexec(cursor, sql)
+                self.doexec(session, sql)
             return cursor.fetchone()[0]
         finally:
             self._eid_creation_lock.release()