arche / internal/archesrv/ssh.go

commit 154431fd
  1package archesrv
  2
  3import (
  4	"crypto/ecdsa"
  5	"crypto/elliptic"
  6	"crypto/rand"
  7	"encoding/hex"
  8	"fmt"
  9	"io"
 10	"net"
 11	"net/http"
 12	"os"
 13	"path/filepath"
 14	"strings"
 15
 16	"arche/internal/syncpkg"
 17
 18	"golang.org/x/crypto/ssh"
 19)
 20
 21func (s *forgeServer) RunSSH(listenAddr string) error {
 22	hostKey, err := s.loadOrCreateHostKey()
 23	if err != nil {
 24		return fmt.Errorf("host key: %w", err)
 25	}
 26
 27	cfg := &ssh.ServerConfig{
 28		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
 29			user, err := s.db.AuthorizeSSHKey(key)
 30			if err != nil || user == nil {
 31				return nil, fmt.Errorf("unauthorized")
 32			}
 33			fp := hex.EncodeToString(key.Marshal())
 34			return &ssh.Permissions{
 35				Extensions: map[string]string{
 36					"user-id":  fmt.Sprintf("%d", user.ID),
 37					"username": user.Username,
 38					"is-admin": fmt.Sprintf("%v", user.IsAdmin),
 39					"key-fp":   fp,
 40				},
 41			}, nil
 42		},
 43	}
 44	cfg.AddHostKey(hostKey)
 45
 46	ln, err := net.Listen("tcp", listenAddr)
 47	if err != nil {
 48		return fmt.Errorf("ssh listen %s: %w", listenAddr, err)
 49	}
 50	defer ln.Close()
 51	s.log.Info("SSH listening", "addr", listenAddr)
 52
 53	for {
 54		conn, err := ln.Accept()
 55		if err != nil {
 56			return err
 57		}
 58		go s.handleSSHConn(conn, cfg)
 59	}
 60}
 61
 62func (s *forgeServer) handleSSHConn(netConn net.Conn, cfg *ssh.ServerConfig) {
 63	defer netConn.Close()
 64
 65	sshConn, chans, reqs, err := ssh.NewServerConn(netConn, cfg)
 66	if err != nil {
 67		return
 68	}
 69	defer sshConn.Close()
 70	go ssh.DiscardRequests(reqs)
 71
 72	for newChan := range chans {
 73		if newChan.ChannelType() != "session" {
 74			newChan.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
 75			continue
 76		}
 77		ch, requests, err := newChan.Accept()
 78		if err != nil {
 79			return
 80		}
 81		go s.handleSSHSession(sshConn, ch, requests)
 82	}
 83}
 84
 85func (s *forgeServer) handleSSHSession(conn *ssh.ServerConn, ch ssh.Channel, requests <-chan *ssh.Request) {
 86	defer ch.Close()
 87
 88	for req := range requests {
 89		if req.Type != "exec" {
 90			if req.WantReply {
 91				req.Reply(false, nil) //nolint:errcheck
 92			}
 93			continue
 94		}
 95
 96		if len(req.Payload) < 4 {
 97			req.Reply(false, nil) //nolint:errcheck
 98			return
 99		}
100		cmdLen := int(req.Payload[0])<<24 | int(req.Payload[1])<<16 | int(req.Payload[2])<<8 | int(req.Payload[3])
101		if len(req.Payload) < 4+cmdLen {
102			req.Reply(false, nil) //nolint:errcheck
103			return
104		}
105		cmd := string(req.Payload[4 : 4+cmdLen])
106		req.Reply(true, nil) //nolint:errcheck
107
108		exitCode := s.execSSHCommand(conn, ch, cmd)
109		exitPayload := []byte{0, 0, 0, byte(exitCode)}
110		ch.SendRequest("exit-status", false, exitPayload) //nolint:errcheck
111		return
112	}
113}
114
115func (s *forgeServer) execSSHCommand(conn *ssh.ServerConn, ch ssh.Channel, cmd string) int {
116	const prefix = "arche-sync "
117	if !strings.HasPrefix(cmd, prefix) {
118		fmt.Fprintf(ch.Stderr(), "arche-server: unknown command %q\n", cmd)
119		return 1
120	}
121	repoName := strings.TrimSpace(strings.TrimPrefix(cmd, prefix))
122	repoName = strings.Trim(repoName, "/'\"")
123
124	rec, err := s.db.GetRepo(repoName)
125	if err != nil || rec == nil {
126		fmt.Fprintf(ch.Stderr(), "arche-server: repository %q not found\n", repoName)
127		return 1
128	}
129
130	username := conn.Permissions.Extensions["username"]
131	isAdmin := conn.Permissions.Extensions["is-admin"] == "true"
132	var user *User
133	if username != "" {
134		user = &User{Username: username, IsAdmin: isAdmin}
135	}
136	canRead := s.db.CanRead(rec, user)
137	canWrite := s.db.CanWrite(rec, user)
138	if !canRead {
139		fmt.Fprintf(ch.Stderr(), "arche-server: access denied\n")
140		return 1
141	}
142
143	repoObj, err := openRepo(s.dataDir(), repoName)
144	if err != nil {
145		fmt.Fprintf(ch.Stderr(), "arche-server: open repo: %v\n", err)
146		return 1
147	}
148	defer repoObj.Close()
149
150	srv := syncpkg.NewServerAuth(repoObj, canWrite)
151	if s.cfg.Hooks.PreReceive != "" || s.cfg.Hooks.Update != "" {
152		srv.PreUpdateHook = func(bm, oldHex, newHex string) error {
153			if err := runPreReceiveHook(s.cfg.Hooks.PreReceive, bm, oldHex, newHex, s.cfg.Hooks.TimeoutSec); err != nil {
154				return err
155			}
156			return runPreReceiveHook(s.cfg.Hooks.Update, bm, oldHex, newHex, s.cfg.Hooks.TimeoutSec)
157		}
158	}
159	srv.OnBookmarkUpdated = func(bm, oldHex, newHex string) {
160		s.db.FirePushWebhooks(repoName, username, bm, oldHex, newHex, collectPushCommits(repoObj, oldHex, newHex))
161		runPostReceiveHook(s.cfg.Hooks.PostReceive, bm, oldHex, newHex, s.cfg.Hooks.TimeoutSec)
162
163		if allowed, script, _ := s.db.GetRepoHookConfig(rec.ID); allowed && script != "" {
164			if !s.db.hasWriteCollaborator(rec.ID) {
165				runPostReceiveHook(script, bm, oldHex, newHex, s.cfg.Hooks.TimeoutSec)
166			}
167		}
168	}
169
170	serveHTTPOverSSH(ch, srv.Handler())
171	return 0
172}
173
174func serveHTTPOverSSH(ch ssh.Channel, handler http.Handler) {
175	local, remote := net.Pipe()
176	defer local.Close()
177
178	go func() {
179		io.Copy(local, ch) //nolint:errcheck
180		local.Close()
181	}()
182	go func() {
183		io.Copy(ch, local) //nolint:errcheck
184		ch.CloseWrite()    //nolint:errcheck
185	}()
186
187	l := &singleConnListener{conn: remote}
188	httpSrv := &http.Server{Handler: handler}
189	httpSrv.Serve(l) //nolint:errcheck
190}
191
192type singleConnListener struct {
193	conn net.Conn
194	done chan struct{}
195}
196
197func (l *singleConnListener) Accept() (net.Conn, error) {
198	if l.done == nil {
199		l.done = make(chan struct{})
200		return l.conn, nil
201	}
202	<-l.done
203	return nil, io.EOF
204}
205
206func (l *singleConnListener) Close() error {
207	if l.done != nil {
208		select {
209		case <-l.done:
210		default:
211			close(l.done)
212		}
213	}
214	return nil
215}
216
217func (l *singleConnListener) Addr() net.Addr { return l.conn.LocalAddr() }
218
219func (s *forgeServer) loadOrCreateHostKey() (ssh.Signer, error) {
220	keyPath := filepath.Join(s.cfg.Storage.DataDir, "ssh_host_ecdsa_key")
221
222	if _, err := os.Stat(keyPath); os.IsNotExist(err) {
223		privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
224		if err != nil {
225			return nil, fmt.Errorf("generate host key: %w", err)
226		}
227		signer, err := ssh.NewSignerFromKey(privKey)
228		if err != nil {
229			return nil, err
230		}
231
232		f, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
233		if err != nil {
234			return nil, fmt.Errorf("save host key: %w", err)
235		}
236		f.Write(signer.PublicKey().Marshal()) //nolint:errcheck
237		f.Close()
238
239		return signer, nil
240	}
241
242	privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
243	if err != nil {
244		return nil, fmt.Errorf("generate host key: %w", err)
245	}
246	return ssh.NewSignerFromKey(privKey)
247}