1package archesrv
2
3import (
4 "crypto/rand"
5 "crypto/subtle"
6 "database/sql"
7 "encoding/base64"
8 "encoding/hex"
9 "fmt"
10 "net/http"
11 "strings"
12 "time"
13
14 "golang.org/x/crypto/argon2"
15)
16
17const (
18 argonTime = 3
19 argonMemory = 64 * 1024
20 argonThreads = 4
21 argonKeyLen = 32
22 argonSaltLen = 16
23)
24
25type User struct {
26 ID int64
27 Username string
28 IsAdmin bool
29}
30
31func hashPassword(plain string) (string, error) {
32 salt := make([]byte, argonSaltLen)
33 if _, err := rand.Read(salt); err != nil {
34 return "", fmt.Errorf("generate salt: %w", err)
35 }
36 hash := argon2.IDKey([]byte(plain), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
37 encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
38 argonMemory, argonTime, argonThreads,
39 base64.RawStdEncoding.EncodeToString(salt),
40 base64.RawStdEncoding.EncodeToString(hash),
41 )
42 return encoded, nil
43}
44
45func checkPassword(encoded, plain string) bool {
46 parts := strings.Split(encoded, "$")
47 if len(parts) != 6 || parts[1] != "argon2id" {
48 return false
49 }
50 var m, t uint32
51 var p uint8
52 if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p); err != nil {
53 return false
54 }
55 salt, err := base64.RawStdEncoding.DecodeString(parts[4])
56 if err != nil {
57 return false
58 }
59 wantHash, err := base64.RawStdEncoding.DecodeString(parts[5])
60 if err != nil {
61 return false
62 }
63 gotHash := argon2.IDKey([]byte(plain), salt, t, m, p, uint32(len(wantHash)))
64 return subtle.ConstantTimeCompare(gotHash, wantHash) == 1
65}
66
67func (d *DB) CreateUser(username, password string, isAdmin bool) (*User, error) {
68 hash, err := hashPassword(password)
69 if err != nil {
70 return nil, fmt.Errorf("hash password: %w", err)
71 }
72 admin := 0
73 if isAdmin {
74 admin = 1
75 }
76 res, err := d.db.Exec(
77 "INSERT INTO users(username,password_hash,is_admin,created_at) VALUES(?,?,?,?)",
78 username, hash, admin, time.Now().Unix(),
79 )
80 if err != nil {
81 return nil, fmt.Errorf("create user: %w", err)
82 }
83 id, _ := res.LastInsertId()
84 return &User{ID: id, Username: username, IsAdmin: isAdmin}, nil
85}
86
87func (d *DB) GetUserByName(username string) (*User, string, error) {
88 var u User
89 var hash string
90 err := d.db.QueryRow(
91 "SELECT id, username, password_hash, is_admin FROM users WHERE username=?",
92 username,
93 ).Scan(&u.ID, &u.Username, &hash, &u.IsAdmin)
94 if err == sql.ErrNoRows {
95 return nil, "", nil
96 }
97 return &u, hash, err
98}
99
100func (d *DB) GetUserByID(id int64) (*User, error) {
101 var u User
102 err := d.db.QueryRow(
103 "SELECT id, username, is_admin FROM users WHERE id=?", id,
104 ).Scan(&u.ID, &u.Username, &u.IsAdmin)
105 if err == sql.ErrNoRows {
106 return nil, nil
107 }
108 return &u, err
109}
110
111func (d *DB) HasAnyUser() (bool, error) {
112 var count int
113 err := d.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
114 return count > 0, err
115}
116
117func (d *DB) ListUsers() ([]User, error) {
118 rows, err := d.db.Query("SELECT id, username, is_admin FROM users ORDER BY id")
119 if err != nil {
120 return nil, err
121 }
122 defer rows.Close()
123 var users []User
124 for rows.Next() {
125 var u User
126 if err := rows.Scan(&u.ID, &u.Username, &u.IsAdmin); err != nil {
127 return nil, err
128 }
129 users = append(users, u)
130 }
131 return users, rows.Err()
132}
133
134func (d *DB) DeleteUser(id int64) error {
135 _, err := d.db.Exec("DELETE FROM users WHERE id=?", id)
136 return err
137}
138
139const (
140 sessionCookie = "arche_session"
141 sessionTTL = 7 * 24 * time.Hour
142)
143
144func generateToken() (string, error) {
145 b := make([]byte, 32)
146 if _, err := rand.Read(b); err != nil {
147 return "", err
148 }
149 return hex.EncodeToString(b), nil
150}
151
152func (d *DB) CreateSession(userID int64) (string, error) {
153 tok, err := generateToken()
154 if err != nil {
155 return "", err
156 }
157 exp := time.Now().Add(sessionTTL).Unix()
158 _, err = d.db.Exec(
159 "INSERT INTO sessions(user_id,token,expires_at) VALUES(?,?,?)",
160 userID, tok, exp,
161 )
162 return tok, err
163}
164
165func (d *DB) GetSessionUser(token string) (*User, error) {
166 var userID int64
167 var exp int64
168 err := d.db.QueryRow(
169 "SELECT user_id, expires_at FROM sessions WHERE token=?", token,
170 ).Scan(&userID, &exp)
171 if err == sql.ErrNoRows {
172 return nil, nil
173 }
174 if err != nil {
175 return nil, err
176 }
177 if time.Now().Unix() > exp {
178 d.db.Exec("DELETE FROM sessions WHERE token=?", token) //nolint:errcheck
179 return nil, nil
180 }
181 return d.GetUserByID(userID)
182}
183
184func (d *DB) DeleteSession(token string) error {
185 _, err := d.db.Exec("DELETE FROM sessions WHERE token=?", token)
186 return err
187}
188
189func (d *DB) currentUser(r *http.Request) *User {
190 if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
191 fp := certFingerprint(r.TLS.PeerCertificates[0])
192 if u, _ := d.AuthorizeMTLSCert(fp); u != nil {
193 return u
194 }
195 }
196
197 auth := r.Header.Get("Authorization")
198 if strings.HasPrefix(auth, "Bearer ") {
199 tok := strings.TrimPrefix(auth, "Bearer ")
200 u, _ := d.lookupAPIToken(tok)
201 return u
202 }
203
204 c, err := r.Cookie(sessionCookie)
205 if err != nil {
206 return nil
207 }
208
209 u, _ := d.GetSessionUser(c.Value)
210 return u
211}
212
213func (d *DB) CreateAPIToken(userID int64, label string) (string, error) {
214 tok, err := generateToken()
215 if err != nil {
216 return "", err
217 }
218 hash, err := hashPassword(tok)
219 if err != nil {
220 return "", err
221 }
222 _, err = d.db.Exec(
223 "INSERT INTO api_tokens(user_id,token_hash,label,created_at) VALUES(?,?,?,?)",
224 userID, hash, label, time.Now().Unix(),
225 )
226 if err != nil {
227 return "", err
228 }
229 return tok, nil
230}
231
232func (d *DB) lookupAPIToken(plain string) (*User, error) {
233 rows, err := d.db.Query("SELECT user_id, token_hash FROM api_tokens")
234 if err != nil {
235 return nil, err
236 }
237 var foundUID int64
238 for rows.Next() {
239 var uid int64
240 var hash string
241 if err := rows.Scan(&uid, &hash); err != nil {
242 continue
243 }
244 if checkPassword(hash, plain) {
245 foundUID = uid
246 break
247 }
248 }
249
250 rows.Close()
251 if foundUID == 0 {
252 return nil, nil
253 }
254
255 return d.GetUserByID(foundUID)
256}
257
258type APIToken struct {
259 ID int64
260 Label string
261 CreatedAt int64
262}
263
264func (d *DB) ListAPITokens(userID int64) ([]APIToken, error) {
265 rows, err := d.db.Query(
266 "SELECT id, label, created_at FROM api_tokens WHERE user_id=? ORDER BY created_at DESC",
267 userID,
268 )
269 if err != nil {
270 return nil, err
271 }
272 defer rows.Close()
273 var tokens []APIToken
274 for rows.Next() {
275 var t APIToken
276 if err := rows.Scan(&t.ID, &t.Label, &t.CreatedAt); err != nil {
277 return nil, err
278 }
279 tokens = append(tokens, t)
280 }
281 return tokens, rows.Err()
282}
283
284func (d *DB) DeleteAPIToken(id int64, userID int64) error {
285 _, err := d.db.Exec("DELETE FROM api_tokens WHERE id=? AND user_id=?", id, userID)
286 return err
287}