arche / internal/archesrv/auth.go

commit 154431fd
  1package archesrv
  2
  3import (
  4	"crypto/rand"
  5	"crypto/subtle"
  6	"database/sql"
  7	"encoding/base64"
  8	"encoding/hex"
  9	"fmt"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"golang.org/x/crypto/argon2"
 15)
 16
 17const (
 18	argonTime    = 3
 19	argonMemory  = 64 * 1024
 20	argonThreads = 4
 21	argonKeyLen  = 32
 22	argonSaltLen = 16
 23)
 24
 25type User struct {
 26	ID       int64
 27	Username string
 28	IsAdmin  bool
 29}
 30
 31func hashPassword(plain string) (string, error) {
 32	salt := make([]byte, argonSaltLen)
 33	if _, err := rand.Read(salt); err != nil {
 34		return "", fmt.Errorf("generate salt: %w", err)
 35	}
 36	hash := argon2.IDKey([]byte(plain), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
 37	encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
 38		argonMemory, argonTime, argonThreads,
 39		base64.RawStdEncoding.EncodeToString(salt),
 40		base64.RawStdEncoding.EncodeToString(hash),
 41	)
 42	return encoded, nil
 43}
 44
 45func checkPassword(encoded, plain string) bool {
 46	parts := strings.Split(encoded, "$")
 47	if len(parts) != 6 || parts[1] != "argon2id" {
 48		return false
 49	}
 50	var m, t uint32
 51	var p uint8
 52	if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p); err != nil {
 53		return false
 54	}
 55	salt, err := base64.RawStdEncoding.DecodeString(parts[4])
 56	if err != nil {
 57		return false
 58	}
 59	wantHash, err := base64.RawStdEncoding.DecodeString(parts[5])
 60	if err != nil {
 61		return false
 62	}
 63	gotHash := argon2.IDKey([]byte(plain), salt, t, m, p, uint32(len(wantHash)))
 64	return subtle.ConstantTimeCompare(gotHash, wantHash) == 1
 65}
 66
 67func (d *DB) CreateUser(username, password string, isAdmin bool) (*User, error) {
 68	hash, err := hashPassword(password)
 69	if err != nil {
 70		return nil, fmt.Errorf("hash password: %w", err)
 71	}
 72	admin := 0
 73	if isAdmin {
 74		admin = 1
 75	}
 76	res, err := d.db.Exec(
 77		"INSERT INTO users(username,password_hash,is_admin,created_at) VALUES(?,?,?,?)",
 78		username, hash, admin, time.Now().Unix(),
 79	)
 80	if err != nil {
 81		return nil, fmt.Errorf("create user: %w", err)
 82	}
 83	id, _ := res.LastInsertId()
 84	return &User{ID: id, Username: username, IsAdmin: isAdmin}, nil
 85}
 86
 87func (d *DB) GetUserByName(username string) (*User, string, error) {
 88	var u User
 89	var hash string
 90	err := d.db.QueryRow(
 91		"SELECT id, username, password_hash, is_admin FROM users WHERE username=?",
 92		username,
 93	).Scan(&u.ID, &u.Username, &hash, &u.IsAdmin)
 94	if err == sql.ErrNoRows {
 95		return nil, "", nil
 96	}
 97	return &u, hash, err
 98}
 99
