✨ feat(auth): JWT secret retention policy + automatic cleanup loop (ADR-0021) #41
@@ -1,6 +1,6 @@
|
|||||||
# 10. JWT Secret Retention Policy
|
# 21. JWT Secret Retention Policy
|
||||||
|
|
||||||
**Status:** Proposed
|
**Status:** Implemented (2026-05-05 — `pkg/user/jwt_manager.go` `RemoveExpiredSecrets` + `StartCleanupLoop`, wired in `pkg/server/server.go` `Run`; admin endpoint `/api/v1/admin/jwt/secrets` remains explicitly out of scope and tracked under @todo BDD scenarios)
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
|
|
||||||
|
|||||||
@@ -24,13 +24,25 @@ type JWTSecret struct {
|
|||||||
ExpiresAt *time.Time // Optional expiration time
|
ExpiresAt *time.Time // Optional expiration time
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTSecretManager manages multiple JWT secrets for rotation
|
// JWTSecretManager manages multiple JWT secrets for rotation.
|
||||||
|
// Secrets can carry an optional expiration; the cleanup loop removes them
|
||||||
|
// after expiry while always preserving the primary secret (ADR-0021).
|
||||||
type JWTSecretManager interface {
|
type JWTSecretManager interface {
|
||||||
AddSecret(secret string, isPrimary bool, expiresIn time.Duration)
|
AddSecret(secret string, isPrimary bool, expiresIn time.Duration)
|
||||||
RotateToSecret(newSecret string)
|
RotateToSecret(newSecret string)
|
||||||
GetPrimarySecret() string
|
GetPrimarySecret() string
|
||||||
GetAllValidSecrets() []JWTSecret
|
GetAllValidSecrets() []JWTSecret
|
||||||
GetSecretByIndex(index int) (string, bool)
|
GetSecretByIndex(index int) (string, bool)
|
||||||
|
|
||||||
|
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
|
||||||
|
// non-nil and in the past. Returns the count of secrets removed.
|
||||||
|
// The primary secret is never removed regardless of expiration.
|
||||||
|
RemoveExpiredSecrets() int
|
||||||
|
|
||||||
|
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at
|
||||||
|
// the given interval. Stops when the context is cancelled. Safe to call
|
||||||
|
// once at startup; calling again replaces the previous loop's context.
|
||||||
|
StartCleanupLoop(ctx context.Context, interval time.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTService defines interface for JWT operations
|
// JWTService defines interface for JWT operations
|
||||||
|
|||||||
@@ -1,16 +1,24 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// jwtSecretManagerImpl implements the JWTSecretManager interface
|
// jwtSecretManagerImpl implements the JWTSecretManager interface.
|
||||||
|
// All operations are mutex-protected so the cleanup goroutine
|
||||||
|
// (StartCleanupLoop) can run alongside Generate / Validate calls.
|
||||||
type jwtSecretManagerImpl struct {
|
type jwtSecretManagerImpl struct {
|
||||||
|
mu sync.Mutex
|
||||||
secrets []JWTSecret
|
secrets []JWTSecret
|
||||||
primarySecret string
|
primarySecret string
|
||||||
|
cleanupCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTSecretManager creates a new JWT secret manager
|
// NewJWTSecretManager creates a new JWT secret manager.
|
||||||
func NewJWTSecretManager(initialSecret string) JWTSecretManager {
|
func NewJWTSecretManager(initialSecret string) JWTSecretManager {
|
||||||
return &jwtSecretManagerImpl{
|
return &jwtSecretManagerImpl{
|
||||||
secrets: []JWTSecret{
|
secrets: []JWTSecret{
|
||||||
@@ -24,58 +32,132 @@ func NewJWTSecretManager(initialSecret string) JWTSecretManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSecret adds a new JWT secret
|
// AddSecret adds a new JWT secret.
|
||||||
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
expiresAt := time.Now().Add(expiresIn)
|
m.mu.Lock()
|
||||||
m.secrets = append(m.secrets, JWTSecret{
|
defer m.mu.Unlock()
|
||||||
|
m.addSecretLocked(secret, isPrimary, expiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addSecretLocked is the internal helper that assumes the mutex is held.
|
||||||
|
func (m *jwtSecretManagerImpl) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
|
entry := JWTSecret{
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
IsPrimary: isPrimary,
|
IsPrimary: isPrimary,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
ExpiresAt: &expiresAt,
|
}
|
||||||
})
|
if expiresIn > 0 {
|
||||||
|
expiresAt := time.Now().Add(expiresIn)
|
||||||
|
entry.ExpiresAt = &expiresAt
|
||||||
|
}
|
||||||
|
m.secrets = append(m.secrets, entry)
|
||||||
|
|
||||||
if isPrimary {
|
if isPrimary {
|
||||||
m.primarySecret = secret
|
m.primarySecret = secret
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RotateToSecret rotates to a new primary secret
|
// RotateToSecret rotates to a new primary secret.
|
||||||
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
|
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
|
||||||
// Mark existing primary as non-primary
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
for i, secret := range m.secrets {
|
for i, secret := range m.secrets {
|
||||||
if secret.IsPrimary {
|
if secret.IsPrimary {
|
||||||
m.secrets[i].IsPrimary = false
|
m.secrets[i].IsPrimary = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.addSecretLocked(newSecret, true, 0)
|
||||||
// Add new secret as primary
|
|
||||||
m.AddSecret(newSecret, true, 0) // No expiration for primary
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPrimarySecret returns the current primary secret
|
// GetPrimarySecret returns the current primary secret.
|
||||||
func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
|
func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
return m.primarySecret
|
return m.primarySecret
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllValidSecrets returns all valid (non-expired) secrets
|
// GetAllValidSecrets returns all valid (non-expired) secrets.
|
||||||
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
|
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
|
||||||
var validSecrets []JWTSecret
|
m.mu.Lock()
|
||||||
now := time.Now()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
valid := make([]JWTSecret, 0, len(m.secrets))
|
||||||
for _, secret := range m.secrets {
|
for _, secret := range m.secrets {
|
||||||
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
||||||
validSecrets = append(validSecrets, secret)
|
valid = append(valid, secret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
return validSecrets
|
// GetSecretByIndex returns a secret by index for testing.
|
||||||
}
|
|
||||||
|
|
||||||
// GetSecretByIndex returns a secret by index for testing
|
|
||||||
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
|
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
if index < 0 || index >= len(m.secrets) {
|
if index < 0 || index >= len(m.secrets) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
return m.secrets[index].Secret, true
|
return m.secrets[index].Secret, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
|
||||||
|
// non-nil and in the past. Returns the count of secrets removed.
|
||||||
|
// The primary secret is never removed regardless of expiration (ADR-0021).
|
||||||
|
func (m *jwtSecretManagerImpl) RemoveExpiredSecrets() int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
kept := make([]JWTSecret, 0, len(m.secrets))
|
||||||
|
removed := 0
|
||||||
|
for _, secret := range m.secrets {
|
||||||
|
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
|
||||||
|
removed++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, secret)
|
||||||
|
}
|
||||||
|
m.secrets = kept
|
||||||
|
return removed
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
|
||||||
|
// given interval. Stops when the parent context is cancelled. Calling again
|
||||||
|
// cancels the previous loop's context and starts a fresh one.
|
||||||
|
func (m *jwtSecretManagerImpl) StartCleanupLoop(ctx context.Context, interval time.Duration) {
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.cleanupCancel != nil {
|
||||||
|
m.cleanupCancel()
|
||||||
|
}
|
||||||
|
loopCtx, cancel := context.WithCancel(ctx)
|
||||||
|
m.cleanupCancel = cancel
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if interval <= 0 {
|
||||||
|
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-loopCtx.Done():
|
||||||
|
log.Info().Msg("JWT secret cleanup loop stopped")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
removed := m.RemoveExpiredSecrets()
|
||||||
|
if removed > 0 {
|
||||||
|
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
|
||||||
|
} else {
|
||||||
|
log.Trace().Msg("JWT cleanup tick: no expired secrets")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|||||||
@@ -701,6 +701,13 @@ func (s *Server) Run() error {
|
|||||||
ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background())
|
ongoingCtx, stopOngoingGracefully := context.WithCancel(context.Background())
|
||||||
defer stopOngoingGracefully()
|
defer stopOngoingGracefully()
|
||||||
|
|
||||||
|
// Start the JWT secret cleanup loop (ADR-0021). The loop runs until rootCtx
|
||||||
|
// is cancelled (graceful shutdown), removing non-primary secrets whose
|
||||||
|
// ExpiresAt is in the past.
|
||||||
|
if s.userService != nil {
|
||||||
|
s.userService.StartJWTSecretCleanupLoop(rootCtx, s.config.GetJWTSecretCleanupInterval())
|
||||||
|
}
|
||||||
|
|
||||||
// Create HTTP server
|
// Create HTTP server
|
||||||
log.Trace().Str("address", s.config.GetServerAddress()).Msg("Server running")
|
log.Trace().Str("address", s.config.GetServerAddress()).Msg("Server running")
|
||||||
|
|
||||||
|
|||||||
@@ -218,6 +218,18 @@ func (s *userServiceImpl) ResetJWTSecrets() {
|
|||||||
s.secretManager.Reset(s.jwtConfig.Secret)
|
s.secretManager.Reset(s.jwtConfig.Secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartJWTSecretCleanupLoop delegates to the underlying secret manager to
|
||||||
|
// start the periodic cleanup goroutine described in ADR-0021.
|
||||||
|
func (s *userServiceImpl) StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration) {
|
||||||
|
s.secretManager.StartCleanupLoop(ctx, interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass via the
|
||||||
|
// underlying secret manager. Returns the count of removed expired secrets.
|
||||||
|
func (s *userServiceImpl) RemoveExpiredJWTSecrets() int {
|
||||||
|
return s.secretManager.RemoveExpiredSecrets()
|
||||||
|
}
|
||||||
|
|
||||||
// UserExists checks if a user exists by username
|
// UserExists checks if a user exists by username
|
||||||
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
|
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
|
||||||
return s.repo.UserExists(ctx, username)
|
return s.repo.UserExists(ctx, username)
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWTSecret represents a JWT secret with metadata
|
// JWTSecret represents a JWT secret with metadata
|
||||||
@@ -12,10 +16,16 @@ type JWTSecret struct {
|
|||||||
ExpiresAt *time.Time // Optional expiration time
|
ExpiresAt *time.Time // Optional expiration time
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTSecretManager manages multiple JWT secrets for rotation
|
// JWTSecretManager manages multiple JWT secrets for rotation.
|
||||||
|
// All operations are mutex-protected so the cleanup goroutine
|
||||||
|
// (StartCleanupLoop) can run alongside Generate / Validate calls.
|
||||||
|
// ADR-0021 implements automatic removal of expired secrets while
|
||||||
|
// always preserving the primary secret.
|
||||||
type JWTSecretManager struct {
|
type JWTSecretManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
secrets []JWTSecret
|
secrets []JWTSecret
|
||||||
primarySecret string
|
primarySecret string
|
||||||
|
cleanupCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTSecretManager creates a new JWT secret manager
|
// NewJWTSecretManager creates a new JWT secret manager
|
||||||
@@ -34,12 +44,19 @@ func NewJWTSecretManager(initialSecret string) *JWTSecretManager {
|
|||||||
|
|
||||||
// AddSecret adds a new JWT secret
|
// AddSecret adds a new JWT secret
|
||||||
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.addSecretLocked(secret, isPrimary, expiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addSecretLocked is the internal helper that assumes the mutex is held.
|
||||||
|
func (m *JWTSecretManager) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
var expiresAt *time.Time
|
var expiresAt *time.Time
|
||||||
if expiresIn > 0 {
|
if expiresIn > 0 {
|
||||||
expirationTime := time.Now().Add(expiresIn)
|
expirationTime := time.Now().Add(expiresIn)
|
||||||
expiresAt = &expirationTime
|
expiresAt = &expirationTime
|
||||||
}
|
}
|
||||||
// If expiresIn is 0 or negative, expiresAt remains nil (no expiration)
|
// expiresIn <= 0 means no expiration
|
||||||
|
|
||||||
m.secrets = append(m.secrets, JWTSecret{
|
m.secrets = append(m.secrets, JWTSecret{
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
@@ -55,48 +72,60 @@ func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn ti
|
|||||||
|
|
||||||
// RotateToSecret rotates to a new primary secret
|
// RotateToSecret rotates to a new primary secret
|
||||||
func (m *JWTSecretManager) RotateToSecret(newSecret string) {
|
func (m *JWTSecretManager) RotateToSecret(newSecret string) {
|
||||||
// Mark existing primary as non-primary
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
for i, secret := range m.secrets {
|
for i, secret := range m.secrets {
|
||||||
if secret.IsPrimary {
|
if secret.IsPrimary {
|
||||||
m.secrets[i].IsPrimary = false
|
m.secrets[i].IsPrimary = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.addSecretLocked(newSecret, true, 0)
|
||||||
// Add new secret as primary
|
|
||||||
m.AddSecret(newSecret, true, 0) // No expiration for primary
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPrimarySecret returns the current primary secret
|
// GetPrimarySecret returns the current primary secret
|
||||||
func (m *JWTSecretManager) GetPrimarySecret() string {
|
func (m *JWTSecretManager) GetPrimarySecret() string {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
return m.primarySecret
|
return m.primarySecret
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllValidSecrets returns all valid (non-expired) secrets
|
// GetAllValidSecrets returns all valid (non-expired) secrets
|
||||||
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
|
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
|
||||||
var validSecrets []JWTSecret
|
m.mu.Lock()
|
||||||
now := time.Now()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
valid := make([]JWTSecret, 0, len(m.secrets))
|
||||||
for _, secret := range m.secrets {
|
for _, secret := range m.secrets {
|
||||||
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
||||||
validSecrets = append(validSecrets, secret)
|
valid = append(valid, secret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return valid
|
||||||
return validSecrets
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSecretByIndex returns a secret by index for testing
|
// GetSecretByIndex returns a secret by index for testing
|
||||||
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
|
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
if index < 0 || index >= len(m.secrets) {
|
if index < 0 || index >= len(m.secrets) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
return m.secrets[index].Secret, true
|
return m.secrets[index].Secret, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset resets the secret manager to its initial state with only the primary secret
|
// Reset resets the secret manager to its initial state with only the primary
|
||||||
// This is useful for test cleanup to ensure tests don't interfere with each other
|
// secret. Used for test cleanup so tests don't interfere with each other.
|
||||||
func (m *JWTSecretManager) Reset(initialSecret string) {
|
func (m *JWTSecretManager) Reset(initialSecret string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.cleanupCancel != nil {
|
||||||
|
m.cleanupCancel()
|
||||||
|
m.cleanupCancel = nil
|
||||||
|
}
|
||||||
m.secrets = []JWTSecret{
|
m.secrets = []JWTSecret{
|
||||||
{
|
{
|
||||||
Secret: initialSecret,
|
Secret: initialSecret,
|
||||||
@@ -106,3 +135,64 @@ func (m *JWTSecretManager) Reset(initialSecret string) {
|
|||||||
}
|
}
|
||||||
m.primarySecret = initialSecret
|
m.primarySecret = initialSecret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
|
||||||
|
// non-nil and in the past. Returns the count of secrets removed.
|
||||||
|
// The primary secret is never removed regardless of expiration (ADR-0021).
|
||||||
|
func (m *JWTSecretManager) RemoveExpiredSecrets() int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
kept := make([]JWTSecret, 0, len(m.secrets))
|
||||||
|
removed := 0
|
||||||
|
for _, secret := range m.secrets {
|
||||||
|
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
|
||||||
|
removed++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, secret)
|
||||||
|
}
|
||||||
|
m.secrets = kept
|
||||||
|
return removed
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
|
||||||
|
// given interval. Stops when the parent context is cancelled. Calling again
|
||||||
|
// cancels the previous loop's context and starts a fresh one.
|
||||||
|
// If interval <= 0, the loop is disabled (cleanup must be triggered manually
|
||||||
|
// via RemoveExpiredSecrets).
|
||||||
|
func (m *JWTSecretManager) StartCleanupLoop(ctx context.Context, interval time.Duration) {
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.cleanupCancel != nil {
|
||||||
|
m.cleanupCancel()
|
||||||
|
}
|
||||||
|
loopCtx, cancel := context.WithCancel(ctx)
|
||||||
|
m.cleanupCancel = cancel
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if interval <= 0 {
|
||||||
|
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-loopCtx.Done():
|
||||||
|
log.Info().Msg("JWT secret cleanup loop stopped")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
removed := m.RemoveExpiredSecrets()
|
||||||
|
if removed > 0 {
|
||||||
|
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
|
||||||
|
} else {
|
||||||
|
log.Trace().Msg("JWT cleanup tick: no expired secrets")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -84,3 +85,73 @@ func TestJWTSecretExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.True(t, foundExpiring)
|
assert.True(t, foundExpiring)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRemoveExpiredSecrets_ExpiredNonPrimaryRemoved confirms that
|
||||||
|
// RemoveExpiredSecrets drops a non-primary secret whose ExpiresAt is in the past.
|
||||||
|
func TestRemoveExpiredSecrets_ExpiredNonPrimaryRemoved(t *testing.T) {
|
||||||
|
manager := NewJWTSecretManager("primary")
|
||||||
|
|
||||||
|
// Add a secret that expired 1 hour ago by setting expiresIn to a small
|
||||||
|
// positive duration then mutating after via AddSecret + manipulation.
|
||||||
|
// Simpler: add with a 1ns lifetime and sleep 2ns equivalent (tiny TTL).
|
||||||
|
manager.AddSecret("about-to-expire", false, 1*time.Nanosecond)
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
removed := manager.RemoveExpiredSecrets()
|
||||||
|
assert.Equal(t, 1, removed, "one expired secret should be removed")
|
||||||
|
|
||||||
|
secrets := manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 1, "only primary should remain")
|
||||||
|
assert.Equal(t, "primary", secrets[0].Secret)
|
||||||
|
assert.True(t, secrets[0].IsPrimary)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoveExpiredSecrets_PrimaryNeverRemoved confirms the primary secret
|
||||||
|
// is preserved even if (somehow) marked expired - ADR-0021 invariant.
|
||||||
|
func TestRemoveExpiredSecrets_PrimaryNeverRemoved(t *testing.T) {
|
||||||
|
manager := NewJWTSecretManager("primary")
|
||||||
|
|
||||||
|
// Add a non-primary that doesn't expire
|
||||||
|
manager.AddSecret("kept", false, 0)
|
||||||
|
|
||||||
|
// Simulate an "expired primary" by manipulating internals via Reset then
|
||||||
|
// re-creating - here we rely on the public contract: primary has no
|
||||||
|
// ExpiresAt by default. Confirm cleanup leaves it.
|
||||||
|
removed := manager.RemoveExpiredSecrets()
|
||||||
|
assert.Equal(t, 0, removed)
|
||||||
|
|
||||||
|
assert.Equal(t, "primary", manager.GetPrimarySecret())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoveExpiredSecrets_NonExpiredKept confirms a future-expiring secret
|
||||||
|
// stays after cleanup.
|
||||||
|
func TestRemoveExpiredSecrets_NonExpiredKept(t *testing.T) {
|
||||||
|
manager := NewJWTSecretManager("primary")
|
||||||
|
manager.AddSecret("future", false, 1*time.Hour)
|
||||||
|
|
||||||
|
removed := manager.RemoveExpiredSecrets()
|
||||||
|
assert.Equal(t, 0, removed)
|
||||||
|
assert.Len(t, manager.GetAllValidSecrets(), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStartCleanupLoop_FiresAndStops confirms the goroutine actually calls
|
||||||
|
// RemoveExpiredSecrets on each tick and stops cleanly when the context is
|
||||||
|
// cancelled. Uses a short interval to keep the test fast.
|
||||||
|
func TestStartCleanupLoop_FiresAndStops(t *testing.T) {
|
||||||
|
manager := NewJWTSecretManager("primary")
|
||||||
|
manager.AddSecret("dies", false, 5*time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
manager.StartCleanupLoop(ctx, 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Wait long enough for at least one tick + the secret's TTL
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
cancel() // stop the loop
|
||||||
|
|
||||||
|
secrets := manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 1, "expired secret should have been removed by the loop")
|
||||||
|
assert.Equal(t, "primary", secrets[0].Secret)
|
||||||
|
}
|
||||||
|
|||||||
@@ -43,6 +43,15 @@ type AuthService interface {
|
|||||||
RotateJWTSecret(newSecret string)
|
RotateJWTSecret(newSecret string)
|
||||||
GetJWTSecretByIndex(index int) (string, bool)
|
GetJWTSecretByIndex(index int) (string, bool)
|
||||||
ResetJWTSecrets() // Reset JWT secrets to initial state for test cleanup
|
ResetJWTSecrets() // Reset JWT secrets to initial state for test cleanup
|
||||||
|
// StartJWTSecretCleanupLoop starts a goroutine that periodically calls
|
||||||
|
// RemoveExpiredJWTSecrets at the given interval, stopping when ctx is
|
||||||
|
// cancelled. Implements the cleanup half of ADR-0021. interval <= 0
|
||||||
|
// disables the loop.
|
||||||
|
StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration)
|
||||||
|
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass and returns
|
||||||
|
// the count of removed non-primary expired secrets. Useful for tests
|
||||||
|
// driving cleanup synchronously.
|
||||||
|
RemoveExpiredJWTSecrets() int
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserManager defines interface for user management operations
|
// UserManager defines interface for user management operations
|
||||||
|
|||||||
Reference in New Issue
Block a user