1package merge
2
3import (
4 "bytes"
5 "fmt"
6 "sort"
7 "strings"
8
9 "arche/internal/object"
10 "arche/internal/repo"
11 "arche/internal/store"
12
13 "github.com/sergi/go-diff/diffmatchpatch"
14)
15
16type MergeResult struct {
17 Content string
18 Clean bool
19 Ours string
20 Theirs string
21}
22
23func MergeText(base, ours, theirs string) MergeResult {
24 if ours == theirs {
25 return MergeResult{Content: ours, Clean: true}
26 }
27 if base == ours {
28 return MergeResult{Content: theirs, Clean: true}
29 }
30 if base == theirs {
31 return MergeResult{Content: ours, Clean: true}
32 }
33
34 dmp := diffmatchpatch.New()
35
36 patches := dmp.PatchMake(base, ours)
37 result, applied := dmp.PatchApply(patches, theirs)
38
39 allApplied := true
40 for _, ok := range applied {
41 if !ok {
42 allApplied = false
43 break
44 }
45 }
46
47 if allApplied && !strings.Contains(result, "<<<<<<<") {
48 return MergeResult{Content: result, Clean: true}
49 }
50
51 conflict := fmt.Sprintf(
52 "<<<<<<< ours\n%s=======\n%s>>>>>>> theirs\n",
53 ensureNewline(ours),
54 ensureNewline(theirs),
55 )
56 return MergeResult{
57 Content: conflict,
58 Clean: false,
59 Ours: ours,
60 Theirs: theirs,
61 }
62}
63
64type TreeMergeResult struct {
65 TreeID [32]byte
66 Conflicts []string
67}
68
69func Trees(r *repo.Repo, base, ours, theirs [32]byte) (*TreeMergeResult, error) {
70 baseFiles := make(map[string][32]byte)
71 oursFiles := make(map[string][32]byte)
72 theirsFiles := make(map[string][32]byte)
73 oursMode := make(map[string]object.EntryMode)
74 theirsMode := make(map[string]object.EntryMode)
75
76 if base != object.ZeroID {
77 if err := flattenTree(r, base, "", baseFiles, nil); err != nil {
78 return nil, fmt.Errorf("merge base: %w", err)
79 }
80 }
81 if err := flattenTree(r, ours, "", oursFiles, oursMode); err != nil {
82 return nil, fmt.Errorf("merge ours: %w", err)
83 }
84 if err := flattenTree(r, theirs, "", theirsFiles, theirsMode); err != nil {
85 return nil, fmt.Errorf("merge theirs: %w", err)
86 }
87
88 allPaths := make(map[string]bool)
89 for p := range oursFiles {
90 allPaths[p] = true
91 }
92 for p := range theirsFiles {
93 allPaths[p] = true
94 }
95 for p := range baseFiles {
96 allPaths[p] = true
97 }
98
99 tx, err := r.Store.Begin()
100 if err != nil {
101 return nil, err
102 }
103
104 var mergedFiles []mergedFile
105 var conflictPaths []string
106
107 for path := range allPaths {
108 bBase := baseFiles[path]
109 bOurs := oursFiles[path]
110 bTheirs := theirsFiles[path]
111
112 switch {
113 case bOurs == bTheirs:
114 if bOurs != object.ZeroID {
115 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bOurs, mode: oursMode[path]})
116 }
117
118 case bBase == bOurs:
119 if bTheirs != object.ZeroID {
120 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bTheirs, mode: theirsMode[path]})
121 }
122
123 case bBase == bTheirs:
124 if bOurs != object.ZeroID {
125 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: bOurs, mode: oursMode[path]})
126 }
127
128 default:
129 if bOurs == object.ZeroID || bTheirs == object.ZeroID {
130 conflictPaths = append(conflictPaths, path)
131 conf := &object.Conflict{
132 Ours: object.ConflictSide{BlobID: bOurs},
133 Theirs: object.ConflictSide{BlobID: bTheirs},
134 }
135 if bBase != object.ZeroID {
136 conf.Base = &object.ConflictSide{BlobID: bBase}
137 }
138 conflictID, err := repo.WriteConflictTx(r.Store, tx, conf)
139 if err != nil {
140 r.Store.Rollback(tx)
141 return nil, err
142 }
143 if err := r.Store.AddConflict(tx, path); err != nil {
144 r.Store.Rollback(tx)
145 return nil, err
146 }
147 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: conflictID, mode: object.ModeFile})
148 continue
149 }
150
151 baseContent, _ := readBlobStr(r, bBase)
152 oursContent, _ := readBlobStr(r, bOurs)
153 theirsContent, _ := readBlobStr(r, bTheirs)
154
155 result := MergeText(baseContent, oursContent, theirsContent)
156 if !result.Clean {
157 conflictPaths = append(conflictPaths, path)
158 conf := &object.Conflict{
159 Ours: object.ConflictSide{BlobID: bOurs},
160 Theirs: object.ConflictSide{BlobID: bTheirs},
161 }
162 if bBase != object.ZeroID {
163 conf.Base = &object.ConflictSide{BlobID: bBase}
164 }
165 conflictID, err := repo.WriteConflictTx(r.Store, tx, conf)
166 if err != nil {
167 r.Store.Rollback(tx)
168 return nil, err
169 }
170 if err := r.Store.AddConflict(tx, path); err != nil {
171 r.Store.Rollback(tx)
172 return nil, err
173 }
174 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: conflictID, mode: oursMode[path]})
175 } else {
176 id, err := writeBlobTx(r, tx, []byte(result.Content))
177 if err != nil {
178 r.Store.Rollback(tx)
179 return nil, err
180 }
181 mergedFiles = append(mergedFiles, mergedFile{path: path, blobID: id, mode: oursMode[path]})
182 }
183 }
184 }
185
186 rootID, err := buildMergeTree(r, tx, mergedFiles)
187 if err != nil {
188 r.Store.Rollback(tx)
189 return nil, err
190 }
191
192 if err := r.Store.Commit(tx); err != nil {
193 return nil, err
194 }
195 return &TreeMergeResult{TreeID: rootID, Conflicts: conflictPaths}, nil
196}
197
198func readBlobStr(r *repo.Repo, id [32]byte) (string, error) {
199 if id == object.ZeroID {
200 return "", nil
201 }
202 content, err := r.ReadBlob(id)
203 if err != nil {
204 return "", err
205 }
206 return string(content), nil
207}
208
209func writeBlobTx(r *repo.Repo, tx *store.Tx, content []byte) ([32]byte, error) {
210 b := &object.Blob{Content: content}
211 var buf bytes.Buffer
212 object.EncodeBlob(&buf, b)
213 id := object.HashBlob(b)
214 return id, r.Store.WriteObject(tx, id, string(object.KindBlob), buf.Bytes())
215}
216
217type mergedFile struct {
218 path string
219 blobID [32]byte
220 mode object.EntryMode
221}
222
223func buildMergeTree(r *repo.Repo, tx *store.Tx, files []mergedFile) ([32]byte, error) {
224 type node struct {
225 isFile bool
226 blobID [32]byte
227 mode object.EntryMode
228 children map[string]*node
229 }
230 root := &node{children: make(map[string]*node)}
231
232 for _, f := range files {
233 parts := strings.Split(f.path, "/")
234 cur := root
235 for i, part := range parts {
236 if i == len(parts)-1 {
237 cur.children[part] = &node{isFile: true, blobID: f.blobID, mode: f.mode}
238 } else {
239 if _, ok := cur.children[part]; !ok {
240 cur.children[part] = &node{children: make(map[string]*node)}
241 }
242 cur = cur.children[part]
243 }
244 }
245 }
246
247 var writeNode func(n *node) ([32]byte, error)
248 writeNode = func(n *node) ([32]byte, error) {
249 var entries []object.TreeEntry
250 for name, child := range n.children {
251 if child.isFile {
252 entries = append(entries, object.TreeEntry{Name: name, Mode: child.mode, ObjectID: child.blobID})
253 } else {
254 subID, err := writeNode(child)
255 if err != nil {
256 return object.ZeroID, err
257 }
258 entries = append(entries, object.TreeEntry{Name: name, Mode: object.ModeDir, ObjectID: subID})
259 }
260 }
261 sort.Slice(entries, func(i, j int) bool { return entries[i].Name < entries[j].Name })
262 t := &object.Tree{Entries: entries}
263 id, err := repo.WriteTreeTx(r.Store, tx, t)
264 return id, err
265 }
266
267 return writeNode(root)
268}
269
270func flattenTree(r *repo.Repo, treeID [32]byte, prefix string, blobs map[string][32]byte, modes map[string]object.EntryMode) error {
271 if treeID == object.ZeroID {
272 return nil
273 }
274 t, err := r.ReadTree(treeID)
275 if err != nil {
276 return err
277 }
278 for _, e := range t.Entries {
279 rel := joinPath(prefix, e.Name)
280 if e.Mode == object.ModeDir {
281 if err := flattenTree(r, e.ObjectID, rel, blobs, modes); err != nil {
282 return err
283 }
284 } else {
285 blobs[rel] = e.ObjectID
286 if modes != nil {
287 modes[rel] = e.Mode
288 }
289 }
290 }
291 return nil
292}
293
294func joinPath(prefix, name string) string {
295 if prefix == "" {
296 return name
297 }
298 return prefix + "/" + name
299}
300
301func ensureNewline(s string) string {
302 if s != "" && !strings.HasSuffix(s, "\n") {
303 return s + "\n"
304 }
305 return s
306}