arche / internal/syncpkg/ssh_transport.go

commit 154431fd
  1package syncpkg
  2
  3import (
  4	"bufio"
  5	"fmt"
  6	"io"
  7	"net"
  8	"net/http"
  9	"os"
 10	"os/exec"
 11	"time"
 12)
 13
 14type SSHTransport struct {
 15	Host    string
 16	Repo    string
 17	KeyFile string
 18}
 19
 20func NewSSHClient(host, repo, keyFile string) (*http.Client, error) {
 21	t := &SSHTransport{Host: host, Repo: repo, KeyFile: keyFile}
 22	return &http.Client{Transport: t, Timeout: 0}, nil
 23}
 24
 25type sshConn struct {
 26	cmd    *exec.Cmd
 27	reader io.ReadCloser
 28	writer io.WriteCloser
 29	local  net.Addr
 30	remote net.Addr
 31}
 32
 33func dialSSH(host, repo, keyFile string) (*sshConn, error) {
 34	args := []string{
 35		"-o", "StrictHostKeyChecking=accept-new",
 36		"-o", "BatchMode=yes",
 37	}
 38	if keyFile != "" {
 39		args = append(args, "-i", keyFile)
 40	}
 41	args = append(args, host, "arche-sync", repo)
 42
 43	cmd := exec.Command("ssh", args...)
 44	cmd.Stderr = os.Stderr
 45
 46	stdin, err := cmd.StdinPipe()
 47	if err != nil {
 48		return nil, fmt.Errorf("ssh stdin: %w", err)
 49	}
 50	stdout, err := cmd.StdoutPipe()
 51	if err != nil {
 52		return nil, fmt.Errorf("ssh stdout: %w", err)
 53	}
 54	if err := cmd.Start(); err != nil {
 55		return nil, fmt.Errorf("ssh start: %w", err)
 56	}
 57
 58	return &sshConn{
 59		cmd:    cmd,
 60		reader: stdout,
 61		writer: stdin,
 62		local:  &net.TCPAddr{},
 63		remote: &net.TCPAddr{},
 64	}, nil
 65}
 66
 67func (c *sshConn) Read(b []byte) (int, error)  { return c.reader.Read(b) }
 68func (c *sshConn) Write(b []byte) (int, error) { return c.writer.Write(b) }
 69func (c *sshConn) Close() error {
 70	c.writer.Close() //nolint:errcheck
 71	c.reader.Close() //nolint:errcheck
 72	return c.cmd.Wait()
 73}
 74func (c *sshConn) LocalAddr() net.Addr                { return c.local }
 75func (c *sshConn) RemoteAddr() net.Addr               { return c.remote }
 76func (c *sshConn) SetDeadline(t time.Time) error      { return nil }
 77func (c *sshConn) SetReadDeadline(t time.Time) error  { return nil }
 78func (c *sshConn) SetWriteDeadline(t time.Time) error { return nil }
 79
 80func (t *SSHTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 81	conn, err := dialSSH(t.Host, t.Repo, t.KeyFile)
 82	if err != nil {
 83		return nil, err
 84	}
 85
 86	local, remote := net.Pipe()
 87
 88	go func() {
 89		io.Copy(local, conn) //nolint:errcheck
 90		local.Close()
 91	}()
 92	go func() {
 93		io.Copy(conn, local) //nolint:errcheck
 94	}()
 95
 96	if err := req.Write(remote); err != nil {
 97		remote.Close()
 98		conn.Close()
 99		return nil, fmt.Errorf("write request: %w", err)
100	}
101
102	resp, err := http.ReadResponse(bufio.NewReader(remote), req)
103	if err != nil {
104		remote.Close()
105		conn.Close()
106		return nil, fmt.Errorf("read response: %w", err)
107	}
108
109	resp.Body = &closeOnBodyClose{resp.Body, remote, conn}
110	return resp, nil
111}
112
113type closeOnBodyClose struct {
114	io.ReadCloser
115	pipe net.Conn
116	conn io.Closer
117}
118
119func (c *closeOnBodyClose) Close() error {
120	err := c.ReadCloser.Close()
121	c.pipe.Close() //nolint:errcheck
122	c.conn.Close() //nolint:errcheck
123	return err
124}