🧪 test: implement JWT secret rotation BDD tests
All checks were successful
CI/CD Pipeline / Build Docker Cache (push) Successful in 11s
CI/CD Pipeline / CI Pipeline (push) Successful in 4m32s

- Fix admin handler to handle flexible boolean parsing

- Modify GenerateJWT to use latest secret for signing

- Update JWT secret manager for proper expiration handling

- Fix BDD test steps to use actual tokens instead of hardcoded ones

- Add comprehensive debug logging for JWT operations

Resolves JWT secret rotation feature implementation

Generated by Mistral Vibe.

Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
2026-04-09 16:14:31 +02:00
parent 695cd407f2
commit 07f8bd65b7
9 changed files with 742 additions and 35 deletions

View File

@@ -0,0 +1,149 @@
package api
import (
"encoding/json"
"net/http"
"time"
"dance-lessons-coach/pkg/user"
"github.com/go-chi/chi/v5"
)
// AdminHandler handles admin-related HTTP requests
type AdminHandler struct {
authService user.AuthService
}
// NewAdminHandler creates a new admin handler
func NewAdminHandler(authService user.AuthService) *AdminHandler {
return &AdminHandler{
authService: authService,
}
}
// RegisterRoutes registers admin routes
func (h *AdminHandler) RegisterRoutes(router chi.Router) {
router.Route("/jwt", func(r chi.Router) {
r.Post("/secrets", h.handleAddJWTSecret)
r.Post("/secrets/rotate", h.handleRotateJWTSecret)
})
}
// AddJWTSecretRequest represents a request to add a new JWT secret
type AddJWTSecretRequest struct {
Secret string `json:"secret" validate:"required,min=16"`
IsPrimary bool `json:"is_primary"`
ExpiresIn int64 `json:"expires_in"` // Expiration time in hours
}
// handleAddJWTSecret godoc
//
// @Summary Add JWT secret
// @Description Add a new JWT secret for rotation purposes
// @Tags API/v1/Admin
// @Accept json
// @Produce json
// @Param request body AddJWTSecretRequest true "JWT secret details"
// @Success 200 {object} map[string]string "Secret added successfully"
// @Failure 400 {object} map[string]string "Invalid request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 500 {object} map[string]string "Server error"
// @Router /v1/admin/jwt/secrets [post]
func (h *AdminHandler) handleAddJWTSecret(w http.ResponseWriter, r *http.Request) {
// Decode request body into a map to handle flexible boolean parsing
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
return
}
// Extract and validate fields
secret, ok := body["secret"].(string)
if !ok || secret == "" {
http.Error(w, `{"error":"invalid_request","message":"secret is required and must be a string"}`, http.StatusBadRequest)
return
}
// Handle is_primary as either bool or string
isPrimary := false // default
if val, exists := body["is_primary"]; exists {
switch v := val.(type) {
case bool:
isPrimary = v
case string:
isPrimary = v == "true"
default:
http.Error(w, `{"error":"invalid_request","message":"is_primary must be a boolean or string"}`, http.StatusBadRequest)
return
}
}
// Handle expires_in as either int64 or float64 (JSON numbers)
expiresInHours := int64(0)
if val, exists := body["expires_in"]; exists {
switch v := val.(type) {
case int64:
expiresInHours = v
case float64:
expiresInHours = int64(v)
default:
http.Error(w, `{"error":"invalid_request","message":"expires_in must be a number"}`, http.StatusBadRequest)
return
}
}
// Convert expires_in from hours to time.Duration
expiresIn := time.Duration(expiresInHours) * time.Hour
if expiresIn <= 0 {
// If expires_in is 0 or not provided, set to no expiration for secondary secrets
// For primary secrets, use a reasonable default
if isPrimary {
expiresIn = 24 * 365 * time.Hour // 1 year for primary secrets
} else {
expiresIn = 0 // No expiration for secondary secrets
}
}
// Add the secret to the manager
h.authService.AddJWTSecret(secret, isPrimary, expiresIn)
// Return success
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"message": "JWT secret added successfully"})
}
// RotateJWTSecretRequest represents a request to rotate JWT secrets
type RotateJWTSecretRequest struct {
NewSecret string `json:"new_secret" validate:"required,min=16"`
}
// handleRotateJWTSecret godoc
//
// @Summary Rotate JWT secret
// @Description Rotate to a new primary JWT secret
// @Tags API/v1/Admin
// @Accept json
// @Produce json
// @Param request body RotateJWTSecretRequest true "New JWT secret"
// @Success 200 {object} map[string]string "Secret rotated successfully"
// @Failure 400 {object} map[string]string "Invalid request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 500 {object} map[string]string "Server error"
// @Router /v1/admin/jwt/secrets/rotate [post]
func (h *AdminHandler) handleRotateJWTSecret(w http.ResponseWriter, r *http.Request) {
var req RotateJWTSecretRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
return
}
// Rotate to the new secret
h.authService.RotateJWTSecret(req.NewSecret)
// Return success
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"message": "JWT secret rotated successfully"})
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
@@ -22,6 +23,7 @@ type userServiceImpl struct {
repo UserRepository
jwtConfig JWTConfig
masterPassword string
secretManager *JWTSecretManager
}
// NewUserService creates a new user service with all functionality
@@ -30,6 +32,7 @@ func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword str
repo: repo,
jwtConfig: jwtConfig,
masterPassword: masterPassword,
secretManager: NewJWTSecretManager(jwtConfig.Secret),
}
}
@@ -74,38 +77,77 @@ func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string,
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Get all valid secrets and use the most recently added one for signing
// This supports JWT secret rotation by signing new tokens with the latest secret
validSecrets := s.secretManager.GetAllValidSecrets()
if len(validSecrets) == 0 {
return "", errors.New("no valid JWT secrets available")
}
// Use the most recently added secret (last in the list)
// This ensures new tokens are signed with the latest secret
signingSecret := validSecrets[len(validSecrets)-1].Secret
log.Trace().Ctx(ctx).Str("signing_secret", signingSecret).Bool("is_primary", validSecrets[len(validSecrets)-1].IsPrimary).Msg("Generating JWT with latest secret")
// Sign and get the complete encoded token as a string
tokenString, err := token.SignedString([]byte(s.jwtConfig.Secret))
tokenString, err := token.SignedString([]byte(signingSecret))
if err != nil {
return "", fmt.Errorf("failed to sign JWT: %w", err)
}
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Generated JWT token")
return tokenString, nil
}
// ValidateJWT validates a JWT token and returns the user
func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
// Parse the token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Validating JWT token")
return []byte(s.jwtConfig.Secret), nil
})
// Get all valid secrets for validation
validSecrets := s.secretManager.GetAllValidSecrets()
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
log.Trace().Ctx(ctx).Int("num_secrets", len(validSecrets)).Msg("Validating JWT with multiple secrets")
for i, secret := range validSecrets {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Bool("is_primary", secret.IsPrimary).Msg("Trying secret")
}
// Check if token is valid
if !token.Valid {
// Try each valid secret until we find one that works
var parsedToken *jwt.Token
var validationError error
for i, secret := range validSecrets {
// Parse the token with current secret
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret.Secret), nil
})
if err == nil && token.Valid {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Msg("JWT validation successful")
parsedToken = token
break
}
// Store the last error for reporting
validationError = err
if err != nil {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Err(err).Msg("JWT validation failed")
}
}
if parsedToken == nil {
if validationError != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", validationError)
}
return nil, errors.New("invalid JWT token")
}
// Get claims
claims, ok := token.Claims.(jwt.MapClaims)
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid JWT claims")
}
@@ -156,6 +198,21 @@ func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword
return adminUser, nil
}
// AddJWTSecret adds a new JWT secret to the manager
func (s *userServiceImpl) AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration) {
s.secretManager.AddSecret(secret, isPrimary, expiresIn)
}
// RotateJWTSecret rotates to a new primary JWT secret
func (s *userServiceImpl) RotateJWTSecret(newSecret string) {
s.secretManager.RotateToSecret(newSecret)
}
// GetJWTSecretByIndex returns a JWT secret by index for testing
func (s *userServiceImpl) GetJWTSecretByIndex(index int) (string, bool) {
return s.secretManager.GetSecretByIndex(index)
}
// UserExists checks if a user exists by username
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
return s.repo.UserExists(ctx, username)

