1package wc
2
3import (
4 "encoding/json"
5 "fmt"
6 "io/fs"
7 "os"
8 "path/filepath"
9 "sort"
10 "strings"
11 "syscall"
12 "time"
13
14 "arche/internal/merge"
15 "arche/internal/object"
16 "arche/internal/repo"
17 "arche/internal/store"
18 "arche/internal/watcher"
19)
20
21func dirtySet(r *repo.Repo) (map[string]bool, error) {
22 if !watcher.IsActive(r.ArcheDir()) {
23 return nil, nil
24 }
25 entries, err := r.Store.ListDirtyWCacheEntries()
26 if err != nil {
27 return nil, err
28 }
29 m := make(map[string]bool, len(entries))
30 for _, e := range entries {
31 m[e.Path] = true
32 }
33 return m, nil
34}
35
36type FileStatus struct {
37 Path string
38 Status rune
39}
40
41type WC struct {
42 Repo *repo.Repo
43 SignKey string
44 NoAutoAdvance bool
45 AuthorOverride *object.Signature
46}
47
48func New(r *repo.Repo) *WC { return &WC{Repo: r} }
49
50func (wc *WC) maybeSign(c *object.Commit) error {
51 if wc.SignKey == "" {
52 return nil
53 }
54 body := object.CommitBodyForSigning(c)
55 sig, _, err := object.SignCommitBody(body, wc.SignKey)
56 if err != nil {
57 return fmt.Errorf("commit signing: %w", err)
58 }
59 c.CommitSig = sig
60 return nil
61}
62
63func (wc *WC) snapshotIntoTx(tx *store.Tx, headCommit *object.Commit, paths []string, cacheMap map[string]store.WCacheEntry, dirty map[string]bool, message string, now time.Time) (*object.Commit, [32]byte, error) {
64 r := wc.Repo
65
66 var entries []fileEntry
67
68 if err := r.Store.ClearWCache(tx); err != nil {
69 return nil, object.ZeroID, fmt.Errorf("clear wcache: %w", err)
70 }
71
72 for _, rel := range paths {
73 if dirty != nil && !dirty[rel] {
74 if cached, ok := cacheMap[rel]; ok {
75 entries = append(entries, fileEntry{
76 path: rel,
77 blobID: cached.BlobID,
78 mode: object.EntryMode(cached.Mode),
79 })
80 if err := r.Store.SetWCacheEntry(tx, cached); err != nil {
81 return nil, object.ZeroID, fmt.Errorf("set wcache: %w", err)
82 }
83 continue
84 }
85 }
86
87 abs := filepath.Join(r.Root, rel)
88 info, err := os.Lstat(abs)
89 if err != nil {
90 continue
91 }
92
93 var blobID [32]byte
94 mode := fileMode(info)
95
96 if cached, ok := cacheMap[rel]; ok {
97 st := info.Sys().(*syscall.Stat_t)
98 inode := st.Ino
99 mtime := info.ModTime().UnixNano()
100 size := info.Size()
101 if cached.Inode == inode && cached.MtimeNs == mtime && cached.Size == size {
102 blobID = cached.BlobID
103 }
104 }
105
106 if blobID == object.ZeroID {
107 content, err := readFileContent(abs, info)
108 if err != nil {
109 return nil, object.ZeroID, err
110 }
111 id, err := repo.WriteBlobTx(r.Store, tx, &object.Blob{Content: content})
112 if err != nil {
113 return nil, object.ZeroID, err
114 }
115 blobID = id
116 }
117
118 st := info.Sys().(*syscall.Stat_t)
119 if err := r.Store.SetWCacheEntry(tx, store.WCacheEntry{
120 Path: rel,
121 Inode: st.Ino,
122 MtimeNs: info.ModTime().UnixNano(),
123 Size: info.Size(),
124 BlobID: blobID,
125 Mode: uint8(mode),
126 }); err != nil {
127 return nil, object.ZeroID, fmt.Errorf("set wcache: %w", err)
128 }
129
130 entries = append(entries, fileEntry{path: rel, blobID: blobID, mode: mode})
131 }
132
133 tree, err := buildTree(r, tx, entries)
134 if err != nil {
135 return nil, object.ZeroID, err
136 }
137
138 sig := object.Signature{
139 Name: r.Cfg.User.Name,
140 Email: r.Cfg.User.Email,
141 Timestamp: now,
142 }
143
144 c := &object.Commit{
145 TreeID: tree,
146 Parents: headCommit.Parents,
147 ChangeID: headCommit.ChangeID,
148 Author: headCommit.Author,
149 Committer: sig,
150 Message: message,
151 Phase: headCommit.Phase,
152 }
153 if headCommit.Author.Timestamp.IsZero() {
154 c.Author = sig
155 }
156
157 if err := wc.maybeSign(c); err != nil {
158 return nil, object.ZeroID, err
159 }
160
161 commitID, err := repo.WriteCommitTx(r.Store, tx, c)
162 if err != nil {
163 return nil, object.ZeroID, err
164 }
165 if err := r.Store.SetChangeCommit(tx, c.ChangeID, commitID); err != nil {
166 return nil, object.ZeroID, err
167 }
168
169 return c, commitID, nil
170}
171
172func (wc *WC) snapshotInput() (paths []string, cacheMap map[string]store.WCacheEntry, dirty map[string]bool, err error) {
173 r := wc.Repo
174
175 cacheEntries, err := r.Store.ListWCacheEntries()
176 if err != nil {
177 return nil, nil, nil, err
178 }
179 cacheMap = make(map[string]store.WCacheEntry, len(cacheEntries))
180 for _, e := range cacheEntries {
181 cacheMap[e.Path] = e
182 }
183
184 dirty, _ = dirtySet(r)
185
186 if dirty != nil {
187 seen := make(map[string]bool, len(cacheMap)+len(dirty))
188 for p := range cacheMap {
189 seen[p] = true
190 paths = append(paths, p)
191 }
192 for p := range dirty {
193 if !seen[p] {
194 paths = append(paths, p)
195 }
196 }
197 } else {
198 paths, err = wc.trackedPaths()
199 if err != nil {
200 return nil, nil, nil, err
201 }
202 }
203
204 return paths, cacheMap, dirty, nil
205}
206
207func (wc *WC) Snapshot(message string) (*object.Commit, [32]byte, error) {
208 r := wc.Repo
209 now := time.Now()
210
211 head, _, err := r.HeadCommit()
212 if err != nil {
213 return nil, object.ZeroID, err
214 }
215
216 paths, cacheMap, dirty, err := wc.snapshotInput()
217 if err != nil {
218 return nil, object.ZeroID, err
219 }
220
221 tx, err := r.Store.Begin()
222 if err != nil {
223 return nil, object.ZeroID, err
224 }
225
226 c, commitID, err := wc.snapshotIntoTx(tx, head, paths, cacheMap, dirty, message, now)
227 if err != nil {
228 r.Store.Rollback(tx)
229 return nil, object.ZeroID, err
230 }
231 if err := r.Store.Commit(tx); err != nil {
232 return nil, object.ZeroID, err
233 }
234 return c, commitID, nil
235}
236
237func (wc *WC) Snap(message string) (*object.Commit, [32]byte, error) {
238 r := wc.Repo
239 now := time.Now()
240
241 before, err := r.CaptureRefState()
242 if err != nil {
243 return nil, object.ZeroID, err
244 }
245
246 statusBefore, err := wc.Status()
247 if err != nil {
248 return nil, object.ZeroID, err
249 }
250 diffPaths := make(map[string]bool, len(statusBefore))
251 for _, fsEntry := range statusBefore {
252 diffPaths[fsEntry.Path] = true
253 }
254
255 useRestrictedPaths := len(r.Cfg.Hooks.PreSnap) > 0
256 if useRestrictedPaths {
257 if err := RunHooksSequential(r.Root, "pre-snap", r.Cfg.Hooks.PreSnap); err != nil {
258 return nil, object.ZeroID, fmt.Errorf("pre-snap hook failed: %w", err)
259 }
260 }
261
262 head, oldHeadID, err := r.HeadCommit()
263 if err != nil {
264 return nil, object.ZeroID, err
265 }
266
267 type snapshotFn func(tx *store.Tx) (*object.Commit, [32]byte, error)
268 var doSnapshot snapshotFn
269
270 if useRestrictedPaths {
271 headBlobs := make(map[string][32]byte)
272 headModes := make(map[string]object.EntryMode)
273 if err := flattenTree(r, head.TreeID, "", headBlobs); err != nil {
274 return nil, object.ZeroID, err
275 }
276 if err := flattenTreeModes(r, head.TreeID, "", headModes); err != nil {
277 return nil, object.ZeroID, err
278 }
279 doSnapshot = func(tx *store.Tx) (*object.Commit, [32]byte, error) {
280 return wc.snapshotRestrictedPathsIntoTx(tx, head, headBlobs, headModes, diffPaths, message, now)
281 }
282 } else {
283 paths, cacheMap, dirty, err := wc.snapshotInput()
284 if err != nil {
285 return nil, object.ZeroID, err
286 }
287 doSnapshot = func(tx *store.Tx) (*object.Commit, [32]byte, error) {
288 return wc.snapshotIntoTx(tx, head, paths, cacheMap, dirty, message, now)
289 }
290 }
291
292 existingBookmarks, _ := r.Store.ListBookmarks()
293
294 tx, err := r.Store.Begin()
295 if err != nil {
296 return nil, object.ZeroID, err
297 }
298
299 snapped, snappedID, err := doSnapshot(tx)
300 if err != nil {
301 r.Store.Rollback(tx)
302 return nil, object.ZeroID, err
303 }
304
305 if snappedID != oldHeadID {
306 for _, bm := range existingBookmarks {
307 if bm.CommitID == oldHeadID {
308 _ = r.Store.SetBookmark(tx, store.Bookmark{
309 Name: bm.Name,
310 CommitID: snappedID,
311 Remote: bm.Remote,
312 })
313 }
314 }
315 }
316
317 newChangeID, err := r.Store.AllocChangeID(tx)
318 if err != nil {
319 r.Store.Rollback(tx)
320 return nil, object.ZeroID, err
321 }
322
323 sig := object.Signature{Name: r.Cfg.User.Name, Email: r.Cfg.User.Email, Timestamp: now}
324 newDraft := &object.Commit{
325 TreeID: snapped.TreeID,
326 Parents: [][32]byte{snappedID},
327 ChangeID: newChangeID,
328 Author: sig,
329 Committer: sig,
330 Message: "",
331 Phase: object.PhaseDraft,
332 }
333
334 newDraftID, err := repo.WriteCommitTx(r.Store, tx, newDraft)
335 if err != nil {
336 r.Store.Rollback(tx)
337 return nil, object.ZeroID, err
338 }
339
340 if err := r.Store.SetChangeCommit(tx, newChangeID, newDraftID); err != nil {
341 r.Store.Rollback(tx)
342 return nil, object.ZeroID, err
343 }
344
345 after := buildRefState(snappedID, object.FormatChangeID(newChangeID))
346 op := store.Operation{
347 Kind: "snap",
348 Timestamp: now.Unix(),
349 Before: before,
350 After: after,
351 Metadata: "'" + firstLine(snapped.Message) + "'",
352 }
353 if _, err := r.Store.InsertOperation(tx, op); err != nil {
354 r.Store.Rollback(tx)
355 return nil, object.ZeroID, err
356 }
357
358 if err := r.Store.Commit(tx); err != nil {
359 return nil, object.ZeroID, err
360 }
361
362 if err := r.WriteHead(object.FormatChangeID(newChangeID)); err != nil {
363 return nil, object.ZeroID, err
364 }
365
366 if len(r.Cfg.Hooks.PostSnap) > 0 {
367 if err := RunHooksSequential(r.Root, "post-snap", r.Cfg.Hooks.PostSnap); err != nil {
368 fmt.Fprintf(os.Stderr, "arche snap: post-snap hook: %v\n", err)
369 }
370 }
371
372 return snapped, snappedID, nil
373}
374
375func (wc *WC) Status() ([]FileStatus, error) {
376 r := wc.Repo
377 head, _, err := r.HeadCommit()
378 if err != nil {
379 return nil, err
380 }
381
382 headFiles := make(map[string][32]byte)
383 if err := flattenTree(r, head.TreeID, "", headFiles); err != nil {
384 return nil, err
385 }
386
387 wcPaths, err := wc.trackedPaths()
388 if err != nil {
389 return nil, err
390 }
391 wcSet := make(map[string]bool, len(wcPaths))
392 for _, p := range wcPaths {
393 wcSet[p] = true
394 }
395
396 cacheEntries, _ := r.Store.ListWCacheEntries()
397 cacheMap := make(map[string]store.WCacheEntry, len(cacheEntries))
398 for _, e := range cacheEntries {
399 cacheMap[e.Path] = e
400 }
401 dirty, _ := dirtySet(r)
402
403 var out []FileStatus
404
405 for path, headBlobID := range headFiles {
406 if !wcSet[path] {
407 out = append(out, FileStatus{Path: path, Status: 'D'})
408 continue
409 }
410
411 if dirty != nil && !dirty[path] {
412 if cached, ok := cacheMap[path]; ok {
413 if cached.BlobID != headBlobID {
414 out = append(out, FileStatus{Path: path, Status: 'M'})
415 }
416 continue
417 }
418 }
419
420 curBlobID, err := wc.blobIDForPath(path)
421 if err != nil {
422 continue
423 }
424 if curBlobID != headBlobID {
425 out = append(out, FileStatus{Path: path, Status: 'M'})
426 }
427 }
428
429 ignore, _ := loadIgnore(r.Root)
430 for _, path := range wcPaths {
431 if _, inHead := headFiles[path]; !inHead {
432 if ignore.Match(path) {
433 continue
434 }
435 out = append(out, FileStatus{Path: path, Status: 'A'})
436 }
437 }
438
439 sort.Slice(out, func(i, j int) bool { return out[i].Path < out[j].Path })
440 return out, nil
441}
442
443func (wc *WC) materializeDisk(treeID [32]byte) (map[string][32]byte, map[string]object.EntryMode, error) {
444 r := wc.Repo
445
446 wantFiles := make(map[string][32]byte)
447 wantMode := make(map[string]object.EntryMode)
448 if err := flattenTree(r, treeID, "", wantFiles); err != nil {
449 return nil, nil, err
450 }
451
452 if err := flattenTreeModes(r, treeID, "", wantMode); err != nil {
453 return nil, nil, err
454 }
455
456 ignore, _ := loadIgnore(r.Root)
457 err := filepath.WalkDir(r.Root, func(path string, d fs.DirEntry, err error) error {
458 if err != nil {
459 return nil
460 }
461 rel, _ := filepath.Rel(r.Root, path)
462 if rel == "." {
463 return nil
464 }
465 if d.IsDir() {
466 if rel == archeDirName || strings.HasPrefix(rel, archeDirName+string(os.PathSeparator)) {
467 return filepath.SkipDir
468 }
469 return nil
470 }
471 if ignore.Match(rel) {
472 return nil
473 }
474 if _, ok := wantFiles[rel]; !ok {
475 return os.Remove(path)
476 }
477 return nil
478 })
479 if err != nil {
480 return nil, nil, err
481 }
482
483 var conflictPaths []string
484 for relPath, blobID := range wantFiles {
485 abs := filepath.Join(r.Root, relPath)
486 if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil {
487 return nil, nil, err
488 }
489 content, err := r.ReadBlob(blobID)
490 if err != nil {
491 if conf, cErr := r.ReadConflict(blobID); cErr == nil {
492 content = renderConflictMarkers(r, conf)
493 conflictPaths = append(conflictPaths, relPath)
494 err = nil
495 }
496 }
497 if err != nil {
498 return nil, nil, err
499 }
500 perm := fs.FileMode(0o644)
501 if wantMode[relPath] == object.ModeExec {
502 perm = 0o755
503 }
504 if err := os.WriteFile(abs, content, perm); err != nil {
505 return nil, nil, err
506 }
507 }
508
509 for _, p := range conflictPaths {
510 delete(wantFiles, p)
511 }
512
513 return wantFiles, wantMode, nil
514}
515
516func renderConflictMarkers(r *repo.Repo, conf *object.Conflict) []byte {
517 readStr := func(id [32]byte) string {
518 if id == object.ZeroID {
519 return ""
520 }
521 b, _ := r.ReadBlob(id)
522 return string(b)
523 }
524 nl := func(s string) string {
525 if len(s) > 0 && s[len(s)-1] != '\n' {
526 return s + "\n"
527 }
528 return s
529 }
530 if conf.Ours.BlobID == object.ZeroID {
531 return []byte(fmt.Sprintf("<<<<<<< ours\n(deleted)\n=======\n%s>>>>>>> theirs\n", nl(readStr(conf.Theirs.BlobID))))
532 }
533 if conf.Theirs.BlobID == object.ZeroID {
534 return []byte(fmt.Sprintf("<<<<<<< ours\n%s=======\n(deleted)\n>>>>>>> theirs\n", nl(readStr(conf.Ours.BlobID))))
535 }
536 return []byte(fmt.Sprintf("<<<<<<< ours\n%s=======\n%s>>>>>>> theirs\n",
537 nl(readStr(conf.Ours.BlobID)),
538 nl(readStr(conf.Theirs.BlobID))))
539}
540
541func (wc *WC) populateWCacheInTx(tx *store.Tx, wantFiles map[string][32]byte) error {
542 r := wc.Repo
543 if err := r.Store.ClearWCache(tx); err != nil {
544 return err
545 }
546 for relPath, blobID := range wantFiles {
547 abs := filepath.Join(r.Root, relPath)
548 info, err := os.Lstat(abs)
549 if err != nil {
550 continue
551 }
552 st, ok := info.Sys().(*syscall.Stat_t)
553 if !ok {
554 continue
555 }
556 _ = r.Store.SetWCacheEntry(tx, store.WCacheEntry{
557 Path: relPath,
558 Inode: st.Ino,
559 MtimeNs: info.ModTime().UnixNano(),
560 Size: info.Size(),
561 BlobID: blobID,
562 Mode: uint8(fileMode(info)),
563 })
564 }
565 return nil
566}
567
568func (wc *WC) MaterializeQuiet(treeID [32]byte) error {
569 r := wc.Repo
570
571 wantFiles, _, err := wc.materializeDisk(treeID)
572 if err != nil {
573 return err
574 }
575
576 tx, err := r.Store.Begin()
577 if err != nil {
578 return err
579 }
580 if err := wc.populateWCacheInTx(tx, wantFiles); err != nil {
581 r.Store.Rollback(tx)
582 return err
583 }
584 return r.Store.Commit(tx)
585}
586
587func (wc *WC) Materialize(treeID [32]byte, newChangeID string) error {
588 r := wc.Repo
589
590 before, _ := r.CaptureRefState()
591 now := time.Now()
592
593 wantFiles, _, err := wc.materializeDisk(treeID)
594 if err != nil {
595 return err
596 }
597
598 bare := object.StripChangeIDPrefix(newChangeID)
599 commitID, _ := r.Store.GetChangeCommit(bare)
600 after := buildRefState(commitID, newChangeID)
601
602 tx, err := r.Store.Begin()
603 if err != nil {
604 return err
605 }
606 if err := wc.populateWCacheInTx(tx, wantFiles); err != nil {
607 r.Store.Rollback(tx)
608 return err
609 }
610
611 op := store.Operation{
612 Kind: "co",
613 Timestamp: now.Unix(),
614 Before: before,
615 After: after,
616 Metadata: "checked out " + newChangeID,
617 }
618 if _, err := r.Store.InsertOperation(tx, op); err != nil {
619 r.Store.Rollback(tx)
620 return err
621 }
622
623 return r.Store.Commit(tx)
624}
625
626const archeDirName = ".arche"
627
628func (wc *WC) trackedPaths() ([]string, error) {
629 r := wc.Repo
630 ignore, _ := loadIgnore(r.Root)
631
632 var paths []string
633 err := filepath.WalkDir(r.Root, func(path string, d fs.DirEntry, err error) error {
634 if err != nil {
635 return nil
636 }
637 rel, _ := filepath.Rel(r.Root, path)
638 if rel == "." {
639 return nil
640 }
641 if d.IsDir() {
642 if rel == archeDirName || strings.HasPrefix(rel, archeDirName+string(os.PathSeparator)) {
643 return filepath.SkipDir
644 }
645 if ignore.MatchDir(rel) {
646 return filepath.SkipDir
647 }
648 return nil
649 }
650 if ignore.Match(rel) {
651 return nil
652 }
653 paths = append(paths, filepath.ToSlash(rel))
654 return nil
655 })
656 return paths, err
657}
658
659func (wc *WC) blobIDForPath(rel string) ([32]byte, error) {
660 r := wc.Repo
661 abs := filepath.Join(r.Root, rel)
662 info, err := os.Lstat(abs)
663 if err != nil {
664 return object.ZeroID, err
665 }
666 st := info.Sys().(*syscall.Stat_t)
667
668 if cached, _ := r.Store.GetWCacheEntry(rel); cached != nil {
669 if cached.Inode == st.Ino &&
670 cached.MtimeNs == info.ModTime().UnixNano() &&
671 cached.Size == info.Size() {
672 return cached.BlobID, nil
673 }
674 }
675
676 content, err := readFileContent(abs, info)
677 if err != nil {
678 return object.ZeroID, err
679 }
680 b := &object.Blob{Content: content}
681 return object.HashBlob(b), nil
682}
683
684func flattenTree(r *repo.Repo, treeID [32]byte, prefix string, out map[string][32]byte) error {
685 if treeID == object.ZeroID {
686 return nil
687 }
688 t, err := r.ReadTree(treeID)
689 if err != nil {
690 return err
691 }
692 for _, e := range t.Entries {
693 rel := join(prefix, e.Name)
694 switch e.Mode {
695 case object.ModeDir:
696 if err := flattenTree(r, e.ObjectID, rel, out); err != nil {
697 return err
698 }
699 default:
700 out[rel] = e.ObjectID
701 }
702 }
703 return nil
704}
705
706func flattenTreeModes(r *repo.Repo, treeID [32]byte, prefix string, out map[string]object.EntryMode) error {
707 if treeID == object.ZeroID {
708 return nil
709 }
710 t, err := r.ReadTree(treeID)
711 if err != nil {
712 return err
713 }
714 for _, e := range t.Entries {
715 rel := join(prefix, e.Name)
716 switch e.Mode {
717 case object.ModeDir:
718 if err := flattenTreeModes(r, e.ObjectID, rel, out); err != nil {
719 return err
720 }
721 default:
722 out[rel] = e.Mode
723 }
724 }
725 return nil
726}
727
728type fileEntry struct {
729 path string
730 blobID [32]byte
731 mode object.EntryMode
732}
733
734func buildTree(r *repo.Repo, tx *store.Tx, entries []fileEntry) ([32]byte, error) {
735 type node struct {
736 isFile bool
737 blobID [32]byte
738 mode object.EntryMode
739 children map[string]*node
740 }
741 root := &node{children: make(map[string]*node)}
742
743 for _, e := range entries {
744 parts := strings.Split(e.path, "/")
745 cur := root
746 for i, part := range parts {
747 if i == len(parts)-1 {
748 cur.children[part] = &node{isFile: true, blobID: e.blobID, mode: e.mode}
749 } else {
750 if _, ok := cur.children[part]; !ok {
751 cur.children[part] = &node{children: make(map[string]*node)}
752 }
753 cur = cur.children[part]
754 }
755 }
756 }
757
758 var writeNode func(n *node) ([32]byte, error)
759 writeNode = func(n *node) ([32]byte, error) {
760 var treeEntries []object.TreeEntry
761 for name, child := range n.children {
762 if child.isFile {
763 treeEntries = append(treeEntries, object.TreeEntry{
764 Name: name,
765 Mode: child.mode,
766 ObjectID: child.blobID,
767 })
768 } else {
769 subID, err := writeNode(child)
770 if err != nil {
771 return object.ZeroID, err
772 }
773 treeEntries = append(treeEntries, object.TreeEntry{
774 Name: name,
775 Mode: object.ModeDir,
776 ObjectID: subID,
777 })
778 }
779 }
780 sort.Slice(treeEntries, func(i, j int) bool { return treeEntries[i].Name < treeEntries[j].Name })
781 t := &object.Tree{Entries: treeEntries}
782 id, err := repo.WriteTreeTx(r.Store, tx, t)
783 return id, err
784 }
785
786 return writeNode(root)
787}
788
789func fileMode(info os.FileInfo) object.EntryMode {
790 if info.Mode()&0o111 != 0 {
791 return object.ModeExec
792 }
793 if info.Mode()&os.ModeSymlink != 0 {
794 return object.ModeSymlink
795 }
796 return object.ModeFile
797}
798
799func readFileContent(abs string, info os.FileInfo) ([]byte, error) {
800 if info.Mode()&os.ModeSymlink != 0 {
801 target, err := os.Readlink(abs)
802 if err != nil {
803 return nil, err
804 }
805 return []byte(target), nil
806 }
807 return os.ReadFile(abs)
808}
809
810func join(prefix, name string) string {
811 if prefix == "" {
812 return name
813 }
814 return prefix + "/" + name
815}
816
817func buildRefState(commitID [32]byte, changeID string) string {
818 m := map[string]string{
819 "head": changeID,
820 "tip": fmt.Sprintf("%x", commitID),
821 }
822 b, _ := json.Marshal(m)
823 return string(b)
824}
825
826func firstLine(s string) string {
827 if i := strings.IndexByte(s, '\n'); i >= 0 {
828 return s[:i]
829 }
830 return s
831}
832
833func (wc *WC) Amend(message string) (*object.Commit, [32]byte, error) {
834 r := wc.Repo
835 now := time.Now()
836
837 head, oldHeadID, err := r.HeadCommit()
838 if err != nil {
839 return nil, object.ZeroID, err
840 }
841 if head.Phase == object.PhasePublic {
842 return nil, object.ZeroID, fmt.Errorf("cannot amend a public commit; use --force-rewrite if you are sure")
843 }
844
845 before, err := r.CaptureRefState()
846 if err != nil {
847 return nil, object.ZeroID, err
848 }
849
850 if message == "" {
851 message = head.Message
852 }
853
854 paths, cacheMap, dirty, err := wc.snapshotInput()
855 if err != nil {
856 return nil, object.ZeroID, err
857 }
858
859 tx, err := r.Store.Begin()
860 if err != nil {
861 return nil, object.ZeroID, err
862 }
863
864 amended, amendedID, err := wc.snapshotIntoTx(tx, head, paths, cacheMap, dirty, message, now)
865 if err != nil {
866 r.Store.Rollback(tx)
867 return nil, object.ZeroID, err
868 }
869
870 if oldHeadID != amendedID {
871 obs := &object.ObsoleteMarker{
872 Predecessor: oldHeadID,
873 Successors: [][32]byte{amendedID},
874 Reason: "amend",
875 Timestamp: now.Unix(),
876 }
877 if _, err := repo.WriteObsoleteTx(r.Store, tx, obs); err != nil {
878 r.Store.Rollback(tx)
879 return nil, object.ZeroID, err
880 }
881 }
882
883 after := buildRefState(amendedID, object.FormatChangeID(amended.ChangeID))
884 op := store.Operation{
885 Kind: "amend",
886 Timestamp: now.Unix(),
887 Before: before,
888 After: after,
889 Metadata: "'" + firstLine(amended.Message) + "'",
890 }
891 if _, err := r.Store.InsertOperation(tx, op); err != nil {
892 r.Store.Rollback(tx)
893 return nil, object.ZeroID, err
894 }
895
896 if err := r.Store.Commit(tx); err != nil {
897 return nil, object.ZeroID, err
898 }
899
900 if oldHeadID != amendedID {
901 if err := wc.autoRebaseDownstream(oldHeadID, amendedID, head.ChangeID, now); err != nil {
902 fmt.Fprintf(os.Stderr, "arche: warning: downstream rebase failed: %v\n", err)
903 }
904 }
905
906 return amended, amendedID, nil
907}
908
909func (wc *WC) autoRebaseDownstream(oldParentID, newParentID [32]byte, headChangeID string, now time.Time) error {
910 r := wc.Repo
911
912 allChanges, err := r.Store.ListChanges()
913 if err != nil {
914 return err
915 }
916
917 type draftEntry struct {
918 id [32]byte
919 changeID string
920 commit *object.Commit
921 }
922
923 children := make(map[[32]byte][]draftEntry)
924 for _, ch := range allChanges {
925 if ch.CommitID == object.ZeroID {
926 continue
927 }
928 c, err := r.ReadCommit(ch.CommitID)
929 if err != nil || c == nil {
930 continue
931 }
932 if c.Phase != object.PhaseDraft {
933 continue
934 }
935 if c.ChangeID == headChangeID {
936 continue
937 }
938 if len(c.Parents) == 0 {
939 continue
940 }
941 d := draftEntry{id: ch.CommitID, changeID: ch.Name, commit: c}
942 children[c.Parents[0]] = append(children[c.Parents[0]], d)
943 }
944
945 type rebaseTask struct {
946 entry draftEntry
947 newParent [32]byte
948 }
949 var tasks []rebaseTask
950 queue := []struct {
951 oldID [32]byte
952 newID [32]byte
953 }{{oldParentID, newParentID}}
954
955 for len(queue) > 0 {
956 cur := queue[0]
957 queue = queue[1:]
958 for _, child := range children[cur.oldID] {
959 tasks = append(tasks, rebaseTask{entry: child, newParent: cur.newID})
960 queue = append(queue, struct{ oldID, newID [32]byte }{child.id, child.id})
961 }
962 }
963
964 remapped := map[[32]byte][32]byte{oldParentID: newParentID}
965
966 for _, task := range tasks {
967 oldFirst := task.entry.commit.Parents[0]
968 newParent, ok := remapped[oldFirst]
969 if !ok {
970 newParent = oldFirst
971 }
972
973 var baseTreeID [32]byte
974 if pc, err2 := r.ReadCommit(oldFirst); err2 == nil {
975 baseTreeID = pc.TreeID
976 }
977 newParentCommit, err := r.ReadCommit(newParent)
978 if err != nil {
979 return fmt.Errorf("read new parent for %s: %w", object.FormatChangeID(task.entry.changeID), err)
980 }
981
982 result, err := merge.Trees(r, baseTreeID, task.entry.commit.TreeID, newParentCommit.TreeID)
983 if err != nil {
984 return fmt.Errorf("merge for %s: %w", object.FormatChangeID(task.entry.changeID), err)
985 }
986
987 newCommit := &object.Commit{
988 TreeID: result.TreeID,
989 Parents: [][32]byte{newParent},
990 ChangeID: task.entry.changeID,
991 Author: task.entry.commit.Author,
992 Committer: object.Signature{Name: r.Cfg.User.Name, Email: r.Cfg.User.Email, Timestamp: now},
993 Message: task.entry.commit.Message,
994 Phase: task.entry.commit.Phase,
995 }
996
997 tx, err := r.Store.Begin()
998 if err != nil {
999 return err
1000 }
1001 newCommitID, err := repo.WriteCommitTx(r.Store, tx, newCommit)
1002 if err != nil {
1003 r.Store.Rollback(tx)
1004 return err
1005 }
1006 if err := r.Store.SetChangeCommit(tx, task.entry.changeID, newCommitID); err != nil {
1007 r.Store.Rollback(tx)
1008 return err
1009 }
1010 obs := &object.ObsoleteMarker{
1011 Predecessor: task.entry.id,
1012 Successors: [][32]byte{newCommitID},
1013 Reason: "amend",
1014 Timestamp: now.Unix(),
1015 }
1016 if _, err := repo.WriteObsoleteTx(r.Store, tx, obs); err != nil {
1017 r.Store.Rollback(tx)
1018 return err
1019 }
1020 if err := r.Store.Commit(tx); err != nil {
1021 return err
1022 }
1023
1024 remapped[task.entry.id] = newCommitID
1025 conflictNote := ""
1026 if len(result.Conflicts) > 0 {
1027 conflictNote = fmt.Sprintf(" (%d conflict(s))", len(result.Conflicts))
1028 }
1029 fmt.Printf(" auto-rebased %s%s\n", object.FormatChangeID(task.entry.changeID), conflictNote)
1030 }
1031 return nil
1032}