arche / internal/cli/cmd_squash.go

commit 154431fd
  1package cli
  2
  3import (
  4	"fmt"
  5	"strings"
  6	"time"
  7
  8	"arche/internal/object"
  9	"arche/internal/repo"
 10	"arche/internal/store"
 11	"arche/internal/wc"
 12
 13	"github.com/spf13/cobra"
 14)
 15
 16var squashCmd = &cobra.Command{
 17	Use:   "squash <id1>..<id2>",
 18	Short: "Collapse a range of commits into one",
 19	Long: `Collapse all commits from <id1> to <id2> (inclusive, linear range) into a
 20single new commit.  The new commit:
 21
 22  - Has the tree of <id2> (the final state).
 23  - Has the parents of <id1> (the commit before the range starts).
 24  - Gets a new change ID.
 25  - Combines the messages of all squashed commits.
 26
 27The squashed commits are each marked obsolete.  If HEAD is inside the
 28squashed range, HEAD is moved to the new commit and the working copy is
 29materialized from <id2>'s tree.`,
 30	Args: cobra.ExactArgs(1),
 31	RunE: squashRunE,
 32}
 33
 34func squashRunE(cmd *cobra.Command, args []string) error {
 35	r := openRepo()
 36	defer r.Close()
 37
 38	id1, id2, err := parseDotDotRange(r, args[0])
 39	if err != nil {
 40		return err
 41	}
 42
 43	chain, err := collectRangeInclusive(r, id1, id2)
 44	if err != nil {
 45		return err
 46	}
 47	if len(chain) == 1 {
 48		fmt.Println("Range contains only one commit — nothing to squash.")
 49		return nil
 50	}
 51
 52	if !squashForceRewrite {
 53		for _, cid := range chain {
 54			c, err := r.ReadCommit(cid)
 55			if err != nil {
 56				return fmt.Errorf("read commit %x: %w", cid[:6], err)
 57			}
 58			if c.Phase == object.PhasePublic {
 59				return fmt.Errorf("commit ch:%s is public; use --force-rewrite to rewrite history", c.ChangeID)
 60			}
 61		}
 62	}
 63
 64	first, err := r.ReadCommit(chain[0])
 65	if err != nil {
 66		return err
 67	}
 68	last, err := r.ReadCommit(chain[len(chain)-1])
 69	if err != nil {
 70		return err
 71	}
 72
 73	var msgs []string
 74	for _, cid := range chain {
 75		c, err := r.ReadCommit(cid)
 76		if err != nil {
 77			return err
 78		}
 79		if m := strings.TrimSpace(c.Message); m != "" {
 80			msgs = append(msgs, m)
 81		}
 82	}
 83	combined := strings.Join(msgs, "\n\n")
 84	if combined == "" {
 85		combined = "squash"
 86	}
 87
 88	before, _ := r.CaptureRefState()
 89	now := time.Now()
 90	sig := object.Signature{Name: r.Cfg.User.Name, Email: r.Cfg.User.Email, Timestamp: now}
 91
 92	tx, err := r.Store.Begin()
 93	if err != nil {
 94		return err
 95	}
 96
 97	newCID, err := r.Store.AllocChangeID(tx)
 98	if err != nil {
 99		r.Store.Rollback(tx)
100		return err
101	}
102
103	squashed := &object.Commit{
104		TreeID:    last.TreeID,
105		Parents:   first.Parents,
106		ChangeID:  newCID,
107		Author:    first.Author,
108		Committer: sig,
109		Message:   combined,
110		Phase:     first.Phase,
111	}
112	if first.Author.Timestamp.IsZero() {
113		squashed.Author = sig
114	}
115
116	newID, err := repo.WriteCommitTx(r.Store, tx, squashed)
117	if err != nil {
118		r.Store.Rollback(tx)
119		return err
120	}
121	if err := r.Store.SetChangeCommit(tx, newCID, newID); err != nil {
122		r.Store.Rollback(tx)
123		return err
124	}
125
126	for _, oldID := range chain {
127		obs := &object.ObsoleteMarker{
128			Predecessor: oldID,
129			Successors:  [][32]byte{newID},
130			Reason:      "squash",
131		}
132		if _, err := repo.WriteObsoleteTx(r.Store, tx, obs); err != nil {
133			r.Store.Rollback(tx)
134			return err
135		}
136	}
137
138	after := fmt.Sprintf(`{"head":%q,"tip":%q}`,
139		object.FormatChangeID(newCID), fmt.Sprintf("%x", newID))
140	op := store.Operation{
141		Kind:      "squash",
142		Timestamp: now.Unix(),
143		Before:    before,
144		After:     after,
145		Metadata:  fmt.Sprintf("squashed %d commits", len(chain)),
146	}
147	if _, err := r.Store.InsertOperation(tx, op); err != nil {
148		r.Store.Rollback(tx)
149		return err
150	}
151	if err := r.Store.Commit(tx); err != nil {
152		return err
153	}
154
155	_, headID, headErr := r.HeadCommit()
156	if headErr == nil {
157		for _, cid := range chain {
158			if cid == headID {
159				w := wc.New(r)
160				if err := w.Materialize(squashed.TreeID, object.FormatChangeID(newCID)); err != nil {
161					return err
162				}
163				if err := r.WriteHead(object.FormatChangeID(newCID)); err != nil {
164					return err
165				}
166				break
167			}
168		}
169	}
170
171	fmt.Printf("Squashed %d commits → ch:%s\n  %s\n",
172		len(chain), newCID, squashFirstLine(combined))
173	return nil
174}
175
176func parseDotDotRange(r *repo.Repo, s string) ([32]byte, [32]byte, error) {
177	parts := strings.SplitN(s, "..", 2)
178	if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
179		return object.ZeroID, object.ZeroID,
180			fmt.Errorf("squash range must be <id1>..<id2>, got %q", s)
181	}
182	id1, err := resolveRef(r, parts[0])
183	if err != nil {
184		return object.ZeroID, object.ZeroID, fmt.Errorf("resolve %q: %w", parts[0], err)
185	}
186	id2, err := resolveRef(r, parts[1])
187	if err != nil {
188		return object.ZeroID, object.ZeroID, fmt.Errorf("resolve %q: %w", parts[1], err)
189	}
190	return id1, id2, nil
191}
192
193func collectRangeInclusive(r *repo.Repo, id1, id2 [32]byte) ([][32]byte, error) {
194	var chain [][32]byte
195	cur := id2
196	for {
197		chain = append(chain, cur)
198		if cur == id1 {
199			break
200		}
201		c, err := r.ReadCommit(cur)
202		if err != nil {
203			return nil, fmt.Errorf("read commit %x: %w", cur[:4], err)
204		}
205		if len(c.Parents) == 0 {
206			return nil, fmt.Errorf("reached root without finding start commit — id1 must be an ancestor of id2")
207		}
208		if len(c.Parents) > 1 {
209			return nil, fmt.Errorf("merge commit %x in range — squash only works on linear chains", cur[:4])
210		}
211		cur = c.Parents[0]
212	}
213	for i, j := 0, len(chain)-1; i < j; i, j = i+1, j-1 {
214		chain[i], chain[j] = chain[j], chain[i]
215	}
216	return chain, nil
217}
218
219var squashForceRewrite bool
220
221func init() {
222	squashCmd.Flags().BoolVar(&squashForceRewrite, "force-rewrite", false, "allow rewriting public commits")
223}
224
225func squashFirstLine(s string) string {
226	if idx := strings.IndexByte(s, '\n'); idx >= 0 {
227		return s[:idx]
228	}
229	return s
230}