Files
dance-lessons-coach/pkg/user/auth_service.go
Gabriel Radureau 405a9fc937 feat(auth): JWT TTL hot-reload + fix hardcoded 24h bug (ADR-0023 Phase 2)
Two changes in one diff because they share the same surface (JWTConfig
plumbing):

1. **Bug fix** : pkg/server/server.go was hardcoding ExpirationTime to
   24h, ignoring the auth.jwt.ttl config value entirely (default 1h).
   Production has been signing tokens with 24h TTL regardless of config
   since the config field was added.

2. **Hot-reload (ADR-0023 Phase 2)** : extends JWTConfig with a GetTTL
   func() time.Duration callback. effectiveTTL() prefers GetTTL when
   set, falls back to ExpirationTime otherwise (test-friendly). server.go
   wires GetTTL = cfg.GetJWTTTL — a method value that captures the
   *Config, so when WatchAndApply re-unmarshals, the next token
   generation reads the new TTL automatically. Tokens already issued
   keep their original expiry.

WatchAndApply now also logs the new jwt_ttl on every reload event.

Tests:
- New TestWatchAndApply_JWTTTL in pkg/config/config_hot_reload_test.go
  rewrites the config file and asserts the in-memory ttl flips within
  2s. Polling (no fixed sleep), race-clean.
- Existing pkg/user tests (including JWT manager + cleanup loop) all
  pass with -race.
- Full BDD suite (auth/config/greet/health/info/jwt) green.

ADR-0023 status: Phase 1+2 Implemented. Phase 3 (telemetry sampler)
and Phase 4 (api.v2_enabled — needs router refactor) remain Proposed.
2026-05-05 09:08:19 +02:00

327 lines
10 KiB
Go

