Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
164 lines
4.2 KiB
Go
164 lines
4.2 KiB
Go
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// jwtSecretManagerImpl implements the JWTSecretManager interface.
|
|
// All operations are mutex-protected so the cleanup goroutine
|
|
// (StartCleanupLoop) can run alongside Generate / Validate calls.
|
|
type jwtSecretManagerImpl struct {
|
|
mu sync.Mutex
|
|
secrets []JWTSecret
|
|
primarySecret string
|
|
cleanupCancel context.CancelFunc
|
|
}
|
|
|
|
// NewJWTSecretManager creates a new JWT secret manager.
|
|
func NewJWTSecretManager(initialSecret string) JWTSecretManager {
|
|
return &jwtSecretManagerImpl{
|
|
secrets: []JWTSecret{
|
|
{
|
|
Secret: initialSecret,
|
|
IsPrimary: true,
|
|
CreatedAt: time.Now(),
|
|
},
|
|
},
|
|
primarySecret: initialSecret,
|
|
}
|
|
}
|
|
|
|
// AddSecret adds a new JWT secret.
|
|
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.addSecretLocked(secret, isPrimary, expiresIn)
|
|
}
|
|
|
|
// addSecretLocked is the internal helper that assumes the mutex is held.
|
|
func (m *jwtSecretManagerImpl) addSecretLocked(secret string, isPrimary bool, expiresIn time.Duration) {
|
|
entry := JWTSecret{
|
|
Secret: secret,
|
|
IsPrimary: isPrimary,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if expiresIn > 0 {
|
|
expiresAt := time.Now().Add(expiresIn)
|
|
entry.ExpiresAt = &expiresAt
|
|
}
|
|
m.secrets = append(m.secrets, entry)
|
|
|
|
if isPrimary {
|
|
m.primarySecret = secret
|
|
}
|
|
}
|
|
|
|
// RotateToSecret rotates to a new primary secret.
|
|
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for i, secret := range m.secrets {
|
|
if secret.IsPrimary {
|
|
m.secrets[i].IsPrimary = false
|
|
break
|
|
}
|
|
}
|
|
m.addSecretLocked(newSecret, true, 0)
|
|
}
|
|
|
|
// GetPrimarySecret returns the current primary secret.
|
|
func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.primarySecret
|
|
}
|
|
|
|
// GetAllValidSecrets returns all valid (non-expired) secrets.
|
|
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
valid := make([]JWTSecret, 0, len(m.secrets))
|
|
for _, secret := range m.secrets {
|
|
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
|
valid = append(valid, secret)
|
|
}
|
|
}
|
|
return valid
|
|
}
|
|
|
|
// GetSecretByIndex returns a secret by index for testing.
|
|
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if index < 0 || index >= len(m.secrets) {
|
|
return "", false
|
|
}
|
|
return m.secrets[index].Secret, true
|
|
}
|
|
|
|
// RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is
|
|
// non-nil and in the past. Returns the count of secrets removed.
|
|
// The primary secret is never removed regardless of expiration (ADR-0021).
|
|
func (m *jwtSecretManagerImpl) RemoveExpiredSecrets() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
kept := make([]JWTSecret, 0, len(m.secrets))
|
|
removed := 0
|
|
for _, secret := range m.secrets {
|
|
if !secret.IsPrimary && secret.ExpiresAt != nil && !secret.ExpiresAt.After(now) {
|
|
removed++
|
|
continue
|
|
}
|
|
kept = append(kept, secret)
|
|
}
|
|
m.secrets = kept
|
|
return removed
|
|
}
|
|
|
|
// StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at the
|
|
// given interval. Stops when the parent context is cancelled. Calling again
|
|
// cancels the previous loop's context and starts a fresh one.
|
|
func (m *jwtSecretManagerImpl) StartCleanupLoop(ctx context.Context, interval time.Duration) {
|
|
m.mu.Lock()
|
|
if m.cleanupCancel != nil {
|
|
m.cleanupCancel()
|
|
}
|
|
loopCtx, cancel := context.WithCancel(ctx)
|
|
m.cleanupCancel = cancel
|
|
m.mu.Unlock()
|
|
|
|
if interval <= 0 {
|
|
log.Warn().Dur("interval", interval).Msg("JWT secret cleanup interval is non-positive, loop disabled")
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
log.Info().Dur("interval", interval).Msg("JWT secret cleanup loop started")
|
|
for {
|
|
select {
|
|
case <-loopCtx.Done():
|
|
log.Info().Msg("JWT secret cleanup loop stopped")
|
|
return
|
|
case <-ticker.C:
|
|
removed := m.RemoveExpiredSecrets()
|
|
if removed > 0 {
|
|
log.Info().Int("removed", removed).Msg("JWT secrets cleaned up")
|
|
} else {
|
|
log.Trace().Msg("JWT cleanup tick: no expired secrets")
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|