95
pkg/user/jwt_manager.go Normal file
View File

@@ -0,0 +1,95 @@
package user
import (
"time"
)
// JWTSecret represents a JWT secret with metadata
type JWTSecret struct {
Secret string
IsPrimary bool
CreatedAt time.Time
ExpiresAt *time.Time // Optional expiration time
}
// JWTSecretManager manages multiple JWT secrets for rotation
type JWTSecretManager struct {
secrets []JWTSecret
primarySecret string
}
// NewJWTSecretManager creates a new JWT secret manager
func NewJWTSecretManager(initialSecret string) *JWTSecretManager {
return &JWTSecretManager{
secrets: []JWTSecret{
{
Secret: initialSecret,
IsPrimary: true,
CreatedAt: time.Now(),
},
},
primarySecret: initialSecret,
}
}
// AddSecret adds a new JWT secret
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
var expiresAt *time.Time
if expiresIn > 0 {
expirationTime := time.Now().Add(expiresIn)
expiresAt = &expirationTime
}
// If expiresIn is 0 or negative, expiresAt remains nil (no expiration)
m.secrets = append(m.secrets, JWTSecret{
Secret: secret,
IsPrimary: isPrimary,
CreatedAt: time.Now(),
ExpiresAt: expiresAt,
})
if isPrimary {
m.primarySecret = secret
}
}
// RotateToSecret rotates to a new primary secret
func (m *JWTSecretManager) RotateToSecret(newSecret string) {
// Mark existing primary as non-primary
for i, secret := range m.secrets {
if secret.IsPrimary {
m.secrets[i].IsPrimary = false
break
}
}
// Add new secret as primary
m.AddSecret(newSecret, true, 0) // No expiration for primary
}
// GetPrimarySecret returns the current primary secret
func (m *JWTSecretManager) GetPrimarySecret() string {
return m.primarySecret
}
// GetAllValidSecrets returns all valid (non-expired) secrets
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
var validSecrets []JWTSecret
now := time.Now()
for _, secret := range m.secrets {
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
validSecrets = append(validSecrets, secret)
}
}
return validSecrets
}
// GetSecretByIndex returns a secret by index for testing
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
if index < 0 || index >= len(m.secrets) {
return "", false
}
return m.secrets[index].Secret, true
}