package user
import (
"context"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
// JWTConfig holds JWT configuration.
//
// GetTTL, when non-nil, is called on every token generation to read the
// current TTL — this enables ADR-0023 Phase 2 hot-reload of `auth.jwt.ttl`.
// If nil, ExpirationTime is used as a static fallback.
type JWTConfig struct {
Secret string
ExpirationTime time.Duration
GetTTL func() time.Duration
Issuer string
}
// effectiveTTL returns the live TTL: GetTTL() when wired, else
// ExpirationTime as a static fallback (used by tests that don't go
// through the server-level wiring).
func (c JWTConfig) effectiveTTL() time.Duration {
if c.GetTTL != nil {
if ttl := c.GetTTL(); ttl > 0 {
return ttl
}
}
return c.ExpirationTime
}
// userServiceImpl implements the unified UserService interface
type userServiceImpl struct {
repo UserRepository
jwtConfig JWTConfig
masterPassword string
secretManager *JWTSecretManager
}
// NewUserService creates a new user service with all functionality
func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *userServiceImpl {
return &userServiceImpl{
repo: repo,
jwtConfig: jwtConfig,
masterPassword: masterPassword,
secretManager: NewJWTSecretManager(jwtConfig.Secret),
}
}
// Authenticate authenticates a user with username and password
func (s *userServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) {
user, err := s.repo.GetUserByUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return nil, errors.New("invalid credentials")
}
// Check password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, errors.New("invalid credentials")
}
// Update last login time
now := time.Now()
user.LastLogin = &now
if err := s.repo.UpdateUser(ctx, user); err != nil {
// Don't fail authentication if we can't update last login
// Just log it and continue
}
return user, nil
}
// GenerateJWT generates a JWT token for the given user
func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) {
// Create the claims
claims := jwt.MapClaims{
"sub": user.ID,
"name": user.Username,
"admin": user.IsAdmin,
"exp": time.Now().Add(s.jwtConfig.effectiveTTL()).Unix(),
"iat": time.Now().Unix(),
"iss": s.jwtConfig.Issuer,
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Get all valid secrets and use the most recently added one for signing
// This supports JWT secret rotation by signing new tokens with the latest secret
validSecrets := s.secretManager.GetAllValidSecrets()
if len(validSecrets) == 0 {
return "", errors.New("no valid JWT secrets available")
}
// Use the most recently added secret (last in the list)
// This ensures new tokens are signed with the latest secret
signingSecret := validSecrets[len(validSecrets)-1].Secret
log.Trace().Ctx(ctx).Str("signing_secret", signingSecret).Bool("is_primary", validSecrets[len(validSecrets)-1].IsPrimary).Msg("Generating JWT with latest secret")
// Sign and get the complete encoded token as a string
tokenString, err := token.SignedString([]byte(signingSecret))
if err != nil {
return "", fmt.Errorf("failed to sign JWT: %w", err)
}
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Generated JWT token")
return tokenString, nil
}
// ValidateJWT validates a JWT token and returns the user
func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Validating JWT token")
// Get all valid secrets for validation
validSecrets := s.secretManager.GetAllValidSecrets()
log.Trace().Ctx(ctx).Int("num_secrets", len(validSecrets)).Msg("Validating JWT with multiple secrets")
for i, secret := range validSecrets {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Bool("is_primary", secret.IsPrimary).Msg("Trying secret")
}
// Try each valid secret until we find one that works
var parsedToken *jwt.Token
var validationError error
for i, secret := range validSecrets {
// Parse the token with current secret
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret.Secret), nil
})
if err == nil && token.Valid {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Msg("JWT validation successful")
parsedToken = token
break
}
// Store the last error for reporting
validationError = err
if err != nil {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Err(err).Msg("JWT validation failed")
}
}
if parsedToken == nil {
if validationError != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", validationError)
}
return nil, errors.New("invalid JWT token")
}
// Get claims
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid JWT claims")
}
// Get user ID from claims
userIDFloat, ok := claims["sub"].(float64)
if !ok {
return nil, errors.New("invalid user ID in JWT")
}
userID := uint(userIDFloat)
// Get user from repository
user, err := s.repo.GetUserByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get user from JWT: %w", err)
}
if user == nil {
return nil, errors.New("user not found")
}
return user, nil
}
// HashPassword hashes a password using bcrypt (implements PasswordService interface)
func (s *userServiceImpl) HashPassword(ctx context.Context, password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash password: %w", err)
}
return string(hash), nil
}
// AdminAuthenticate authenticates an admin user with master password
func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) {
// Check if master password matches
if masterPassword != s.masterPassword {
return nil, errors.New("invalid admin credentials")
}
// Create a virtual admin user (not persisted)
adminUser := &User{
ID: 0, // Special ID for admin
Username: "admin",
IsAdmin: true,
}
return adminUser, nil
}
// AddJWTSecret adds a new JWT secret to the manager
func (s *userServiceImpl) AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration) {
s.secretManager.AddSecret(secret, isPrimary, expiresIn)
}
// RotateJWTSecret rotates to a new primary JWT secret
func (s *userServiceImpl) RotateJWTSecret(newSecret string) {
s.secretManager.RotateToSecret(newSecret)
}
// GetJWTSecretByIndex returns a JWT secret by index for testing
func (s *userServiceImpl) GetJWTSecretByIndex(index int) (string, bool) {
return s.secretManager.GetSecretByIndex(index)
}
// ResetJWTSecrets resets JWT secrets to initial state for test cleanup
func (s *userServiceImpl) ResetJWTSecrets() {
s.secretManager.Reset(s.jwtConfig.Secret)
}
// StartJWTSecretCleanupLoop delegates to the underlying secret manager to
// start the periodic cleanup goroutine described in ADR-0021.
func (s *userServiceImpl) StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration) {
s.secretManager.StartCleanupLoop(ctx, interval)
}
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass via the
// underlying secret manager. Returns the count of removed expired secrets.
func (s *userServiceImpl) RemoveExpiredJWTSecrets() int {
return s.secretManager.RemoveExpiredSecrets()
}
// UserExists checks if a user exists by username
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
return s.repo.UserExists(ctx, username)
}
// CreateUser creates a new user in the database
func (s *userServiceImpl) CreateUser(ctx context.Context, user *User) error {
return s.repo.CreateUser(ctx, user)
}
// RequestPasswordReset requests a password reset for a user
func (s *userServiceImpl) RequestPasswordReset(ctx context.Context, username string) error {
// Check if user exists
exists, err := s.repo.UserExists(ctx, username)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found: %s", username)
}
// Allow password reset
return s.repo.AllowPasswordReset(ctx, username)
}
// CompletePasswordReset completes the password reset process
func (s *userServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error {
// Hash the new password
hashedPassword, err := s.HashPassword(ctx, newPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
// Complete the password reset
return s.repo.CompletePasswordReset(ctx, username, hashedPassword)
}
// PasswordResetServiceImpl implements the PasswordResetService interface
type PasswordResetServiceImpl struct {
repo UserRepository
auth *userServiceImpl
}
// NewPasswordResetService creates a new password reset service
func NewPasswordResetService(repo UserRepository, auth *userServiceImpl) *PasswordResetServiceImpl {
return &PasswordResetServiceImpl{
repo: repo,
auth: auth,
}
}
// RequestPasswordReset requests a password reset for a user
func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error {
// Check if user exists
exists, err := s.repo.UserExists(ctx, username)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found: %s", username)
}
// Allow password reset
return s.repo.AllowPasswordReset(ctx, username)
}
// CompletePasswordReset completes the password reset process
func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error {
// Hash the new password
hashedPassword, err := s.auth.HashPassword(ctx, newPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
// Complete the password reset
return s.repo.CompletePasswordReset(ctx, username, hashedPassword)
}