1package archesrv
2
3import (
4 "crypto/sha256"
5 "crypto/tls"
6 "crypto/x509"
7 "encoding/pem"
8 "fmt"
9 "time"
10)
11
12type MTLSCert struct {
13 ID int64
14 UserID int64
15 Label string
16 Fingerprint string
17 CertPEM string
18 AddedAt time.Time
19}
20
21func certFingerprint(cert *x509.Certificate) string {
22 sum := sha256.Sum256(cert.Raw)
23 return fmt.Sprintf("%x", sum)
24}
25
26func certFingerprintFromPEM(certPEM string) (string, *x509.Certificate, error) {
27 block, _ := pem.Decode([]byte(certPEM))
28 if block == nil || block.Type != "CERTIFICATE" {
29 return "", nil, fmt.Errorf("invalid certificate PEM (expected CERTIFICATE block)")
30 }
31 cert, err := x509.ParseCertificate(block.Bytes)
32 if err != nil {
33 return "", nil, fmt.Errorf("parse certificate: %w", err)
34 }
35 return certFingerprint(cert), cert, nil
36}
37
38func (d *DB) AddMTLSCert(userID int64, label, certPEM string) (*MTLSCert, error) {
39 fp, _, err := certFingerprintFromPEM(certPEM)
40 if err != nil {
41 return nil, err
42 }
43 now := time.Now().Unix()
44 res, err := d.db.Exec(
45 "INSERT INTO mtls_certs(user_id,label,fingerprint,cert_pem,added_at) VALUES(?,?,?,?,?)",
46 userID, label, fp, certPEM, now,
47 )
48 if err != nil {
49 return nil, fmt.Errorf("add mTLS cert: %w", err)
50 }
51 id, _ := res.LastInsertId()
52 return &MTLSCert{
53 ID: id,
54 UserID: userID,
55 Label: label,
56 Fingerprint: fp,
57 CertPEM: certPEM,
58 AddedAt: time.Unix(now, 0),
59 }, nil
60}
61
62func (d *DB) ListMTLSCerts(userID int64) ([]MTLSCert, error) {
63 rows, err := d.db.Query(
64 "SELECT id, user_id, label, fingerprint, cert_pem, added_at FROM mtls_certs WHERE user_id=? ORDER BY id",
65 userID,
66 )
67 if err != nil {
68 return nil, err
69 }
70 defer rows.Close()
71 var out []MTLSCert
72 for rows.Next() {
73 var c MTLSCert
74 var ts int64
75 if err := rows.Scan(&c.ID, &c.UserID, &c.Label, &c.Fingerprint, &c.CertPEM, &ts); err != nil {
76 return nil, err
77 }
78 c.AddedAt = time.Unix(ts, 0)
79 out = append(out, c)
80 }
81 return out, rows.Err()
82}
83
84func (d *DB) DeleteMTLSCert(certID, userID int64) error {
85 _, err := d.db.Exec("DELETE FROM mtls_certs WHERE id=? AND user_id=?", certID, userID)
86 return err
87}
88
89func (d *DB) AuthorizeMTLSCert(fingerprint string) (*User, error) {
90 var userID int64
91 err := d.db.QueryRow(
92 "SELECT user_id FROM mtls_certs WHERE fingerprint=?", fingerprint,
93 ).Scan(&userID)
94 if err != nil {
95 return nil, nil //nolint:nilerr
96 }
97 return d.GetUserByID(userID)
98}
99
100func tlsConfigMTLS() *tls.Config {
101 return &tls.Config{
102 ClientAuth: tls.RequireAnyClientCert,
103 MinVersion: tls.VersionTLS12,
104 }
105}