✨ feat(auth): JWT secret retention policy + automatic cleanup loop (ADR-0021) (#41)
Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #41.
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// JWTSecret represents a JWT secret with metadata
|
||||
@@ -12,10 +16,16 @@ type JWTSecret struct {
|
||||
ExpiresAt *time.Time // Optional expiration time
|
||||
}
|
||||
|
||||
// JWTSecretManager manages multiple JWT secrets for rotation
|
||||
// JWTSecretManager manages multiple JWT secrets for rotation.
|
||||
// All operations are mutex-protected so the cleanup goroutine
|
||||
// (StartCleanupLoop) can run alongside Generate / Validate calls.
|
||||
// ADR-0021 implements automatic removal of expired secrets while
|
||||
// always preserving the primary secret.
|
||||
type JWTSecretManager struct {
|
||||
mu sync.Mutex
|
||||
secrets []JWTSecret
|
||||
primarySecret string
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewJWTSecretManager creates a new JWT secret manager
|
||||
@@ -34,12 +44,19 @@ func NewJWTSecretManager(initialSecret string) *JWTSecretManager {
|
||||
|
||||
// AddSecret adds a new JWT secret
|
||||
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.addSecretLocked(secret, isPrimary, expiresIn)
|
||||
}
|
||||
|
||||
// addSecretLocked is the internal helper that assumes the mutex is held.
|
||||
func (m *JWTSecretManager) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||
var expiresAt *time.Time
|
||||
if expiresIn > 0 {
|
||||
expirationTime := time.Now().Add(expiresIn)
|
||||
expiresAt = &expirationTime
|
||||
}
|
||||
// If expiresIn is 0 or negative, expiresAt remains nil (no expiration)
|
||||
// expiresIn <= 0 means no expiration
|
||||
|
||||
m.secrets = append(m.secrets, JWTSecret{
|
||||
Secret: secret,
|
||||
@@ -55,48 +72,60 @@ func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn ti
|
||||
|
||||
// RotateToSecret rotates to a new primary secret
|
||||
func (m *JWTSecretManager) RotateToSecret(newSecret string) {
|
||||
// Mark existing primary as non-primary
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for i, secret := range m.secrets {
|
||||
if secret.IsPrimary {
|
||||
m.secrets[i].IsPrimary = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Add new secret as primary
|
||||
m.AddSecret(newSecret, true, 0) // No expiration for primary
|
||||
m.addSecretLocked(newSecret, true, 0)
|
||||
}
|
||||
|
||||
// GetPrimarySecret returns the current primary secret
|
||||
func (m *JWTSecretManager) GetPrimarySecret() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.primarySecret
|
||||
}
|
||||
|
||||
// GetAllValidSecrets returns all valid (non-expired) secrets
|
||||
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
|
||||
var validSecrets []JWTSecret
|
||||
now := time.Now()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
valid := make([]JWTSecret, 0, len(m.secrets))
|
||||
for _, secret := range m.secrets {
|
||||
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
||||
validSecrets = append(validSecrets, secret)
|
||||
valid = append(valid, secret)
|
||||
}
|
||||
}
|
||||
|
||||
return validSecrets
|
||||
return valid
|
||||
}
|
||||
|
||||
// GetSecretByIndex returns a secret by index for testing
|
||||
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if index < 0 || index >= len(m.secrets) {
|
||||
return "", false
|
||||
}
|
||||
return m.secrets[index].Secret, true
|
||||
}
|
||||
|
||||
// Reset resets the secret manager to its initial state with only the primary secret
|
||||
// This is useful for test cleanup to ensure tests don't interfere with each other
|
||||
// Reset resets the secret manager to its initial state with only the primary
|
||||
// secret. Used for test cleanup so tests don't interfere with each other.
|
||||
func (m *JWTSecretManager) Reset(initialSecret string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cleanupCancel != nil {
|
||||
m.cleanupCancel()
|
||||
m.cleanupCancel = nil
|
||||
}
|
||||
m.secrets = []JWTSecret{
|
||||
{
|
||||
Secret: initialSecret,
|
||||
@@ -106,3 +135,64 @@ func (m *JWTSecretManager) Reset(initialSecret string) {
|
||||
}
|
||||
m.primarySecret = initialSecret
|
||||
}
|
||||
|
||||
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
|
||||
// non-nil and in the past. Returns the count of secrets removed.
|
||||
// The primary secret is never removed regardless of expiration (ADR-0021).
|
||||
func (m *JWTSecretManager) RemoveExpiredSecrets() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
kept := make([]JWTSecret, 0, len(m.secrets))
|
||||
removed := 0
|
||||
for _, secret := range m.secrets {
|
||||
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
|
||||
removed++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, secret)
|
||||
}
|
||||
m.secrets = kept
|
||||
return removed
|
||||
}
|
||||
|
||||
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
|
||||
// given interval. Stops when the parent context is cancelled. Calling again
|
||||
// cancels the previous loop's context and starts a fresh one.
|
||||
// If interval <= 0, the loop is disabled (cleanup must be triggered manually
|
||||
// via RemoveExpiredSecrets).
|
||||
func (m *JWTSecretManager) StartCleanupLoop(ctx context.Context, interval time.Duration) {
|
||||
m.mu.Lock()
|
||||
if m.cleanupCancel != nil {
|
||||
m.cleanupCancel()
|
||||
}
|
||||
loopCtx, cancel := context.WithCancel(ctx)
|
||||
m.cleanupCancel = cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
if interval <= 0 {
|
||||
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
|
||||
for {
|
||||
select {
|
||||
case <-loopCtx.Done():
|
||||
log.Info().Msg("JWT secret cleanup loop stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
removed := m.RemoveExpiredSecrets()
|
||||
if removed > 0 {
|
||||
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
|
||||
} else {
|
||||
log.Trace().Msg("JWT cleanup tick: no expired secrets")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user