Files
dance-lessons-coach/pkg/jwt/jwt_secret_manager.go
Gabriel Radureau 03ea2a7b89
Some checks failed
CI/CD Pipeline / Build Docker Cache (push) Successful in 13s
CI/CD Pipeline / Trigger Docker Push (push) Has been cancelled
CI/CD Pipeline / CI Pipeline (push) Has been cancelled
feat(auth): JWT secret retention policy + automatic cleanup loop (ADR-0021) (#41)
Co-authored-by: Gabriel Radureau <arcodange@gmail.com>
Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
2026-05-05 08:40:27 +02:00

164 lines
4.2 KiB
Go

package jwt
import (
"context"
"sync"
"time"
"github.com/rs/zerolog/log"
)
// jwtSecretManagerImpl implements the JWTSecretManager interface.
// All operations are mutex-protected so the cleanup goroutine
// (StartCleanupLoop) can run alongside Generate / Validate calls.
type jwtSecretManagerImpl struct {
mu sync.Mutex
secrets []JWTSecret
primarySecret string
cleanupCancel context.CancelFunc
}
// NewJWTSecretManager creates a new JWT secret manager.
func NewJWTSecretManager(initialSecret string) JWTSecretManager {
return &jwtSecretManagerImpl{
secrets: []JWTSecret{
{
Secret: initialSecret,
IsPrimary: true,
CreatedAt: time.Now(),
},
},
primarySecret: initialSecret,
}
}
// AddSecret adds a new JWT secret.
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
m.addSecretLocked(secret, isPrimary, expiresIn)
}
// addSecretLocked is the internal helper that assumes the mutex is held.
func (m *jwtSecretManagerImpl) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
entry := JWTSecret{
Secret: secret,
IsPrimary: isPrimary,
CreatedAt: time.Now(),
}
if expiresIn > 0 {
expiresAt := time.Now().Add(expiresIn)
entry.ExpiresAt = &expiresAt
}
m.secrets = append(m.secrets, entry)
if isPrimary {
m.primarySecret = secret
}
}
// RotateToSecret rotates to a new primary secret.
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
m.mu.Lock()
defer m.mu.Unlock()
for i, secret := range m.secrets {
if secret.IsPrimary {
m.secrets[i].IsPrimary = false
break
}
}
m.addSecretLocked(newSecret, true, 0)
}
// GetPrimarySecret returns the current primary secret.
func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.primarySecret
}
// GetAllValidSecrets returns all valid (non-expired) secrets.
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
valid := make([]JWTSecret, 0, len(m.secrets))
for _, secret := range m.secrets {
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
valid = append(valid, secret)
}
}
return valid
}
// GetSecretByIndex returns a secret by index for testing.
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if index < 0 || index >= len(m.secrets) {
return "", false
}
return m.secrets[index].Secret, true
}
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
// non-nil and in the past. Returns the count of secrets removed.
// The primary secret is never removed regardless of expiration (ADR-0021).
func (m *jwtSecretManagerImpl) RemoveExpiredSecrets() int {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
kept := make([]JWTSecret, 0, len(m.secrets))
removed := 0
for _, secret := range m.secrets {
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
removed++
continue
}
kept = append(kept, secret)
}
m.secrets = kept
return removed
}
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
// given interval. Stops when the parent context is cancelled. Calling again
// cancels the previous loop's context and starts a fresh one.
func (m *jwtSecretManagerImpl) StartCleanupLoop(ctx context.Context, interval time.Duration) {
m.mu.Lock()
if m.cleanupCancel != nil {
m.cleanupCancel()
}
loopCtx, cancel := context.WithCancel(ctx)
m.cleanupCancel = cancel
m.mu.Unlock()
if interval <= 0 {
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
for {
select {
case <-loopCtx.Done():
log.Info().Msg("JWT secret cleanup loop stopped")
return
case <-ticker.C:
removed := m.RemoveExpiredSecrets()
if removed > 0 {
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
} else {
log.Trace().Msg("JWT cleanup tick: no expired secrets")
}
}
}
}()
}