feat(auth): JWT secret retention policy + automatic cleanup loop (ADR-0021) (#41)
Some checks failed
CI/CD Pipeline / Build Docker Cache (push) Successful in 13s
CI/CD Pipeline / Trigger Docker Push (push) Has been cancelled
CI/CD Pipeline / CI Pipeline (push) Has been cancelled

Co-authored-by: Gabriel Radureau <arcodange@gmail.com>
Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #41.
This commit is contained in:
2026-05-05 08:40:27 +02:00
committed by arcodange
parent a2beadc458
commit 03ea2a7b89
8 changed files with 319 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
# 10. JWT Secret Retention Policy # 21. JWT Secret Retention Policy
**Status:** Proposed **Status:** Implemented (2026-05-05 — `pkg/user/jwt_manager.go` `RemoveExpiredSecrets` + `StartCleanupLoop`, wired in `pkg/server/server.go` `Run`; admin endpoint `/api/v1/admin/jwt/secrets` remains explicitly out of scope and tracked under @todo BDD scenarios)
## Context ## Context

View File

@@ -24,13 +24,25 @@ type JWTSecret struct {
ExpiresAt *time.Time // Optional expiration time ExpiresAt *time.Time // Optional expiration time
} }
// JWTSecretManager manages multiple JWT secrets for rotation // JWTSecretManager manages multiple JWT secrets for rotation.
// Secrets can carry an optional expiration; the cleanup loop removes them
// after expiry while always preserving the primary secret (ADR-0021).
type JWTSecretManager interface { type JWTSecretManager interface {
AddSecret(secret string, isPrimary bool, expiresIn time.Duration) AddSecret(secret string, isPrimary bool, expiresIn time.Duration)
RotateToSecret(newSecret string) RotateToSecret(newSecret string)
GetPrimarySecret() string GetPrimarySecret() string
GetAllValidSecrets() []JWTSecret GetAllValidSecrets() []JWTSecret
GetSecretByIndex(index int) (string, bool) GetSecretByIndex(index int) (string, bool)
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
// non-nil and in the past. Returns the count of secrets removed.
// The primary secret is never removed regardless of expiration.
RemoveExpiredSecrets() int
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at
// the given interval. Stops when the context is cancelled. Safe to call
// once at startup; calling again replaces the previous loop's context.
StartCleanupLoop(ctx context.Context, interval time.Duration)
} }
// JWTService defines interface for JWT operations // JWTService defines interface for JWT operations

View File