View File

@@ -0,0 +1,86 @@
package user
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestJWTSecretManager(t *testing.T) {
// Create a new secret manager with initial secret
manager := NewJWTSecretManager("primary-secret")
// Test initial state
assert.Equal(t, "primary-secret", manager.GetPrimarySecret())
// Test GetAllValidSecrets initially
secrets := manager.GetAllValidSecrets()
assert.Len(t, secrets, 1)
assert.Equal(t, "primary-secret", secrets[0].Secret)
assert.True(t, secrets[0].IsPrimary)
assert.Nil(t, secrets[0].ExpiresAt)
// Add a secondary secret
manager.AddSecret("secondary-secret", false, 0) // 0 means no expiration
// Test after adding secondary secret
assert.Equal(t, "primary-secret", manager.GetPrimarySecret()) // Primary should not change
secrets = manager.GetAllValidSecrets()
assert.Len(t, secrets, 2)
// Find the secondary secret
foundSecondary := false
for _, secret := range secrets {
if secret.Secret == "secondary-secret" {
foundSecondary = true
assert.False(t, secret.IsPrimary)
assert.Nil(t, secret.ExpiresAt) // Should have no expiration
break
}
}
assert.True(t, foundSecondary, "Secondary secret should be found in valid secrets")
// Test rotation
manager.RotateToSecret("new-primary-secret")
assert.Equal(t, "new-primary-secret", manager.GetPrimarySecret())
secrets = manager.GetAllValidSecrets()
assert.Len(t, secrets, 3) // Should have 3 secrets now
// Find the new primary secret
foundNewPrimary := false
for _, secret := range secrets {
if secret.Secret == "new-primary-secret" {
foundNewPrimary = true
assert.True(t, secret.IsPrimary)
assert.Nil(t, secret.ExpiresAt) // Should have no expiration
break
}
}
assert.True(t, foundNewPrimary, "New primary secret should be found in valid secrets")
}
func TestJWTSecretExpiration(t *testing.T) {
manager := NewJWTSecretManager("primary-secret")
// Add a secret with expiration
manager.AddSecret("expiring-secret", false, 1*time.Hour) // Expires in 1 hour
// Should have 2 secrets initially
secrets := manager.GetAllValidSecrets()
assert.Len(t, secrets, 2)
// Test expiration logic
foundExpiring := false
for _, secret := range secrets {
if secret.Secret == "expiring-secret" {
foundExpiring = true
assert.NotNil(t, secret.ExpiresAt)
assert.True(t, secret.ExpiresAt.After(time.Now()))
break
}
}
assert.True(t, foundExpiring)
}

View File

@@ -39,6 +39,9 @@ type AuthService interface {
GenerateJWT(ctx context.Context, user *User) (string, error)
ValidateJWT(ctx context.Context, token string) (*User, error)
AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error)
AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration)
RotateJWTSecret(newSecret string)
GetJWTSecretByIndex(index int) (string, bool)
}
// UserManager defines interface for user management operations