arche / internal/merge/merge.go

commit 154431fd
  1package merge
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"sort"
  7	"strings"
  8
  9	"arche/internal/object"
 10	"arche/internal/repo"
 11	"arche/internal/store"
 12
 13	"github.com/sergi/go-diff/diffmatchpatch"
 14)
 15
 16type MergeResult struct {
 17	Content string
 18	Clean   bool
 19	Ours    string
 20	Theirs  string
 21}
 22
 23func MergeText(base, ours, theirs string) MergeResult {
 24	if ours == theirs {
 25		return MergeResult{Content: ours, Clean: true}
 26	}
 27	if base == ours {
 28		return MergeResult{Content: theirs, Clean: true}
 29	}
 30	if base == theirs {
 31		return MergeResult{Content: ours, Clean: true}
 32	}
 33
 34	dmp := diffmatchpatch.New()
 35
 36	patches := dmp.PatchMake(base, ours)
 37	result, applied := dmp.PatchApply(patches, theirs)
 38
 39	allApplied := true
 40	for _, ok := range applied {
 41		if !ok {
 42			allApplied = false
 43			break
 44		}
 45	}
 46
 47	if allApplied && !strings.Contains(result, "<<<<<<<") {
 48		return MergeResult{Content: result, Clean: true}
 49	}
 50
 51	conflict := fmt.Sprintf(
 52		"<<<<<<< ours\n%s=======\n%s>>>>>>> theirs\n",
 53		ensureNewline(ours),
 54		ensureNewline(theirs),
 55	)
 56	return MergeResult{
 57		Content: conflict,
 58		Clean:   false,
 59		Ours:    ours,
 60		Theirs:  theirs,
 61	}
 62}
 63
 64type TreeMergeResult struct {
 65	TreeID    [32]byte
 66	Conflicts []string
 67}
 68
 69func Trees(r *repo.Repo, base, ours, theirs [32]byte) (*TreeMergeResult, error) {
 70	baseFiles := make(map[string][32]byte)
 71	oursFiles := make(map[string][32]byte)
 72	theirsFiles := make(map[string][32]byte)
 73	oursMode := make(map[string]object.EntryMode)
 74	theirsMode := make(map[string]object.EntryMode)
 75
 76	if base != object.ZeroID {
 77		if err := flattenTree(r, base, "", baseFiles, nil); err != nil {
 78			return nil, fmt.Errorf("merge base: %w", err)
 79		}
 80	}
 81	if err := flattenTree(r, ours, "", oursFiles, oursMode); err != nil {
 82		return nil, fmt.Errorf("merge ours: %w", err)
 83	}
 84	if err := flattenTree(r, theirs, "", theirsFiles, theirsMode); err != nil {
 85		return nil, fmt.Errorf("merge theirs: %w", err)
 86	}
 87
 88	allPaths := make(map[string]bool)
 89	for p := range oursFiles {
 90		allPaths[p] = true
 91	}
 92	for p := range theirsFiles {
 93		allPaths[p] = true
 94	}
 95	for p := range baseFiles {
 96		allPaths[p] = true
 97	}
 98
 99	tx, err := r.Store.Begin()
100	if err != nil {
101		return nil, err
102	}
103
104	var mergedFiles []mergedFile
105	var conflictPaths []string
106
107	for path := range allPaths {
108		bBase := baseFiles[path]
109		bOurs := oursFiles[path]
110		bTheirs := theirsFiles[path]
111
112		switch {
113		case bOurs == bTheirs:
114			if bOurs != object.ZeroID {
115				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bOurs, mode: oursMode[path]})
116			}
117
118		case bBase == bOurs:
119			if bTheirs != object.ZeroID {
120				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bTheirs, mode: theirsMode[path]})
121			}
122
123		case bBase == bTheirs:
124			if bOurs != object.ZeroID {
125				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bOurs, mode: oursMode[path]})
126			}
127
128		default:
129			if bOurs == object.ZeroID || bTheirs == object.ZeroID {
130				conflictPaths = append(conflictPaths, path)
131				conf := &object.Conflict{
132					Ours:   object.ConflictSide{BlobID: bOurs},
133					Theirs: object.ConflictSide{BlobID: bTheirs},
134				}
135				if bBase != object.ZeroID {
136					conf.Base = &object.ConflictSide{BlobID: bBase}
137				}
138				conflictID, err := repo.WriteConflictTx(r.Store, tx, conf)
139				if err != nil {
140					r.Store.Rollback(tx)
141					return nil, err
142				}
143				if err := r.Store.AddConflict(tx, path); err != nil {
144					r.Store.Rollback(tx)
145					return nil, err
146				}
147				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: conflictID, mode: object.ModeFile})
148				continue
149			}
150
151			baseContent, _ := readBlobStr(r, bBase)
152			oursContent, _ := readBlobStr(r, bOurs)
153			theirsContent, _ := readBlobStr(r, bTheirs)
154
155			result := MergeText(baseContent, oursContent, theirsContent)
156			if !result.Clean {
157				conflictPaths = append(conflictPaths, path)
158				conf := &object.Conflict{
159					Ours:   object.ConflictSide{BlobID: bOurs},
160					Theirs: object.ConflictSide{BlobID: bTheirs},
161				}
162				if bBase != object.ZeroID {
163					conf.Base = &object.ConflictSide{BlobID: bBase}
164				}
165				conflictID, err := repo.WriteConflictTx(r.Store, tx, conf)
166				if err != nil {
167					r.Store.Rollback(tx)
168					return nil, err
169				}
170				if err := r.Store.AddConflict(tx, path); err != nil {
171					r.Store.Rollback(tx)
172					return nil, err
173				}
174				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: conflictID, mode: oursMode[path]})
175			} else {
176				id, err := writeBlobTx(r, tx, []byte(result.Content))
177				if err != nil {
178					r.Store.Rollback(tx)
179					return nil, err
180				}
181				mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: id, mode: oursMode[path]})
182			}
183		}
184	}
185
186	rootID, err := buildMergeTree(r, tx, mergedFiles)
187	if err != nil {
188		r.Store.Rollback(tx)
189		return nil, err
190	}
191
192	if err := r.Store.Commit(tx); err != nil {
193		return nil, err
194	}
195	return &TreeMergeResult{TreeID: rootID, Conflicts: conflictPaths}, nil
196}
197
198func readBlobStr(r *repo.Repo, id [32]byte) (string, error) {
199	if id == object.ZeroID {
200		return "", nil
201	}
202	content, err := r.ReadBlob(id)
203	if err != nil {
204		return "", err
205	}
206	return string(content), nil
207}
208
209func writeBlobTx(r *repo.Repo, tx *store.Tx, content []byte) ([32]byte, error) {
210	b := &object.Blob{Content: content}
211	var buf bytes.Buffer
212	object.EncodeBlob(&buf, b)
213	id := object.HashBlob(b)
214	return id, r.Store.WriteObject(tx, id, string(object.KindBlob), buf.Bytes())
215}
216
217type mergedFile struct {
218	path   string
219	blobID [32]byte
220	mode   object.EntryMode
221}
222
223func buildMergeTree(r *repo.Repo, tx *store.Tx, files []mergedFile) ([32]byte, error) {
224	type node struct {
225		isFile   bool
226		blobID   [32]byte
227		mode     object.EntryMode
228		children map[string]*node
229	}
230	root := &node{children: make(map[string]*node)}
231
232	for _, f := range files {
233		parts := strings.Split(f.path, "/")
234		cur := root
235		for i, part := range parts {
236			if i == len(parts)-1 {
237				cur.children[part] = &node{isFile: true, blobID: f.blobID, mode: f.mode}
238			} else {
239				if _, ok := cur.children[part]; !ok {
240					cur.children[part] = &node{children: make(map[string]*node)}
241				}
242				cur = cur.children[part]
243			}
244		}
245	}
246
247	var writeNode func(n *node) ([32]byte, error)
248	writeNode = func(n *node) ([32]byte, error) {
249		var entries []object.TreeEntry
250		for name, child := range n.children {
251			if child.isFile {
252				entries = append(entries, object.TreeEntry{Name: name, Mode: child.mode, ObjectID: child.blobID})
253			} else {
254				subID, err := writeNode(child)
255				if err != nil {
256					return object.ZeroID, err
257				}
258				entries = append(entries, object.TreeEntry{Name: name, Mode: object.ModeDir, ObjectID: subID})
259			}
260		}
261		sort.Slice(entries, func(i, j int) bool { return entries[i].Name < entries[j].Name })
262		t := &object.Tree{Entries: entries}
263		id, err := repo.WriteTreeTx(r.Store, tx, t)
264		return id, err
265	}
266
267	return writeNode(root)
268}
269
270func flattenTree(r *repo.Repo, treeID [32]byte, prefix string, blobs map[string][32]byte, modes map[string]object.EntryMode) error {
271	if treeID == object.ZeroID {
272		return nil
273	}
274	t, err := r.ReadTree(treeID)
275	if err != nil {
276		return err
277	}
278	for _, e := range t.Entries {
279		rel := joinPath(prefix, e.Name)
280		if e.Mode == object.ModeDir {
281			if err := flattenTree(r, e.ObjectID, rel, blobs, modes); err != nil {
282				return err
283			}
284		} else {
285			blobs[rel] = e.ObjectID
286			if modes != nil {
287				modes[rel] = e.Mode
288			}
289		}
290	}
291	return nil
292}
293
294func joinPath(prefix, name string) string {
295	if prefix == "" {
296		return name
297	}
298	return prefix + "/" + name
299}
300
301func ensureNewline(s string) string {
302	if s != "" && !strings.HasSuffix(s, "\n") {
303		return s + "\n"
304	}
305	return s
306}