@@ -1,16 +1,24 @@
package jwt package jwt
import ( import (
"context"
"sync"
"time" "time"
"github.com/rs/zerolog/log"
) )
// jwtSecretManagerImpl implements the JWTSecretManager interface // jwtSecretManagerImpl implements the JWTSecretManager interface.
// All operations are mutex-protected so the cleanup goroutine
// (StartCleanupLoop) can run alongside Generate / Validate calls.
type jwtSecretManagerImpl struct { type jwtSecretManagerImpl struct {
mu sync.Mutex
secrets []JWTSecret secrets []JWTSecret
primarySecret string primarySecret string
cleanupCancel context.CancelFunc
} }
// NewJWTSecretManager creates a new JWT secret manager // NewJWTSecretManager creates a new JWT secret manager.
func NewJWTSecretManager(initialSecret string) JWTSecretManager { func NewJWTSecretManager(initialSecret string) JWTSecretManager {
return &jwtSecretManagerImpl{ return &jwtSecretManagerImpl{
secrets: []JWTSecret{ secrets: []JWTSecret{
@@ -24,58 +32,132 @@ func NewJWTSecretManager(initialSecret string) JWTSecretManager {
} }
} }
// AddSecret adds a new JWT secret // AddSecret adds a new JWT secret.
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) { func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
expiresAt := time.Now().Add(expiresIn) m.mu.Lock()
m.secrets = append(m.secrets, JWTSecret{ defer m.mu.Unlock()
m.addSecretLocked(secret, isPrimary, expiresIn)
}
// addSecretLocked is the internal helper that assumes the mutex is held.
func (m *jwtSecretManagerImpl) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
entry := JWTSecret{
Secret: secret, Secret: secret,
IsPrimary: isPrimary, IsPrimary: isPrimary,
CreatedAt: time.Now(), CreatedAt: time.Now(),
ExpiresAt: &expiresAt, }
}) if expiresIn > 0 {
expiresAt := time.Now().Add(expiresIn)
entry.ExpiresAt = &expiresAt
}
m.secrets = append(m.secrets, entry)
if isPrimary { if isPrimary {
m.primarySecret = secret m.primarySecret = secret
} }
} }
// RotateToSecret rotates to a new primary secret // RotateToSecret rotates to a new primary secret.
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) { func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
// Mark existing primary as non-primary m.mu.Lock()
defer m.mu.Unlock()
for i, secret := range m.secrets { for i, secret := range m.secrets {
if secret.IsPrimary { if secret.IsPrimary {
m.secrets[i].IsPrimary = false m.secrets[i].IsPrimary = false
break break
} }
} }
m.addSecretLocked(newSecret, true, 0)
// Add new secret as primary
m.AddSecret(newSecret, true, 0) // No expiration for primary
} }
// GetPrimarySecret returns the current primary secret // GetPrimarySecret returns the current primary secret.
func (m *jwtSecretManagerImpl) GetPrimarySecret() string { func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.primarySecret return m.primarySecret
} }
// GetAllValidSecrets returns all valid (non-expired) secrets // GetAllValidSecrets returns all valid (non-expired) secrets.
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret { func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
var validSecrets []JWTSecret m.mu.Lock()
now := time.Now() defer m.mu.Unlock()
now := time.Now()
valid := make([]JWTSecret, 0, len(m.secrets))
for _, secret := range m.secrets { for _, secret := range m.secrets {
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) { if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
validSecrets = append(validSecrets, secret) valid = append(valid, secret)
} }
} }
return valid
}
return validSecrets // GetSecretByIndex returns a secret by index for testing.
}
// GetSecretByIndex returns a secret by index for testing
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) { func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if index < 0 || index >= len(m.secrets) { if index < 0 || index >= len(m.secrets) {
return "", false return "", false
} }
return m.secrets[index].Secret, true return m.secrets[index].Secret, true
} }
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
// non-nil and in the past. Returns the count of secrets removed.
// The primary secret is never removed regardless of expiration (ADR-0021).
func (m *jwtSecretManagerImpl) RemoveExpiredSecrets() int {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
kept := make([]JWTSecret, 0, len(m.secrets))
removed := 0
for _, secret := range m.secrets {
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
removed++
continue
}
kept = append(kept, secret)
}
m.secrets = kept
return removed
}
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
// given interval. Stops when the parent context is cancelled. Calling again
// cancels the previous loop's context and starts a fresh one.
func (m *jwtSecretManagerImpl) StartCleanupLoop(ctx context.Context, interval time.Duration) {
m.mu.Lock()
if m.cleanupCancel != nil {
m.cleanupCancel()
}
loopCtx, cancel := context.WithCancel(ctx)
m.cleanupCancel = cancel
m.mu.Unlock()
if interval <= 0 {
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
for {
select {
case <-loopCtx.Done():
log.Info().Msg("JWT secret cleanup loop stopped")
return
case <-ticker.C:
removed := m.RemoveExpiredSecrets()
if removed > 0 {
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
} else {
log.Trace().Msg("JWT cleanup tick: no expired secrets")
}
}
}
}()
}

View File

