arche / internal/cli/cmd_rebase.go

commit 154431fd
  1package cli
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"arche/internal/merge"
  8	"arche/internal/object"
  9	"arche/internal/repo"
 10	"arche/internal/store"
 11	"arche/internal/wc"
 12
 13	"github.com/spf13/cobra"
 14)
 15
 16var rebaseForceRewrite bool
 17
 18var rebaseCmd = &cobra.Command{
 19	Use:   "rebase <dest>",
 20	Short: "Replay commits from the current change onto a new base",
 21	Long: `Replay the working-copy lineage on top of <dest>.
 22Each replayed commit is re-parented; the originals are marked obsolete.
 23Conflicts stop the rebase and require 'arche resolve'.`,
 24	Args: cobra.ExactArgs(1),
 25	RunE: func(cmd *cobra.Command, args []string) error {
 26		r := openRepo()
 27		defer r.Close()
 28
 29		destID, err := resolveRef(r, args[0])
 30		if err != nil {
 31			return err
 32		}
 33
 34		_, headID, err := r.HeadCommit()
 35		if err != nil {
 36			return err
 37		}
 38
 39		chain, err := collectLinearChain(r, headID, destID)
 40		if err != nil {
 41			return err
 42		}
 43		if len(chain) == 0 {
 44			fmt.Println("Nothing to rebase.")
 45			return nil
 46		}
 47
 48		if !rebaseForceRewrite {
 49			for _, id := range chain {
 50				c, err := r.ReadCommit(id)
 51				if err != nil {
 52					return fmt.Errorf("read commit %x: %w", id[:6], err)
 53				}
 54				if c.Phase == object.PhasePublic {
 55					return fmt.Errorf("commit %x is public; use --force-rewrite to rewrite history", id[:8])
 56				}
 57			}
 58		}
 59
 60		before, _ := r.CaptureRefState()
 61		now := time.Now()
 62		newParentID := destID
 63		newTipID := destID
 64		newTipChangeID := ""
 65
 66		for _, origID := range chain {
 67			orig, err := r.ReadCommit(origID)
 68			if err != nil {
 69				return fmt.Errorf("read commit %x: %w", origID[:6], err)
 70			}
 71
 72			newParentCommit, err := r.ReadCommit(newParentID)
 73			if err != nil {
 74				return err
 75			}
 76
 77			var baseTreeID [32]byte
 78			if len(orig.Parents) > 0 {
 79				p, err := r.ReadCommit(orig.Parents[0])
 80				if err != nil {
 81					return err
 82				}
 83				baseTreeID = p.TreeID
 84			}
 85
 86			result, err := merge.Trees(r, baseTreeID, orig.TreeID, newParentCommit.TreeID)
 87			if err != nil {
 88				return fmt.Errorf("rebase merge step: %w", err)
 89			}
 90
 91			sig := object.Signature{
 92				Name:      orig.Author.Name,
 93				Email:     orig.Author.Email,
 94				Timestamp: orig.Author.Timestamp,
 95			}
 96			committer := object.Signature{Name: r.Cfg.User.Name, Email: r.Cfg.User.Email, Timestamp: now}
 97			_ = committer
 98
 99			tx, err := r.Store.Begin()
100			if err != nil {
101				return err
102			}
103
104			newCommit := &object.Commit{
105				TreeID:    result.TreeID,
106				Parents:   [][32]byte{newParentID},
107				ChangeID:  orig.ChangeID,
108				Author:    sig,
109				Committer: object.Signature{Name: r.Cfg.User.Name, Email: r.Cfg.User.Email, Timestamp: now},
110				Message:   orig.Message,
111				Phase:     orig.Phase,
112			}
113			newCommitID, err := repo.WriteCommitTx(r.Store, tx, newCommit)
114			if err != nil {
115				r.Store.Rollback(tx)
116				return err
117			}
118			if err := r.Store.SetChangeCommit(tx, orig.ChangeID, newCommitID); err != nil {
119				r.Store.Rollback(tx)
120				return err
121			}
122
123			obs := &object.ObsoleteMarker{
124				Predecessor: origID,
125				Successors:  [][32]byte{newCommitID},
126				Reason:      "rebase",
127			}
128			obsID, err := repo.WriteObsoleteTx(r.Store, tx, obs)
129			if err != nil {
130				r.Store.Rollback(tx)
131				return err
132			}
133			_ = obsID
134
135			opAfter := buildMergeRefState(newCommitID, object.FormatChangeID(orig.ChangeID))
136			op := store.Operation{
137				Kind: "rebase-step", Timestamp: now.Unix(), Before: before, After: opAfter,
138				Metadata: fmt.Sprintf("rebased %s onto %x", object.FormatChangeID(orig.ChangeID), destID[:6]),
139			}
140			if _, err := r.Store.InsertOperation(tx, op); err != nil {
141				r.Store.Rollback(tx)
142				return err
143			}
144
145			if err := r.Store.ClearAllConflicts(tx); err != nil {
146				r.Store.Rollback(tx)
147				return err
148			}
149			for _, cp := range result.Conflicts {
150				if err := r.Store.AddConflict(tx, cp); err != nil {
151					r.Store.Rollback(tx)
152					return err
153				}
154			}
155
156			if err := r.Store.Commit(tx); err != nil {
157				return err
158			}
159
160			newParentID = newCommitID
161			newTipID = newCommitID
162			newTipChangeID = object.FormatChangeID(orig.ChangeID)
163
164			if len(result.Conflicts) > 0 {
165				w := wc.New(r)
166				if err := w.Materialize(result.TreeID, newTipChangeID); err != nil {
167					return err
168				}
169				if err := r.WriteHead(newTipChangeID); err != nil {
170					return err
171				}
172				fmt.Printf("Rebase paused at %s (%d conflict(s)):\n", newTipChangeID, len(result.Conflicts))
173				for _, p := range result.Conflicts {
174					fmt.Printf("  conflict: %s\n", p)
175				}
176				fmt.Println("Resolve conflicts, then run 'arche resolve <path>' and 'arche snap'.")
177				return nil
178			}
179		}
180
181		_ = newTipID
182		finalCommit, err := r.ReadCommit(newParentID)
183		if err != nil {
184			return err
185		}
186		w := wc.New(r)
187		if err := w.Materialize(finalCommit.TreeID, newTipChangeID); err != nil {
188			return err
189		}
190		if err := r.WriteHead(newTipChangeID); err != nil {
191			return err
192		}
193		fmt.Printf("Rebase complete: now at %s\n", newTipChangeID)
194		return nil
195	},
196}
197
198func init() {
199	rebaseCmd.Flags().BoolVar(&rebaseForceRewrite, "force-rewrite", false, "allow rewriting public commits")
200}
201
202func collectLinearChain(r *repo.Repo, headID, destID [32]byte) ([][32]byte, error) {
203	destAncestors := make(map[[32]byte]bool)
204	q := [][32]byte{destID}
205	for len(q) > 0 {
206		id := q[0]
207		q = q[1:]
208		if destAncestors[id] {
209			continue
210		}
211		destAncestors[id] = true
212		c, err := r.ReadCommit(id)
213		if err != nil {
214			break
215		}
216		q = append(q, c.Parents...)
217	}
218
219	var chain [][32]byte
220	current := headID
221	for !destAncestors[current] {
222
223		c, err := r.ReadCommit(current)
224		if err != nil {
225			return nil, err
226		}
227		chain = append(chain, current)
228		if len(c.Parents) == 0 {
229			break
230		}
231		current = c.Parents[0]
232	}
233
234	for i, j := 0, len(chain)-1; i < j; i, j = i+1, j-1 {
235		chain[i], chain[j] = chain[j], chain[i]
236	}
237
238	return chain, nil
239}