states.py
changeset 23 423c62a146c7
parent 22 93dd72d028a1
child 24 20ac7fa3fd29
--- a/states.py	Fri Jul 01 14:55:02 2011 +0200
+++ b/states.py	Fri Jul 01 16:00:19 2011 +0200
@@ -17,7 +17,9 @@
 
 name are not fixed yet.
 '''
+import os
 from functools import partial
+
 from mercurial.i18n import _
 from mercurial import cmdutil
 from mercurial import scmutil
@@ -31,6 +33,7 @@
 from mercurial import extensions
 from mercurial import wireproto
 from mercurial import pushkey
+from mercurial.lock import release
 
 
 _NOSHARE=2
@@ -196,6 +199,8 @@
     opull = repo.pull
     opush = repo.push
     o_tag = repo._tag
+    orollback = repo.rollback
+    o_writejournal = repo._writejournal
     class statefulrepo(repo.__class__):
 
         def nodestate(self, node):
@@ -236,11 +241,13 @@
             except IOError:
                 pass
             return heads
-        def _readstatesheads(self):
+
+        def _readstatesheads(self, undo=False):
             statesheads = {}
             for state in STATES:
                 if state.trackheads:
-                    filename = 'states/%s-heads' % state.name
+                    filemask = 'states/%s-heads'
+                    filename = filemask % state.name
                     statesheads[state] = self._readheadsfile(filename)
             return statesheads
 
@@ -333,10 +340,47 @@
                 remote = map(node.bin, remote.listkeys('immutableheads'))
             return remote
 
+        ### Tag support
+
         def _tag(self, names, node, *args, **kwargs):
             tagnode = o_tag(names, node, *args, **kwargs)
             self.setstate(ST0, [node, tagnode])
             return tagnode
 
+        ### rollback support
+
+        def _writejournal(self, desc):
+            entries = list(o_writejournal(desc))
+            for state in STATES:
+                if state.trackheads:
+                    filename = 'states/%s-heads' % state.name
+                    filepath = self.join(filename)
+                    if  os.path.exists(filepath):
+                        journalname = 'states/journal.%s-heads' % state.name
+                        journalpath = self.join(journalname)
+                        util.copyfile(filepath, journalpath)
+                        entries.append(journalpath)
+            return tuple(entries)
+
+        def rollback(self, dryrun=False):
+            wlock = lock = None
+            try:
+                wlock = self.wlock()
+                lock = self.lock()
+                ret = orollback(dryrun)
+                if not (ret or dryrun): #rollback did not failed
+                    for state in STATES:
+                        if state.trackheads:
+                            src  = self.join('states/undo.%s-heads') % state.name
+                            dest = self.join('states/%s-heads') % state.name
+                            if os.path.exists(src):
+                                util.rename(src, dest)
+                            elif os.path.exists(dest): #unlink in any case
+                                os.unlink(dest)
+                    self.__dict__.pop('_statesheads', None)
+                return ret
+            finally:
+                release(lock, wlock)
+
     repo.__class__ = statefulrepo