🧪 test: add JWT secret rotation BDD scenarios and step implementations #12
@@ -3,6 +3,7 @@ package steps
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"dance-lessons-coach/pkg/bdd/testserver"
|
"dance-lessons-coach/pkg/bdd/testserver"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
type AuthSteps struct {
|
type AuthSteps struct {
|
||||||
client *testserver.Client
|
client *testserver.Client
|
||||||
lastToken string
|
lastToken string
|
||||||
|
firstToken string // Store the first token for rotation testing
|
||||||
lastUserID uint
|
lastUserID uint
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,8 +336,12 @@ func (s *AuthSteps) iUseAMalformedJWTTokenForAuthentication() error {
|
|||||||
|
|
||||||
// JWT Validation Steps
|
// JWT Validation Steps
|
||||||
func (s *AuthSteps) iValidateTheReceivedJWTToken() error {
|
func (s *AuthSteps) iValidateTheReceivedJWTToken() error {
|
||||||
// Extract and parse the JWT token
|
// Validate the received JWT token by sending it to the validation endpoint
|
||||||
return s.iShouldReceiveAValidJWTToken()
|
if s.lastToken == "" {
|
||||||
|
return fmt.Errorf("no token to validate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.client.Request("POST", "/api/v1/auth/validate", map[string]string{"token": s.lastToken})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthSteps) theTokenShouldBeValid() error {
|
func (s *AuthSteps) theTokenShouldBeValid() error {
|
||||||
@@ -344,23 +350,53 @@ func (s *AuthSteps) theTokenShouldBeValid() error {
|
|||||||
return fmt.Errorf("expected status 200, got %d", s.client.GetLastStatusCode())
|
return fmt.Errorf("expected status 200, got %d", s.client.GetLastStatusCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if response contains a token
|
// Check if response contains validation confirmation
|
||||||
body := string(s.client.GetLastBody())
|
body := string(s.client.GetLastBody())
|
||||||
if !strings.Contains(body, "token") {
|
if !strings.Contains(body, "valid") {
|
||||||
return fmt.Errorf("expected response to contain token, got %s", body)
|
return fmt.Errorf("expected response to contain valid token confirmation, got %s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only try to parse a JWT token if this is an authentication response (contains "token" field)
|
||||||
|
if strings.Contains(body, "token") {
|
||||||
// Extract and parse the JWT token
|
// Extract and parse the JWT token
|
||||||
if err := s.iShouldReceiveAValidJWTToken(); err != nil {
|
if err := s.iShouldReceiveAValidJWTToken(); err != nil {
|
||||||
return fmt.Errorf("failed to parse JWT token: %w", err)
|
return fmt.Errorf("failed to parse JWT token: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we got here, the token is valid and parsed successfully
|
// If we got here, the token is valid
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthSteps) itShouldContainTheCorrectUserID() error {
|
func (s *AuthSteps) itShouldContainTheCorrectUserID() error {
|
||||||
// Verify that we have a stored user ID from the last token
|
// Check if this is a token validation response (contains user_id)
|
||||||
|
body := string(s.client.GetLastBody())
|
||||||
|
if strings.Contains(body, "user_id") {
|
||||||
|
// This is a token validation response, extract user_id from it
|
||||||
|
startIdx := strings.Index(body, `"user_id":`)
|
||||||
|
if startIdx == -1 {
|
||||||
|
return fmt.Errorf("no user_id found in validation response: %s", body)
|
||||||
|
}
|
||||||
|
startIdx += 10 // Skip "user_id":
|
||||||
|
endIdx := strings.Index(body[startIdx:], ",")
|
||||||
|
if endIdx == -1 {
|
||||||
|
endIdx = strings.Index(body[startIdx:], "}")
|
||||||
|
}
|
||||||
|
if endIdx == -1 {
|
||||||
|
return fmt.Errorf("malformed user_id in validation response: %s", body)
|
||||||
|
}
|
||||||
|
userIDStr := strings.TrimSpace(body[startIdx : startIdx+endIdx])
|
||||||
|
userID, err := strconv.Atoi(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse user_id from validation response: %s", body)
|
||||||
|
}
|
||||||
|
if userID <= 0 {
|
||||||
|
return fmt.Errorf("invalid user_id in validation response: %d", userID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, verify that we have a stored user ID from the last token
|
||||||
if s.lastUserID == 0 {
|
if s.lastUserID == 0 {
|
||||||
return fmt.Errorf("no user ID stored from previous token")
|
return fmt.Errorf("no user ID stored from previous token")
|
||||||
}
|
}
|
||||||
@@ -439,7 +475,17 @@ func (s *AuthSteps) iShouldReceiveAValidJWTTokenSignedWithThePrimarySecret() err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract and store the token
|
// Extract and store the token
|
||||||
return s.iShouldReceiveAValidJWTToken()
|
err := s.iShouldReceiveAValidJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store this as the first token if not already set (for rotation testing)
|
||||||
|
if s.firstToken == "" {
|
||||||
|
s.firstToken = s.lastToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthSteps) iValidateAJWTTokenSignedWithTheSecondarySecret() error {
|
func (s *AuthSteps) iValidateAJWTTokenSignedWithTheSecondarySecret() error {
|
||||||
@@ -516,24 +562,26 @@ func (s *AuthSteps) iUseAJWTTokenSignedWithTheExpiredSecondarySecretForAuthentic
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthSteps) iUseTheOldJWTTokenSignedWithPrimarySecret() error {
|
func (s *AuthSteps) iUseTheOldJWTTokenSignedWithPrimarySecret() error {
|
||||||
// This step assumes we have stored the old token from previous authentication
|
// Use the actual token from the first authentication (stored in firstToken)
|
||||||
// For now, we'll simulate by using a token that would have been signed with primary secret
|
if s.firstToken == "" {
|
||||||
oldPrimaryToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImV4cCI6MjIwMDAwMDAwMCwiaXNzIjoiZGFuY2UtbGVzc29ucy1jb2FjaCJ9.old-primary-secret-signature"
|
return fmt.Errorf("no old token stored from first authentication")
|
||||||
|
}
|
||||||
|
|
||||||
// Set the Authorization header with the old primary token
|
// Set the Authorization header with the old primary token
|
||||||
req := map[string]string{"token": oldPrimaryToken}
|
req := map[string]string{"token": s.firstToken}
|
||||||
return s.client.RequestWithHeader("POST", "/api/v1/auth/validate", req, map[string]string{
|
return s.client.RequestWithHeader("POST", "/api/v1/auth/validate", req, map[string]string{
|
||||||
"Authorization": "Bearer " + oldPrimaryToken,
|
"Authorization": "Bearer " + s.firstToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthSteps) iValidateTheOldJWTTokenSignedWithPrimarySecret() error {
|
func (s *AuthSteps) iValidateTheOldJWTTokenSignedWithPrimarySecret() error {
|
||||||
// This would validate the old token signed with primary secret
|
// Use the actual token from the first authentication (stored in firstToken)
|
||||||
// For now, we'll simulate by validating a token
|
if s.firstToken == "" {
|
||||||
oldPrimaryToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImV4cCI6MjIwMDAwMDAwMCwiaXNzIjoiZGFuY2UtbGVzc29ucy1jb2FjaCJ9.old-primary-secret-signature"
|
return fmt.Errorf("no old token stored from first authentication")
|
||||||
|
}
|
||||||
|
|
||||||
return s.client.RequestWithHeader("POST", "/api/v1/auth/validate", map[string]string{"token": oldPrimaryToken}, map[string]string{
|
return s.client.RequestWithHeader("POST", "/api/v1/auth/validate", map[string]string{"token": s.firstToken}, map[string]string{
|
||||||
"Authorization": "Bearer " + oldPrimaryToken,
|
"Authorization": "Bearer " + s.firstToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
182
pkg/jwt/jwt.go
Normal file
182
pkg/jwt/jwt.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTConfig holds JWT configuration
|
||||||
|
type JWTConfig struct {
|
||||||
|
Secret string
|
||||||
|
ExpirationTime time.Duration
|
||||||
|
Issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTSecret represents a JWT secret with metadata
|
||||||
|
type JWTSecret struct {
|
||||||
|
Secret string
|
||||||
|
IsPrimary bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt *time.Time // Optional expiration time
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTSecretManager manages multiple JWT secrets for rotation
|
||||||
|
type JWTSecretManager interface {
|
||||||
|
AddSecret(secret string, isPrimary bool, expiresIn time.Duration)
|
||||||
|
RotateToSecret(newSecret string)
|
||||||
|
GetPrimarySecret() string
|
||||||
|
GetAllValidSecrets() []JWTSecret
|
||||||
|
GetSecretByIndex(index int) (string, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTService defines interface for JWT operations
|
||||||
|
type JWTService interface {
|
||||||
|
GenerateJWT(ctx context.Context, userID uint, username string, isAdmin bool) (string, error)
|
||||||
|
ValidateJWT(ctx context.Context, tokenString string, secretManager JWTSecretManager) (*JWTClaims, error)
|
||||||
|
GetJWTSecretManager() JWTSecretManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTClaims represents the claims in a JWT token
|
||||||
|
type JWTClaims struct {
|
||||||
|
UserID uint `json:"sub"`
|
||||||
|
Username string `json:"name"`
|
||||||
|
IsAdmin bool `json:"admin"`
|
||||||
|
ExpiresAt int64 `json:"exp"`
|
||||||
|
IssuedAt int64 `json:"iat"`
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtServiceImpl implements the JWTService interface
|
||||||
|
type jwtServiceImpl struct {
|
||||||
|
config JWTConfig
|
||||||
|
secretManager JWTSecretManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTService creates a new JWT service
|
||||||
|
func NewJWTService(config JWTConfig) JWTService {
|
||||||
|
return &jwtServiceImpl{
|
||||||
|
config: config,
|
||||||
|
secretManager: NewJWTSecretManager(config.Secret),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateJWT generates a JWT token for the given user information
|
||||||
|
func (s *jwtServiceImpl) GenerateJWT(ctx context.Context, userID uint, username string, isAdmin bool) (string, error) {
|
||||||
|
// Create the claims
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"sub": userID,
|
||||||
|
"name": username,
|
||||||
|
"admin": isAdmin,
|
||||||
|
"exp": time.Now().Add(s.config.ExpirationTime).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
"iss": s.config.Issuer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
|
||||||
|
// Sign and get the complete encoded token as a string using primary secret
|
||||||
|
tokenString, err := token.SignedString([]byte(s.secretManager.GetPrimarySecret()))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenString, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateJWT validates a JWT token and returns the claims
|
||||||
|
func (s *jwtServiceImpl) ValidateJWT(ctx context.Context, tokenString string, secretManager JWTSecretManager) (*JWTClaims, error) {
|
||||||
|
// Get all valid secrets for validation
|
||||||
|
validSecrets := secretManager.GetAllValidSecrets()
|
||||||
|
|
||||||
|
// Try each valid secret until we find one that works
|
||||||
|
var parsedToken *jwt.Token
|
||||||
|
var validationError error
|
||||||
|
|
||||||
|
for _, secret := range validSecrets {
|
||||||
|
// Parse the token with current secret
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Verify the signing method
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(secret.Secret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil && token.Valid {
|
||||||
|
parsedToken = token
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the last error for reporting
|
||||||
|
validationError = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedToken == nil {
|
||||||
|
if validationError != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWT: %w", validationError)
|
||||||
|
}
|
||||||
|
return nil, errors.New("invalid JWT token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get claims
|
||||||
|
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid JWT claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract user ID from claims
|
||||||
|
userIDFloat, ok := claims["sub"].(float64)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid user ID in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract username from claims
|
||||||
|
username, ok := claims["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid username in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract admin status from claims
|
||||||
|
isAdmin, ok := claims["admin"].(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid admin status in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract expiration time from claims
|
||||||
|
expiresAt, ok := claims["exp"].(float64)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid expiration time in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract issued at time from claims
|
||||||
|
issuedAt, ok := claims["iat"].(float64)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid issued at time in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract issuer from claims
|
||||||
|
issuer, ok := claims["iss"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid issuer in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &JWTClaims{
|
||||||
|
UserID: uint(userIDFloat),
|
||||||
|
Username: username,
|
||||||
|
IsAdmin: isAdmin,
|
||||||
|
ExpiresAt: int64(expiresAt),
|
||||||
|
IssuedAt: int64(issuedAt),
|
||||||
|
Issuer: issuer,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWTSecretManager returns the JWT secret manager
|
||||||
|
func (s *jwtServiceImpl) GetJWTSecretManager() JWTSecretManager {
|
||||||
|
return s.secretManager
|
||||||
|
}
|
||||||
81
pkg/jwt/jwt_secret_manager.go
Normal file
81
pkg/jwt/jwt_secret_manager.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// jwtSecretManagerImpl implements the JWTSecretManager interface
|
||||||
|
type jwtSecretManagerImpl struct {
|
||||||
|
secrets []JWTSecret
|
||||||
|
primarySecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTSecretManager creates a new JWT secret manager
|
||||||
|
func NewJWTSecretManager(initialSecret string) JWTSecretManager {
|
||||||
|
return &jwtSecretManagerImpl{
|
||||||
|
secrets: []JWTSecret{
|
||||||
|
{
|
||||||
|
Secret: initialSecret,
|
||||||
|
IsPrimary: true,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
primarySecret: initialSecret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSecret adds a new JWT secret
|
||||||
|
func (m *jwtSecretManagerImpl) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
|
expiresAt := time.Now().Add(expiresIn)
|
||||||
|
m.secrets = append(m.secrets, JWTSecret{
|
||||||
|
Secret: secret,
|
||||||
|
IsPrimary: isPrimary,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: &expiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
if isPrimary {
|
||||||
|
m.primarySecret = secret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateToSecret rotates to a new primary secret
|
||||||
|
func (m *jwtSecretManagerImpl) RotateToSecret(newSecret string) {
|
||||||
|
// Mark existing primary as non-primary
|
||||||
|
for i, secret := range m.secrets {
|
||||||
|
if secret.IsPrimary {
|
||||||
|
m.secrets[i].IsPrimary = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new secret as primary
|
||||||
|
m.AddSecret(newSecret, true, 0) // No expiration for primary
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrimarySecret returns the current primary secret
|
||||||
|
func (m *jwtSecretManagerImpl) GetPrimarySecret() string {
|
||||||
|
return m.primarySecret
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllValidSecrets returns all valid (non-expired) secrets
|
||||||
|
func (m *jwtSecretManagerImpl) GetAllValidSecrets() []JWTSecret {
|
||||||
|
var validSecrets []JWTSecret
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for _, secret := range m.secrets {
|
||||||
|
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
||||||
|
validSecrets = append(validSecrets, secret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return validSecrets
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSecretByIndex returns a secret by index for testing
|
||||||
|
func (m *jwtSecretManagerImpl) GetSecretByIndex(index int) (string, bool) {
|
||||||
|
if index < 0 || index >= len(m.secrets) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return m.secrets[index].Secret, true
|
||||||
|
}
|
||||||
@@ -166,6 +166,12 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
|
|||||||
r.Route("/auth", func(r chi.Router) {
|
r.Route("/auth", func(r chi.Router) {
|
||||||
handler.RegisterRoutes(r)
|
handler.RegisterRoutes(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Register admin routes
|
||||||
|
adminHandler := userapi.NewAdminHandler(s.userService)
|
||||||
|
r.Route("/admin", func(r chi.Router) {
|
||||||
|
adminHandler.RegisterRoutes(r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
149
pkg/user/api/admin_handler.go
Normal file
149
pkg/user/api/admin_handler.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/user"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminHandler handles admin-related HTTP requests
|
||||||
|
type AdminHandler struct {
|
||||||
|
authService user.AuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAdminHandler creates a new admin handler
|
||||||
|
func NewAdminHandler(authService user.AuthService) *AdminHandler {
|
||||||
|
return &AdminHandler{
|
||||||
|
authService: authService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers admin routes
|
||||||
|
func (h *AdminHandler) RegisterRoutes(router chi.Router) {
|
||||||
|
router.Route("/jwt", func(r chi.Router) {
|
||||||
|
r.Post("/secrets", h.handleAddJWTSecret)
|
||||||
|
r.Post("/secrets/rotate", h.handleRotateJWTSecret)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddJWTSecretRequest represents a request to add a new JWT secret
|
||||||
|
type AddJWTSecretRequest struct {
|
||||||
|
Secret string `json:"secret" validate:"required,min=16"`
|
||||||
|
IsPrimary bool `json:"is_primary"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"` // Expiration time in hours
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAddJWTSecret godoc
|
||||||
|
//
|
||||||
|
// @Summary Add JWT secret
|
||||||
|
// @Description Add a new JWT secret for rotation purposes
|
||||||
|
// @Tags API/v1/Admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body AddJWTSecretRequest true "JWT secret details"
|
||||||
|
// @Success 200 {object} map[string]string "Secret added successfully"
|
||||||
|
// @Failure 400 {object} map[string]string "Invalid request"
|
||||||
|
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||||
|
// @Failure 500 {object} map[string]string "Server error"
|
||||||
|
// @Router /v1/admin/jwt/secrets [post]
|
||||||
|
func (h *AdminHandler) handleAddJWTSecret(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Decode request body into a map to handle flexible boolean parsing
|
||||||
|
var body map[string]interface{}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and validate fields
|
||||||
|
secret, ok := body["secret"].(string)
|
||||||
|
if !ok || secret == "" {
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"secret is required and must be a string"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle is_primary as either bool or string
|
||||||
|
isPrimary := false // default
|
||||||
|
if val, exists := body["is_primary"]; exists {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case bool:
|
||||||
|
isPrimary = v
|
||||||
|
case string:
|
||||||
|
isPrimary = v == "true"
|
||||||
|
default:
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"is_primary must be a boolean or string"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle expires_in as either int64 or float64 (JSON numbers)
|
||||||
|
expiresInHours := int64(0)
|
||||||
|
if val, exists := body["expires_in"]; exists {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int64:
|
||||||
|
expiresInHours = v
|
||||||
|
case float64:
|
||||||
|
expiresInHours = int64(v)
|
||||||
|
default:
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"expires_in must be a number"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert expires_in from hours to time.Duration
|
||||||
|
expiresIn := time.Duration(expiresInHours) * time.Hour
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
// If expires_in is 0 or not provided, set to no expiration for secondary secrets
|
||||||
|
// For primary secrets, use a reasonable default
|
||||||
|
if isPrimary {
|
||||||
|
expiresIn = 24 * 365 * time.Hour // 1 year for primary secrets
|
||||||
|
} else {
|
||||||
|
expiresIn = 0 // No expiration for secondary secrets
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the secret to the manager
|
||||||
|
h.authService.AddJWTSecret(secret, isPrimary, expiresIn)
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"message": "JWT secret added successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateJWTSecretRequest represents a request to rotate JWT secrets
|
||||||
|
type RotateJWTSecretRequest struct {
|
||||||
|
NewSecret string `json:"new_secret" validate:"required,min=16"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRotateJWTSecret godoc
|
||||||
|
//
|
||||||
|
// @Summary Rotate JWT secret
|
||||||
|
// @Description Rotate to a new primary JWT secret
|
||||||
|
// @Tags API/v1/Admin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body RotateJWTSecretRequest true "New JWT secret"
|
||||||
|
// @Success 200 {object} map[string]string "Secret rotated successfully"
|
||||||
|
// @Failure 400 {object} map[string]string "Invalid request"
|
||||||
|
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||||
|
// @Failure 500 {object} map[string]string "Server error"
|
||||||
|
// @Router /v1/admin/jwt/secrets/rotate [post]
|
||||||
|
func (h *AdminHandler) handleRotateJWTSecret(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req RotateJWTSecretRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate to the new secret
|
||||||
|
h.authService.RotateJWTSecret(req.NewSecret)
|
||||||
|
|
||||||
|
// Return success
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"message": "JWT secret rotated successfully"})
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ type userServiceImpl struct {
|
|||||||
repo UserRepository
|
repo UserRepository
|
||||||
jwtConfig JWTConfig
|
jwtConfig JWTConfig
|
||||||
masterPassword string
|
masterPassword string
|
||||||
|
secretManager *JWTSecretManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserService creates a new user service with all functionality
|
// NewUserService creates a new user service with all functionality
|
||||||
@@ -30,6 +32,7 @@ func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword str
|
|||||||
repo: repo,
|
repo: repo,
|
||||||
jwtConfig: jwtConfig,
|
jwtConfig: jwtConfig,
|
||||||
masterPassword: masterPassword,
|
masterPassword: masterPassword,
|
||||||
|
secretManager: NewJWTSecretManager(jwtConfig.Secret),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,38 +77,77 @@ func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string,
|
|||||||
// Create token
|
// Create token
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
|
||||||
|
// Get all valid secrets and use the most recently added one for signing
|
||||||
|
// This supports JWT secret rotation by signing new tokens with the latest secret
|
||||||
|
validSecrets := s.secretManager.GetAllValidSecrets()
|
||||||
|
if len(validSecrets) == 0 {
|
||||||
|
return "", errors.New("no valid JWT secrets available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the most recently added secret (last in the list)
|
||||||
|
// This ensures new tokens are signed with the latest secret
|
||||||
|
signingSecret := validSecrets[len(validSecrets)-1].Secret
|
||||||
|
log.Trace().Ctx(ctx).Str("signing_secret", signingSecret).Bool("is_primary", validSecrets[len(validSecrets)-1].IsPrimary).Msg("Generating JWT with latest secret")
|
||||||
|
|
||||||
// Sign and get the complete encoded token as a string
|
// Sign and get the complete encoded token as a string
|
||||||
tokenString, err := token.SignedString([]byte(s.jwtConfig.Secret))
|
tokenString, err := token.SignedString([]byte(signingSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Generated JWT token")
|
||||||
return tokenString, nil
|
return tokenString, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateJWT validates a JWT token and returns the user
|
// ValidateJWT validates a JWT token and returns the user
|
||||||
func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
|
func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
|
||||||
// Parse the token
|
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Validating JWT token")
|
||||||
|
|
||||||
|
// Get all valid secrets for validation
|
||||||
|
validSecrets := s.secretManager.GetAllValidSecrets()
|
||||||
|
|
||||||
|
log.Trace().Ctx(ctx).Int("num_secrets", len(validSecrets)).Msg("Validating JWT with multiple secrets")
|
||||||
|
for i, secret := range validSecrets {
|
||||||
|
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Bool("is_primary", secret.IsPrimary).Msg("Trying secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each valid secret until we find one that works
|
||||||
|
var parsedToken *jwt.Token
|
||||||
|
var validationError error
|
||||||
|
|
||||||
|
for i, secret := range validSecrets {
|
||||||
|
// Parse the token with current secret
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
// Verify the signing method
|
// Verify the signing method
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(s.jwtConfig.Secret), nil
|
return []byte(secret.Secret), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err == nil && token.Valid {
|
||||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Msg("JWT validation successful")
|
||||||
|
parsedToken = token
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if token is valid
|
// Store the last error for reporting
|
||||||
if !token.Valid {
|
validationError = err
|
||||||
|
if err != nil {
|
||||||
|
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Err(err).Msg("JWT validation failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedToken == nil {
|
||||||
|
if validationError != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWT: %w", validationError)
|
||||||
|
}
|
||||||
return nil, errors.New("invalid JWT token")
|
return nil, errors.New("invalid JWT token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get claims
|
// Get claims
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("invalid JWT claims")
|
return nil, errors.New("invalid JWT claims")
|
||||||
}
|
}
|
||||||
@@ -156,6 +198,21 @@ func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword
|
|||||||
return adminUser, nil
|
return adminUser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddJWTSecret adds a new JWT secret to the manager
|
||||||
|
func (s *userServiceImpl) AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
|
s.secretManager.AddSecret(secret, isPrimary, expiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateJWTSecret rotates to a new primary JWT secret
|
||||||
|
func (s *userServiceImpl) RotateJWTSecret(newSecret string) {
|
||||||
|
s.secretManager.RotateToSecret(newSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWTSecretByIndex returns a JWT secret by index for testing
|
||||||
|
func (s *userServiceImpl) GetJWTSecretByIndex(index int) (string, bool) {
|
||||||
|
return s.secretManager.GetSecretByIndex(index)
|
||||||
|
}
|
||||||
|
|
||||||
// UserExists checks if a user exists by username
|
// UserExists checks if a user exists by username
|
||||||
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
|
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
|
||||||
return s.repo.UserExists(ctx, username)
|
return s.repo.UserExists(ctx, username)
|
||||||
|
|||||||
95
pkg/user/jwt_manager.go
Normal file
95
pkg/user/jwt_manager.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTSecret represents a JWT secret with metadata
|
||||||
|
type JWTSecret struct {
|
||||||
|
Secret string
|
||||||
|
IsPrimary bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt *time.Time // Optional expiration time
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTSecretManager manages multiple JWT secrets for rotation
|
||||||
|
type JWTSecretManager struct {
|
||||||
|
secrets []JWTSecret
|
||||||
|
primarySecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTSecretManager creates a new JWT secret manager
|
||||||
|
func NewJWTSecretManager(initialSecret string) *JWTSecretManager {
|
||||||
|
return &JWTSecretManager{
|
||||||
|
secrets: []JWTSecret{
|
||||||
|
{
|
||||||
|
Secret: initialSecret,
|
||||||
|
IsPrimary: true,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
primarySecret: initialSecret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSecret adds a new JWT secret
|
||||||
|
func (m *JWTSecretManager) AddSecret(secret string, isPrimary bool, expiresIn time.Duration) {
|
||||||
|
var expiresAt *time.Time
|
||||||
|
if expiresIn > 0 {
|
||||||
|
expirationTime := time.Now().Add(expiresIn)
|
||||||
|
expiresAt = &expirationTime
|
||||||
|
}
|
||||||
|
// If expiresIn is 0 or negative, expiresAt remains nil (no expiration)
|
||||||
|
|
||||||
|
m.secrets = append(m.secrets, JWTSecret{
|
||||||
|
Secret: secret,
|
||||||
|
IsPrimary: isPrimary,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
if isPrimary {
|
||||||
|
m.primarySecret = secret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateToSecret rotates to a new primary secret
|
||||||
|
func (m *JWTSecretManager) RotateToSecret(newSecret string) {
|
||||||
|
// Mark existing primary as non-primary
|
||||||
|
for i, secret := range m.secrets {
|
||||||
|
if secret.IsPrimary {
|
||||||
|
m.secrets[i].IsPrimary = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new secret as primary
|
||||||
|
m.AddSecret(newSecret, true, 0) // No expiration for primary
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrimarySecret returns the current primary secret
|
||||||
|
func (m *JWTSecretManager) GetPrimarySecret() string {
|
||||||
|
return m.primarySecret
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllValidSecrets returns all valid (non-expired) secrets
|
||||||
|
func (m *JWTSecretManager) GetAllValidSecrets() []JWTSecret {
|
||||||
|
var validSecrets []JWTSecret
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for _, secret := range m.secrets {
|
||||||
|
if secret.ExpiresAt == nil || secret.ExpiresAt.After(now) {
|
||||||
|
validSecrets = append(validSecrets, secret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return validSecrets
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSecretByIndex returns a secret by index for testing
|
||||||
|
func (m *JWTSecretManager) GetSecretByIndex(index int) (string, bool) {
|
||||||
|
if index < 0 || index >= len(m.secrets) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return m.secrets[index].Secret, true
|
||||||
|
}
|
||||||
86
pkg/user/jwt_manager_test.go
Normal file
86
pkg/user/jwt_manager_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWTSecretManager(t *testing.T) {
|
||||||
|
// Create a new secret manager with initial secret
|
||||||
|
manager := NewJWTSecretManager("primary-secret")
|
||||||
|
|
||||||
|
// Test initial state
|
||||||
|
assert.Equal(t, "primary-secret", manager.GetPrimarySecret())
|
||||||
|
|
||||||
|
// Test GetAllValidSecrets initially
|
||||||
|
secrets := manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 1)
|
||||||
|
assert.Equal(t, "primary-secret", secrets[0].Secret)
|
||||||
|
assert.True(t, secrets[0].IsPrimary)
|
||||||
|
assert.Nil(t, secrets[0].ExpiresAt)
|
||||||
|
|
||||||
|
// Add a secondary secret
|
||||||
|
manager.AddSecret("secondary-secret", false, 0) // 0 means no expiration
|
||||||
|
|
||||||
|
// Test after adding secondary secret
|
||||||
|
assert.Equal(t, "primary-secret", manager.GetPrimarySecret()) // Primary should not change
|
||||||
|
|
||||||
|
secrets = manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 2)
|
||||||
|
|
||||||
|
// Find the secondary secret
|
||||||
|
foundSecondary := false
|
||||||
|
for _, secret := range secrets {
|
||||||
|
if secret.Secret == "secondary-secret" {
|
||||||
|
foundSecondary = true
|
||||||
|
assert.False(t, secret.IsPrimary)
|
||||||
|
assert.Nil(t, secret.ExpiresAt) // Should have no expiration
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundSecondary, "Secondary secret should be found in valid secrets")
|
||||||
|
|
||||||
|
// Test rotation
|
||||||
|
manager.RotateToSecret("new-primary-secret")
|
||||||
|
assert.Equal(t, "new-primary-secret", manager.GetPrimarySecret())
|
||||||
|
|
||||||
|
secrets = manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 3) // Should have 3 secrets now
|
||||||
|
|
||||||
|
// Find the new primary secret
|
||||||
|
foundNewPrimary := false
|
||||||
|
for _, secret := range secrets {
|
||||||
|
if secret.Secret == "new-primary-secret" {
|
||||||
|
foundNewPrimary = true
|
||||||
|
assert.True(t, secret.IsPrimary)
|
||||||
|
assert.Nil(t, secret.ExpiresAt) // Should have no expiration
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundNewPrimary, "New primary secret should be found in valid secrets")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTSecretExpiration(t *testing.T) {
|
||||||
|
manager := NewJWTSecretManager("primary-secret")
|
||||||
|
|
||||||
|
// Add a secret with expiration
|
||||||
|
manager.AddSecret("expiring-secret", false, 1*time.Hour) // Expires in 1 hour
|
||||||
|
|
||||||
|
// Should have 2 secrets initially
|
||||||
|
secrets := manager.GetAllValidSecrets()
|
||||||
|
assert.Len(t, secrets, 2)
|
||||||
|
|
||||||
|
// Test expiration logic
|
||||||
|
foundExpiring := false
|
||||||
|
for _, secret := range secrets {
|
||||||
|
if secret.Secret == "expiring-secret" {
|
||||||
|
foundExpiring = true
|
||||||
|
assert.NotNil(t, secret.ExpiresAt)
|
||||||
|
assert.True(t, secret.ExpiresAt.After(time.Now()))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundExpiring)
|
||||||
|
}
|
||||||
@@ -39,6 +39,9 @@ type AuthService interface {
|
|||||||
GenerateJWT(ctx context.Context, user *User) (string, error)
|
GenerateJWT(ctx context.Context, user *User) (string, error)
|
||||||
ValidateJWT(ctx context.Context, token string) (*User, error)
|
ValidateJWT(ctx context.Context, token string) (*User, error)
|
||||||
AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error)
|
AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error)
|
||||||
|
AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration)
|
||||||
|
RotateJWTSecret(newSecret string)
|
||||||
|
GetJWTSecretByIndex(index int) (string, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserManager defines interface for user management operations
|
// UserManager defines interface for user management operations
|
||||||
|
|||||||
Reference in New Issue
Block a user