arche / internal/archesrv/mtlscerts.go

commit 154431fd
  1package archesrv
  2
  3import (
  4	"crypto/sha256"
  5	"crypto/tls"
  6	"crypto/x509"
  7	"encoding/pem"
  8	"fmt"
  9	"time"
 10)
 11
 12type MTLSCert struct {
 13	ID          int64
 14	UserID      int64
 15	Label       string
 16	Fingerprint string
 17	CertPEM     string
 18	AddedAt     time.Time
 19}
 20
 21func certFingerprint(cert *x509.Certificate) string {
 22	sum := sha256.Sum256(cert.Raw)
 23	return fmt.Sprintf("%x", sum)
 24}
 25
 26func certFingerprintFromPEM(certPEM string) (string, *x509.Certificate, error) {
 27	block, _ := pem.Decode([]byte(certPEM))
 28	if block == nil || block.Type != "CERTIFICATE" {
 29		return "", nil, fmt.Errorf("invalid certificate PEM (expected CERTIFICATE block)")
 30	}
 31	cert, err := x509.ParseCertificate(block.Bytes)
 32	if err != nil {
 33		return "", nil, fmt.Errorf("parse certificate: %w", err)
 34	}
 35	return certFingerprint(cert), cert, nil
 36}
 37
 38func (d *DB) AddMTLSCert(userID int64, label, certPEM string) (*MTLSCert, error) {
 39	fp, _, err := certFingerprintFromPEM(certPEM)
 40	if err != nil {
 41		return nil, err
 42	}
 43	now := time.Now().Unix()
 44	res, err := d.db.Exec(
 45		"INSERT INTO mtls_certs(user_id,label,fingerprint,cert_pem,added_at) VALUES(?,?,?,?,?)",
 46		userID, label, fp, certPEM, now,
 47	)
 48	if err != nil {
 49		return nil, fmt.Errorf("add mTLS cert: %w", err)
 50	}
 51	id, _ := res.LastInsertId()
 52	return &MTLSCert{
 53		ID:          id,
 54		UserID:      userID,
 55		Label:       label,
 56		Fingerprint: fp,
 57		CertPEM:     certPEM,
 58		AddedAt:     time.Unix(now, 0),
 59	}, nil
 60}
 61
 62func (d *DB) ListMTLSCerts(userID int64) ([]MTLSCert, error) {
 63	rows, err := d.db.Query(
 64		"SELECT id, user_id, label, fingerprint, cert_pem, added_at FROM mtls_certs WHERE user_id=? ORDER BY id",
 65		userID,
 66	)
 67	if err != nil {
 68		return nil, err
 69	}
 70	defer rows.Close()
 71	var out []MTLSCert
 72	for rows.Next() {
 73		var c MTLSCert
 74		var ts int64
 75		if err := rows.Scan(&c.ID, &c.UserID, &c.Label, &c.Fingerprint, &c.CertPEM, &ts); err != nil {
 76			return nil, err
 77		}
 78		c.AddedAt = time.Unix(ts, 0)
 79		out = append(out, c)
 80	}
 81	return out, rows.Err()
 82}
 83
 84func (d *DB) DeleteMTLSCert(certID, userID int64) error {
 85	_, err := d.db.Exec("DELETE FROM mtls_certs WHERE id=? AND user_id=?", certID, userID)
 86	return err
 87}
 88
 89func (d *DB) AuthorizeMTLSCert(fingerprint string) (*User, error) {
 90	var userID int64
 91	err := d.db.QueryRow(
 92		"SELECT user_id FROM mtls_certs WHERE fingerprint=?", fingerprint,
 93	).Scan(&userID)
 94	if err != nil {
 95		return nil, nil //nolint:nilerr
 96	}
 97	return d.GetUserByID(userID)
 98}
 99
100func tlsConfigMTLS() *tls.Config {
101	return &tls.Config{
102		ClientAuth: tls.RequireAnyClientCert,
103		MinVersion: tls.VersionTLS12,
104	}
105}