@@ -701,6 +701,13 @@ func (s *Server) Run() error {
ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background()) ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background())
defer stopOngoingGracefully() defer stopOngoingGracefully()
// Start the JWT secret cleanup loop (ADR-0021). The loop runs until rootCtx
// is cancelled (graceful shutdown), removing non-primary secrets whose
// ExpiresAt is in the past.
if s.userService != nil {
s.userService.StartJWTSecretCleanupLoop(rootCtx, s.config.GetJWTSecretCleanupInterval())
}
// Create HTTP server // Create HTTP server
log.Trace().Str("address", s.config.GetServerAddress()).Msg("Server running") log.Trace().Str("address", s.config.GetServerAddress()).Msg("Server running")

View File

@@ -218,6 +218,18 @@ func (s *userServiceImpl) ResetJWTSecrets() {
s.secretManager.Reset(s.jwtConfig.Secret) s.secretManager.Reset(s.jwtConfig.Secret)
} }
// StartJWTSecretCleanupLoop delegates to the underlying secret manager to
// start the periodic cleanup goroutine described in ADR-0021.
func (s *userServiceImpl) StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration) {
s.secretManager.StartCleanupLoop(ctx, interval)
}
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass via the
// underlying secret manager. Returns the count of removed expired secrets.
func (s *userServiceImpl) RemoveExpiredJWTSecrets() int {
return s.secretManager.RemoveExpiredSecrets()
}
// UserExists checks if a user exists by username // UserExists checks if a user exists by username
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) { func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
return s.repo.UserExists(ctx, username) return s.repo.UserExists(ctx, username)

View File

@@ -1,7 +1,11 @@
package user package user
import ( import (
"context"
"sync"
"time" "time"
"github.com/rs/zerolog/log"
) )
// JWTSecret represents a JWT secret with metadata // JWTSecret represents a JWT secret with metadata
@@ -12,10 +16,16 @@ type JWTSecret struct {
ExpiresAt *time.Time // Optional expiration time ExpiresAt *time.Time // Optional expiration time
} }
// JWTSecretManager manages multiple JWT secrets for rotation // JWTSecretManager manages multiple JWT secrets for rotation.
// All operations are mutex-protected so the cleanup goroutine
// (StartCleanupLoop) can run alongside Generate / Validate calls.
// ADR-0021 implements automatic removal of expired secrets while
// always preserving the primary secret.
type JWTSecretManager struct { type JWTSecretManager struct {
mu sync.Mutex
secrets []JWTSecret secrets []JWTSecret
primarySecret string primarySecret string
cleanupCancel context.CancelFunc
} }
// NewJWTSecretManager creates a new JWT secret manager // NewJWTSecretManager creates a new JWT secret manager
@@ -34,12 +44,19 @@ func NewJWTSecretManager(initialSecret string) *JWTSecretManager {
// AddSecret adds a new JWT secret // AddSecret adds a new JWT secret
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) { func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
m.addSecretLocked(secret, isPrimary, expiresIn)
}
// addSecretLocked is the internal helper that assumes the mutex is held.
func (m *JWTSecretManager) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
var expiresAt *time.Time var expiresAt *time.Time
if expiresIn > 0 { if expiresIn > 0 {
expirationTime := time.Now().Add(expiresIn) expirationTime := time.Now().Add(expiresIn)
expiresAt = &expirationTime expiresAt = &expirationTime
} }
// If expiresIn is 0 or negative, expiresAt remains nil (no expiration) // expiresIn <= 0 means no expiration
m.secrets = append(m.secrets, JWTSecret{ m.secrets = append(m.secrets, JWTSecret{
Secret: secret, Secret: secret,
@@ -55,48 +72,60 @@ func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn ti
// RotateToSecret rotates to a new primary secret // RotateToSecret rotates to a new primary secret
func (m *JWTSecretManager) RotateToSecret(newSecret string) { func (m *JWTSecretManager) RotateToSecret(newSecret string) {
// Mark existing primary as non-primary m.mu.Lock()
defer m.mu.Unlock()
for i, secret := range m.secrets { for i, secret := range m.secrets {
if secret.IsPrimary { if secret.IsPrimary {
m.secrets[i].IsPrimary = false m.secrets[i].IsPrimary = false
break break
} }
} }
m.addSecretLocked(newSecret, true, 0)
// Add new secret as primary
m.AddSecret(newSecret, true, 0) // No expiration for primary
} }
// GetPrimarySecret returns the current primary secret // GetPrimarySecret returns the current primary secret
func (m *JWTSecretManager) GetPrimarySecret() string { func (m *JWTSecretManager) GetPrimarySecret() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.primarySecret return m.primarySecret
} }
// GetAllValidSecrets returns all valid (non-expired) secrets // GetAllValidSecrets returns all valid (non-expired) secrets
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret { func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
var validSecrets []JWTSecret m.mu.Lock()
now := time.Now() defer m.mu.Unlock()
now := time.Now()
valid := make([]JWTSecret, 0, len(m.secrets))
for _, secret := range m.secrets { for _, secret := range m.secrets {
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) { if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
validSecrets = append(validSecrets, secret) valid = append(valid, secret)
} }
} }
return valid
return validSecrets
} }
// GetSecretByIndex returns a secret by index for testing // GetSecretByIndex returns a secret by index for testing
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) { func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if index < 0 || index >= len(m.secrets) { if index < 0 || index >= len(m.secrets) {
return "", false return "", false
} }
return m.secrets[index].Secret, true return m.secrets[index].Secret, true
} }
// Reset resets the secret manager to its initial state with only the primary secret // Reset resets the secret manager to its initial state with only the primary
// This is useful for test cleanup to ensure tests don't interfere with each other // secret. Used for test cleanup so tests don't interfere with each other.
func (m *JWTSecretManager) Reset(initialSecret string) { func (m *JWTSecretManager) Reset(initialSecret string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.cleanupCancel != nil {
m.cleanupCancel()
m.cleanupCancel = nil
}
m.secrets = []JWTSecret{ m.secrets = []JWTSecret{
{ {
Secret: initialSecret, Secret: initialSecret,
@@ -106,3 +135,64 @@ func (m *JWTSecretManager) Reset(initialSecret string) {
} }
m.primarySecret = initialSecret m.primarySecret = initialSecret
} }
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
// non-nil and in the past. Returns the count of secrets removed.
// The primary secret is never removed regardless of expiration (ADR-0021).
func (m *JWTSecretManager) RemoveExpiredSecrets() int {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
kept := make([]JWTSecret, 0, len(m.secrets))
removed := 0
for _, secret := range m.secrets {
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
removed++
continue
}
kept = append(kept, secret)
}
m.secrets = kept
return removed
}
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
// given interval. Stops when the parent context is cancelled. Calling again
// cancels the previous loop's context and starts a fresh one.
// If interval <= 0, the loop is disabled (cleanup must be triggered manually
// via RemoveExpiredSecrets).
func (m *JWTSecretManager) StartCleanupLoop(ctx context.Context, interval time.Duration) {
m.mu.Lock()
if m.cleanupCancel != nil {
m.cleanupCancel()
}
loopCtx, cancel := context.WithCancel(ctx)
m.cleanupCancel = cancel
m.mu.Unlock()
if interval <= 0 {
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
for {
select {
case <-loopCtx.Done():
log.Info().Msg("JWT secret cleanup loop stopped")
return
case <-ticker.C:
removed := m.RemoveExpiredSecrets()
if removed > 0 {
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
} else {
log.Trace().Msg("JWT cleanup tick: no expired secrets")
}
}
}
}()
}

View File

@@ -1,6 +1,7 @@
package user package user
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -84,3 +85,73 @@ func TestJWTSecretExpiration(t *testing.T) {
} }
assert.True(t, foundExpiring) assert.True(t, foundExpiring)
} }
// TestRemoveExpiredSecrets_ExpiredNonPrimaryRemoved confirms that
// RemoveExpiredSecrets drops a non-primary secret whose ExpiresAt is in the past.
func TestRemoveExpiredSecrets_ExpiredNonPrimaryRemoved(t *testing.T) {
manager := NewJWTSecretManager("primary")
// Add a secret that expired 1 hour ago by setting expiresIn to a small
// positive duration then mutating after via AddSecret + manipulation.
// Simpler: add with a 1ns lifetime and sleep 2ns equivalent (tiny TTL).
manager.AddSecret("about-to-expire", false, 1*time.Nanosecond)
time.Sleep(5 * time.Millisecond)
removed := manager.RemoveExpiredSecrets()
assert.Equal(t, 1, removed, "one expired secret should be removed")
secrets := manager.GetAllValidSecrets()
assert.Len(t, secrets, 1, "only primary should remain")
assert.Equal(t, "primary", secrets[0].Secret)
assert.True(t, secrets[0].IsPrimary)
}
// TestRemoveExpiredSecrets_PrimaryNeverRemoved confirms the primary secret
// is preserved even if (somehow) marked expired - ADR-0021 invariant.
func TestRemoveExpiredSecrets_PrimaryNeverRemoved(t *testing.T) {
manager := NewJWTSecretManager("primary")
// Add a non-primary that doesn't expire
manager.AddSecret("kept", false, 0)
// Simulate an "expired primary" by manipulating internals via Reset then
// re-creating - here we rely on the public contract: primary has no
// ExpiresAt by default. Confirm cleanup leaves it.
removed := manager.RemoveExpiredSecrets()
assert.Equal(t, 0, removed)
assert.Equal(t, "primary", manager.GetPrimarySecret())
}
// TestRemoveExpiredSecrets_NonExpiredKept confirms a future-expiring secret
// stays after cleanup.
func TestRemoveExpiredSecrets_NonExpiredKept(t *testing.T) {
manager := NewJWTSecretManager("primary")
manager.AddSecret("future", false, 1*time.Hour)
removed := manager.RemoveExpiredSecrets()
assert.Equal(t, 0, removed)
assert.Len(t, manager.GetAllValidSecrets(), 2)
}
// TestStartCleanupLoop_FiresAndStops confirms the goroutine actually calls
// RemoveExpiredSecrets on each tick and stops cleanly when the context is
// cancelled. Uses a short interval to keep the test fast.
func TestStartCleanupLoop_FiresAndStops(t *testing.T) {
manager := NewJWTSecretManager("primary")
manager.AddSecret("dies", false, 5*time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartCleanupLoop(ctx, 10*time.Millisecond)
// Wait long enough for at least one tick + the secret's TTL
time.Sleep(50 * time.Millisecond)
cancel() // stop the loop
secrets := manager.GetAllValidSecrets()
assert.Len(t, secrets, 1, "expired secret should have been removed by the loop")
assert.Equal(t, "primary", secrets[0].Secret)
}

View File

@@ -43,6 +43,15 @@ type AuthService interface {
RotateJWTSecret(newSecret string) RotateJWTSecret(newSecret string)
GetJWTSecretByIndex(index int) (string, bool) GetJWTSecretByIndex(index int) (string, bool)
ResetJWTSecrets() // Reset JWT secrets to initial state for test cleanup ResetJWTSecrets() // Reset JWT secrets to initial state for test cleanup
// StartJWTSecretCleanupLoop starts a goroutine that periodically calls
// RemoveExpiredJWTSecrets at the given interval, stopping when ctx is
// cancelled. Implements the cleanup half of ADR-0021. interval <= 0
// disables the loop.
StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration)
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass and returns
// the count of removed non-primary expired secrets. Useful for tests
// driving cleanup synchronously.
RemoveExpiredJWTSecrets() int
} }
// UserManager defines interface for user management operations // UserManager defines interface for user management operations