package jwt import ( "context" "errors" "fmt" "time" "github.com/golang-jwt/jwt/v5" ) // JWTConfig holds JWT configuration type JWTConfig struct { Secret string ExpirationTime time.Duration Issuer string } // JWTSecret represents a JWT secret with metadata type JWTSecret struct { Secret string IsPrimary bool CreatedAt time.Time ExpiresAt *time.Time // Optional expiration time } // JWTSecretManager manages multiple JWT secrets for rotation. // Secrets can carry an optional expiration; the cleanup loop removes them // after expiry while always preserving the primary secret (ADR-0021). type JWTSecretManager interface { AddSecret(secret string, isPrimary bool, expiresIn time.Duration) RotateToSecret(newSecret string) GetPrimarySecret() string GetAllValidSecrets() []JWTSecret GetSecretByIndex(index int) (string, bool) // RemoveExpiredSecrets drops every non-primary secret whose ExpiresAt is // non-nil and in the past. Returns the count of secrets removed. // The primary secret is never removed regardless of expiration. RemoveExpiredSecrets() int // StartCleanupLoop spawns a goroutine that calls RemoveExpiredSecrets at // the given interval. Stops when the context is cancelled. Safe to call // once at startup; calling again replaces the previous loop's context. StartCleanupLoop(ctx context.Context, interval time.Duration) } // JWTService defines interface for JWT operations type JWTService interface { GenerateJWT(ctx context.Context, userID uint, username string, isAdmin bool) (string, error) ValidateJWT(ctx context.Context, tokenString string, secretManager JWTSecretManager) (*JWTClaims, error) GetJWTSecretManager() JWTSecretManager } // JWTClaims represents the claims in a JWT token type JWTClaims struct { UserID uint `json:"sub"` Username string `json:"name"` IsAdmin bool `json:"admin"` ExpiresAt int64 `json:"exp"` IssuedAt int64 `json:"iat"` Issuer string `json:"iss"` } // jwtServiceImpl implements the JWTService interface type jwtServiceImpl struct { config JWTConfig secretManager JWTSecretManager } // NewJWTService creates a new JWT service func NewJWTService(config JWTConfig) JWTService { return &jwtServiceImpl{ config: config, secretManager: NewJWTSecretManager(config.Secret), } } // GenerateJWT generates a JWT token for the given user information func (s *jwtServiceImpl) GenerateJWT(ctx context.Context, userID uint, username string, isAdmin bool) (string, error) { // Create the claims claims := jwt.MapClaims{ "sub": userID, "name": username, "admin": isAdmin, "exp": time.Now().Add(s.config.ExpirationTime).Unix(), "iat": time.Now().Unix(), "iss": s.config.Issuer, } // Create token token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign and get the complete encoded token as a string using primary secret tokenString, err := token.SignedString([]byte(s.secretManager.GetPrimarySecret())) if err != nil { return "", fmt.Errorf("failed to sign JWT: %w", err) } return tokenString, nil } // ValidateJWT validates a JWT token and returns the claims func (s *jwtServiceImpl) ValidateJWT(ctx context.Context, tokenString string, secretManager JWTSecretManager) (*JWTClaims, error) { // Get all valid secrets for validation validSecrets := secretManager.GetAllValidSecrets() // Try each valid secret until we find one that works var parsedToken *jwt.Token var validationError error for _, secret := range validSecrets { // Parse the token with current secret token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(secret.Secret), nil }) if err == nil && token.Valid { parsedToken = token break } // Store the last error for reporting validationError = err } if parsedToken == nil { if validationError != nil { return nil, fmt.Errorf("failed to parse JWT: %w", validationError) } return nil, errors.New("invalid JWT token") } // Get claims claims, ok := parsedToken.Claims.(jwt.MapClaims) if !ok { return nil, errors.New("invalid JWT claims") } // Extract user ID from claims userIDFloat, ok := claims["sub"].(float64) if !ok { return nil, errors.New("invalid user ID in JWT") } // Extract username from claims username, ok := claims["name"].(string) if !ok { return nil, errors.New("invalid username in JWT") } // Extract admin status from claims isAdmin, ok := claims["admin"].(bool) if !ok { return nil, errors.New("invalid admin status in JWT") } // Extract expiration time from claims expiresAt, ok := claims["exp"].(float64) if !ok { return nil, errors.New("invalid expiration time in JWT") } // Extract issued at time from claims issuedAt, ok := claims["iat"].(float64) if !ok { return nil, errors.New("invalid issued at time in JWT") } // Extract issuer from claims issuer, ok := claims["iss"].(string) if !ok { return nil, errors.New("invalid issuer in JWT") } return &JWTClaims{ UserID: uint(userIDFloat), Username: username, IsAdmin: isAdmin, ExpiresAt: int64(expiresAt), IssuedAt: int64(issuedAt), Issuer: issuer, }, nil } // GetJWTSecretManager returns the JWT secret manager func (s *jwtServiceImpl) GetJWTSecretManager() JWTSecretManager { return s.secretManager }