arche / internal/archesrv/sshkeys.go

commit 154431fd
 1package archesrv
 2
 3import (
 4	"bytes"
 5	"time"
 6
 7	"golang.org/x/crypto/ssh"
 8)
 9
10type SSHKey struct {
11	ID        int64
12	UserID    int64
13	Label     string
14	PublicKey string
15	AddedAt   time.Time
16}
17
18func (d *DB) AddSSHKey(userID int64, label, publicKey string) (*SSHKey, error) {
19	if _, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKey)); err != nil {
20		return nil, err
21	}
22	res, err := d.db.Exec(
23		"INSERT INTO ssh_keys(user_id,label,public_key,added_at) VALUES(?,?,?,?)",
24		userID, label, publicKey, time.Now().Unix(),
25	)
26	if err != nil {
27		return nil, err
28	}
29	id, _ := res.LastInsertId()
30	return &SSHKey{ID: id, UserID: userID, Label: label, PublicKey: publicKey}, nil
31}
32
33func (d *DB) ListSSHKeys(userID int64) ([]SSHKey, error) {
34	rows, err := d.db.Query(
35		"SELECT id, user_id, label, public_key, added_at FROM ssh_keys WHERE user_id=? ORDER BY id",
36		userID,
37	)
38	if err != nil {
39		return nil, err
40	}
41	defer rows.Close()
42	var keys []SSHKey
43	for rows.Next() {
44		var k SSHKey
45		var ts int64
46		if err := rows.Scan(&k.ID, &k.UserID, &k.Label, &k.PublicKey, &ts); err != nil {
47			return nil, err
48		}
49		k.AddedAt = time.Unix(ts, 0)
50		keys = append(keys, k)
51	}
52	return keys, rows.Err()
53}
54
55func (d *DB) DeleteSSHKey(keyID, userID int64) error {
56	_, err := d.db.Exec("DELETE FROM ssh_keys WHERE id=? AND user_id=?", keyID, userID)
57	return err
58}
59
60func (d *DB) AuthorizeSSHKey(pubKey ssh.PublicKey) (*User, error) {
61	keyBytes := pubKey.Marshal()
62
63	rows, err := d.db.Query("SELECT user_id, public_key FROM ssh_keys")
64	if err != nil {
65		return nil, err
66	}
67	defer rows.Close()
68
69	for rows.Next() {
70		var userID int64
71		var keyStr string
72		if err := rows.Scan(&userID, &keyStr); err != nil {
73			continue
74		}
75		stored, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyStr))
76		if err != nil {
77			continue
78		}
79		if bytes.Equal(stored.Marshal(), keyBytes) {
80			return d.GetUserByID(userID)
81		}
82	}
83	return nil, rows.Err()
84}