100func (d *DB) GetUserByID(id int64) (*User, error) {
101	var u User
102	err := d.db.QueryRow(
103		"SELECT id, username, is_admin FROM users WHERE id=?", id,
104	).Scan(&u.ID, &u.Username, &u.IsAdmin)
105	if err == sql.ErrNoRows {
106		return nil, nil
107	}
108	return &u, err
109}
110
111func (d *DB) HasAnyUser() (bool, error) {
112	var count int
113	err := d.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
114	return count > 0, err
115}
116
117func (d *DB) ListUsers() ([]User, error) {
118	rows, err := d.db.Query("SELECT id, username, is_admin FROM users ORDER BY id")
119	if err != nil {
120		return nil, err
121	}
122	defer rows.Close()
123	var users []User
124	for rows.Next() {
125		var u User
126		if err := rows.Scan(&u.ID, &u.Username, &u.IsAdmin); err != nil {
127			return nil, err
128		}
129		users = append(users, u)
130	}
131	return users, rows.Err()
132}
133
134func (d *DB) DeleteUser(id int64) error {
135	_, err := d.db.Exec("DELETE FROM users WHERE id=?", id)
136	return err
137}
138
139const (
140	sessionCookie = "arche_session"
141	sessionTTL    = 7 * 24 * time.Hour
142)
143
144func generateToken() (string, error) {
145	b := make([]byte, 32)
146	if _, err := rand.Read(b); err != nil {
147		return "", err
148	}
149	return hex.EncodeToString(b), nil
150}
151
152func (d *DB) CreateSession(userID int64) (string, error) {
153	tok, err := generateToken()
154	if err != nil {
155		return "", err
156	}
157	exp := time.Now().Add(sessionTTL).Unix()
158	_, err = d.db.Exec(
159		"INSERT INTO sessions(user_id,token,expires_at) VALUES(?,?,?)",
160		userID, tok, exp,
161	)
162	return tok, err
163}
164
165func (d *DB) GetSessionUser(token string) (*User, error) {
166	var userID int64
167	var exp int64
168	err := d.db.QueryRow(
169		"SELECT user_id, expires_at FROM sessions WHERE token=?", token,
170	).Scan(&userID, &exp)
171	if err == sql.ErrNoRows {
172		return nil, nil
173	}
174	if err != nil {
175		return nil, err
176	}
177	if time.Now().Unix() > exp {
178		d.db.Exec("DELETE FROM sessions WHERE token=?", token) //nolint:errcheck
179		return nil, nil
180	}
181	return d.GetUserByID(userID)
182}
183
184func (d *DB) DeleteSession(token string) error {
185	_, err := d.db.Exec("DELETE FROM sessions WHERE token=?", token)
186	return err
187}
188
189func (d *DB) currentUser(r *http.Request) *User {
190	if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
191		fp := certFingerprint(r.TLS.PeerCertificates[0])
192		if u, _ := d.AuthorizeMTLSCert(fp); u != nil {
193			return u
194		}
195	}
196
197	auth := r.Header.Get("Authorization")
198	if strings.HasPrefix(auth, "Bearer ") {
199		tok := strings.TrimPrefix(auth, "Bearer ")
200		u, _ := d.lookupAPIToken(tok)
201		return u
202	}
203
204	c, err := r.Cookie(sessionCookie)
205	if err != nil {
206		return nil
207	}
208
209	u, _ := d.GetSessionUser(c.Value)
210	return u
211}
212
213func (d *DB) CreateAPIToken(userID int64, label string) (string, error) {
214	tok, err := generateToken()
215	if err != nil {
216		return "", err
217	}
218	hash, err := hashPassword(tok)
219	if err != nil {
220		return "", err
221	}
222	_, err = d.db.Exec(
223		"INSERT INTO api_tokens(user_id,token_hash,label,created_at) VALUES(?,?,?,?)",
224		userID, hash, label, time.Now().Unix(),
225	)
226	if err != nil {
227		return "", err
228	}
229	return tok, nil
230}
231
232func (d *DB) lookupAPIToken(plain string) (*User, error) {
233	rows, err := d.db.Query("SELECT user_id, token_hash FROM api_tokens")
234	if err != nil {
235		return nil, err
236	}
237	var foundUID int64
238	for rows.Next() {
239		var uid int64
240		var hash string
241		if err := rows.Scan(&uid, &hash); err != nil {
242			continue
243		}
244		if checkPassword(hash, plain) {
245			foundUID = uid
246			break
247		}
248	}
249
250	rows.Close()
251	if foundUID == 0 {
252		return nil, nil
253	}
254
255	return d.GetUserByID(foundUID)
256}
257
258type APIToken struct {
259	ID        int64
260	Label     string
261	CreatedAt int64
262}
263
264func (d *DB) ListAPITokens(userID int64) ([]APIToken, error) {
265	rows, err := d.db.Query(
266		"SELECT id, label, created_at FROM api_tokens WHERE user_id=? ORDER BY created_at DESC",
267		userID,
268	)
269	if err != nil {
270		return nil, err
271	}
272	defer rows.Close()
273	var tokens []APIToken
274	for rows.Next() {
275		var t APIToken
276		if err := rows.Scan(&t.ID, &t.Label, &t.CreatedAt); err != nil {
277			return nil, err
278		}
279		tokens = append(tokens, t)
280	}
281	return tokens, rows.Err()
282}
283
284func (d *DB) DeleteAPIToken(id int64, userID int64) error {
285	_, err := d.db.Exec("DELETE FROM api_tokens WHERE id=? AND user_id=?", id, userID)
286	return err
287}