package user import ( "context" "sync" "time" "github.com/rs/zerolog/log" ) // JWTSecret represents a JWT secret with metadata type JWTSecret struct { Secret string IsPrimary bool CreatedAt time.Time ExpiresAt *time.Time // Optional expiration time } // JWTSecretManager manages multiple JWT secrets for rotation. // All operations are mutex-protected so the cleanup goroutine // (StartCleanupLoop) can run alongside Generate / Validate calls. // ADR-0021 implements automatic removal of expired secrets while // always preserving the primary secret. type JWTSecretManager struct { mu sync.Mutex secrets []JWTSecret primarySecret string cleanupCancel context.CancelFunc } // NewJWTSecretManager creates a new JWT secret manager func NewJWTSecretManager(initialSecret string) *JWTSecretManager { return &JWTSecretManager{ secrets: []JWTSecret{ { Secret: initialSecret, IsPrimary: true, CreatedAt: time.Now(), }, }, primarySecret: initialSecret, } } // AddSecret adds a new JWT secret func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) { m.mu.Lock() defer m.mu.Unlock() m.addSecretLocked(secret, isPrimary, expiresIn) } // addSecretLocked is the internal helper that assumes the mutex is held. func (m *JWTSecretManager) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) { var expiresAt *time.Time if expiresIn > 0 { expirationTime := time.Now().Add(expiresIn) expiresAt = &expirationTime } // expiresIn <= 0 means no expiration m.secrets = append(m.secrets, JWTSecret{ Secret: secret, IsPrimary: isPrimary, CreatedAt: time.Now(), ExpiresAt: expiresAt, }) if isPrimary { m.primarySecret = secret } } // RotateToSecret rotates to a new primary secret func (m *JWTSecretManager) RotateToSecret(newSecret string) { m.mu.Lock() defer m.mu.Unlock() for i, secret := range m.secrets { if secret.IsPrimary { m.secrets[i].IsPrimary = false break } } m.addSecretLocked(newSecret, true, 0) } // GetPrimarySecret returns the current primary secret func (m *JWTSecretManager) GetPrimarySecret() string { m.mu.Lock() defer m.mu.Unlock() return m.primarySecret } // GetAllValidSecrets returns all valid (non-expired) secrets func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret { m.mu.Lock() defer m.mu.Unlock() now := time.Now() valid := make([]JWTSecret, 0, len(m.secrets)) for _, secret := range m.secrets { if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) { valid = append(valid, secret) } } return valid } // GetSecretByIndex returns a secret by index for testing func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) { m.mu.Lock() defer m.mu.Unlock() if index < 0 || index >= len(m.secrets) { return "", false } return m.secrets[index].Secret, true } // Reset resets the secret manager to its initial state with only the primary // secret. Used for test cleanup so tests don't interfere with each other. func (m *JWTSecretManager) Reset(initialSecret string) { m.mu.Lock() defer m.mu.Unlock() if m.cleanupCancel != nil { m.cleanupCancel() m.cleanupCancel = nil } m.secrets = []JWTSecret{ { Secret: initialSecret, IsPrimary: true, CreatedAt: time.Now(), }, } m.primarySecret = initialSecret } // RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is // non-nil and in the past. Returns the count of secrets removed. // The primary secret is never removed regardless of expiration (ADR-0021). func (m *JWTSecretManager) RemoveExpiredSecrets() int { m.mu.Lock() defer m.mu.Unlock() now := time.Now() kept := make([]JWTSecret, 0, len(m.secrets)) removed := 0 for _, secret := range m.secrets { if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) { removed++ continue } kept = append(kept, secret) } m.secrets = kept return removed } // StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the // given interval. Stops when the parent context is cancelled. Calling again // cancels the previous loop's context and starts a fresh one. // If interval <= 0, the loop is disabled (cleanup must be triggered manually // via RemoveExpiredSecrets). func (m *JWTSecretManager) StartCleanupLoop(ctx context.Context, interval time.Duration) { m.mu.Lock() if m.cleanupCancel != nil { m.cleanupCancel() } loopCtx, cancel := context.WithCancel(ctx) m.cleanupCancel = cancel m.mu.Unlock() if interval <= 0 { log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled") return } go func() { ticker := time.NewTicker(interval) defer ticker.Stop() log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started") for { select { case <-loopCtx.Done(): log.Info().Msg("JWT secret cleanup loop stopped") return case <-ticker.C: removed := m.RemoveExpiredSecrets() if removed > 0 { log.Info().Int("removed", removed).Msg("JWT secrets cleaned up") } else { log.Trace().Msg("JWT cleanup tick: no expired secrets") } } } }() }