server/session.py
changeset 5849 9db65b381028
parent 5815 282194aa43f3
parent 5826 462435bf5457
child 5891 99024ad59223
--- a/server/session.py	Wed Jun 30 15:43:36 2010 +0200
+++ b/server/session.py	Thu Jul 01 17:06:37 2010 +0200
@@ -563,11 +563,15 @@
     @property
     def pool(self):
         """connections pool, set according to transaction mode for each query"""
+        if self._closed:
+            self.reset_pool(True)
+            raise Exception('try to access pool on a closed session')
         return getattr(self._threaddata, 'pool', None)
 
-    def set_pool(self, checkclosed=True):
+    def set_pool(self):
         """the session need a pool to execute some queries"""
-        if checkclosed and self._closed:
+        if self._closed:
+            self.reset_pool(True)
             raise Exception('try to set pool on a closed session')
         if self.pool is None:
             # get pool first to avoid race-condition
@@ -578,24 +582,34 @@
                 self._threaddata.pool = None
                 self.repo._free_pool(pool)
                 raise
-            self._threads_in_transaction.add(threading.currentThread())
+            self._threads_in_transaction.add(
+                (threading.currentThread(), pool) )
         return self._threaddata.pool
 
+    def _free_thread_pool(self, thread, pool, force_close=False):
+        try:
+            self._threads_in_transaction.remove( (thread, pool) )
+        except KeyError:
+            # race condition on pool freeing (freed by commit or rollback vs
+            # close)
+            pass
+        else:
+            if force_close:
+                pool.reconnect()
+            else:
+                pool.pool_reset()
+            # free pool once everything is done to avoid race-condition
+            self.repo._free_pool(pool)
+
     def reset_pool(self, ignoremode=False):
         """the session is no longer using its pool, at least for some time"""
         # pool may be none if no operation has been done since last commit
         # or rollback
-        if self.pool is not None and (ignoremode or self.mode == 'read'):
+        pool = getattr(self._threaddata, 'pool', None)
+        if pool is not None and (ignoremode or self.mode == 'read'):
             # even in read mode, we must release the current transaction
-            pool = self.pool
-            try:
-                self._threads_in_transaction.remove(threading.currentThread())
-            except KeyError:
-                pass
-            pool.pool_reset()
+            self._free_thread_pool(threading.currentThread(), pool)
             del self._threaddata.pool
-            # free pool once everything is done to avoid race-condition
-            self.repo._free_pool(pool)
 
     def _touch(self):
         """update latest session usage timestamp and reset mode to read"""
@@ -772,7 +786,9 @@
 
     def rollback(self, reset_pool=True):
         """rollback the current session's transaction"""
-        if self.pool is None:
+        # don't use self.pool, rollback may be called with _closed == True
+        pool = getattr(self._threaddata, 'pool', None)
+        if pool is None:
             self._clear_thread_data()
             self._touch()
             self.debug('rollback session %s done (no db activity)', self.id)
@@ -787,7 +803,7 @@
                     except:
                         self.critical('rollback error', exc_info=sys.exc_info())
                         continue
-                self.pool.rollback()
+                pool.rollback()
                 self.debug('rollback for session %s done', self.id)
         finally:
             self._touch()
@@ -799,7 +815,7 @@
         """do not close pool on session close, since they are shared now"""
         self._closed = True
         # copy since _threads_in_transaction maybe modified while waiting
-        for thread in self._threads_in_transaction.copy():
+        for thread, pool in self._threads_in_transaction.copy():
             if thread is threading.currentThread():
                 continue
             self.info('waiting for thread %s', thread)
@@ -809,11 +825,12 @@
             for i in xrange(10):
                 thread.join(1)
                 if not (thread.isAlive() and
-                        thread in self._threads_in_transaction):
+                        (thread, pool) in self._threads_in_transaction):
                     break
             else:
                 self.error('thread %s still alive after 10 seconds, will close '
                            'session anyway', thread)
+                self._free_thread_pool(thread, pool, force_close=True)
         self.rollback()
         del self.__threaddata
         del self._tx_data