Files
dance-lessons-coach/pkg/user/auth_service.go
Gabriel Radureau 3c73ca39d6
Some checks failed
CI/CD Pipeline / Build Docker Cache (push) Successful in 23s
CI/CD Pipeline / CI Pipeline (push) Failing after 5m23s
CI/CD Pipeline / Trigger Docker Push (push) Has been skipped
feat(auth): JWT TTL hot-reload + fix hardcoded 24h bug (ADR-0023 Phase 2) (#44)
Co-authored-by: Gabriel Radureau <arcodange@gmail.com>
Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
2026-05-05 09:09:22 +02:00

327 lines
10 KiB
Go

package user
import (
"context"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
// JWTConfig holds JWT configuration.
//
// GetTTL, when non-nil, is called on every token generation to read the
// current TTL — this enables ADR-0023 Phase 2 hot-reload of `auth.jwt.ttl`.
// If nil, ExpirationTime is used as a static fallback.
type JWTConfig struct {
Secret string
ExpirationTime time.Duration
GetTTL func() time.Duration
Issuer string
}
// effectiveTTL returns the live TTL: GetTTL() when wired, else
// ExpirationTime as a static fallback (used by tests that don't go
// through the server-level wiring).
func (c JWTConfig) effectiveTTL() time.Duration {
if c.GetTTL != nil {
if ttl := c.GetTTL(); ttl > 0 {
return ttl
}
}
return c.ExpirationTime
}
// userServiceImpl implements the unified UserService interface
type userServiceImpl struct {
repo UserRepository
jwtConfig JWTConfig
masterPassword string
secretManager *JWTSecretManager
}
// NewUserService creates a new user service with all functionality
func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *userServiceImpl {
return &userServiceImpl{
repo: repo,
jwtConfig: jwtConfig,
masterPassword: masterPassword,
secretManager: NewJWTSecretManager(jwtConfig.Secret),
}
}
// Authenticate authenticates a user with username and password
func (s *userServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) {
user, err := s.repo.GetUserByUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return nil, errors.New("invalid credentials")
}
// Check password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, errors.New("invalid credentials")
}
// Update last login time
now := time.Now()
user.LastLogin = &now
if err := s.repo.UpdateUser(ctx, user); err != nil {
// Don't fail authentication if we can't update last login
// Just log it and continue
}
return user, nil
}
// GenerateJWT generates a JWT token for the given user
func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) {
// Create the claims
claims := jwt.MapClaims{
"sub": user.ID,
"name": user.Username,
"admin": user.IsAdmin,
"exp": time.Now().Add(s.jwtConfig.effectiveTTL()).Unix(),
"iat": time.Now().Unix(),
"iss": s.jwtConfig.Issuer,
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Get all valid secrets and use the most recently added one for signing
// This supports JWT secret rotation by signing new tokens with the latest secret
validSecrets := s.secretManager.GetAllValidSecrets()
if len(validSecrets) == 0 {
return "", errors.New("no valid JWT secrets available")
}
// Use the most recently added secret (last in the list)
// This ensures new tokens are signed with the latest secret
signingSecret := validSecrets[len(validSecrets)-1].Secret
log.Trace().Ctx(ctx).Str("signing_secret", signingSecret).Bool("is_primary", validSecrets[len(validSecrets)-1].IsPrimary).Msg("Generating JWT with latest secret")
// Sign and get the complete encoded token as a string
tokenString, err := token.SignedString([]byte(signingSecret))
if err != nil {
return "", fmt.Errorf("failed to sign JWT: %w", err)
}
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Generated JWT token")
return tokenString, nil
}
// ValidateJWT validates a JWT token and returns the user
func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
log.Trace().Ctx(ctx).Str("token", tokenString).Msg("Validating JWT token")
// Get all valid secrets for validation
validSecrets := s.secretManager.GetAllValidSecrets()
log.Trace().Ctx(ctx).Int("num_secrets", len(validSecrets)).Msg("Validating JWT with multiple secrets")
for i, secret := range validSecrets {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Bool("is_primary", secret.IsPrimary).Msg("Trying secret")
}
// Try each valid secret until we find one that works
var parsedToken *jwt.Token
var validationError error
for i, secret := range validSecrets {
// Parse the token with current secret
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret.Secret), nil
})
if err == nil && token.Valid {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Msg("JWT validation successful")
parsedToken = token
break
}
// Store the last error for reporting
validationError = err
if err != nil {
log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret", secret.Secret).Err(err).Msg("JWT validation failed")
}
}
if parsedToken == nil {
if validationError != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", validationError)
}
return nil, errors.New("invalid JWT token")
}
// Get claims
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid JWT claims")
}
// Get user ID from claims
userIDFloat, ok := claims["sub"].(float64)
if !ok {
return nil, errors.New("invalid user ID in JWT")
}
userID := uint(userIDFloat)
// Get user from repository
user, err := s.repo.GetUserByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get user from JWT: %w", err)
}
if user == nil {
return nil, errors.New("user not found")
}
return user, nil
}
// HashPassword hashes a password using bcrypt (implements PasswordService interface)
func (s *userServiceImpl) HashPassword(ctx context.Context, password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash password: %w", err)
}
return string(hash), nil
}
// AdminAuthenticate authenticates an admin user with master password
func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) {
// Check if master password matches
if masterPassword != s.masterPassword {
return nil, errors.New("invalid admin credentials")
}
// Create a virtual admin user (not persisted)
adminUser := &User{
ID: 0, // Special ID for admin
Username: "admin",
IsAdmin: true,
}
return adminUser, nil
}
// AddJWTSecret adds a new JWT secret to the manager
func (s *userServiceImpl) AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration) {
s.secretManager.AddSecret(secret, isPrimary, expiresIn)
}
// RotateJWTSecret rotates to a new primary JWT secret
func (s *userServiceImpl) RotateJWTSecret(newSecret string) {
s.secretManager.RotateToSecret(newSecret)
}
// GetJWTSecretByIndex returns a JWT secret by index for testing
func (s *userServiceImpl) GetJWTSecretByIndex(index int) (string, bool) {
return s.secretManager.GetSecretByIndex(index)
}
// ResetJWTSecrets resets JWT secrets to initial state for test cleanup
func (s *userServiceImpl) ResetJWTSecrets() {
s.secretManager.Reset(s.jwtConfig.Secret)
}
// StartJWTSecretCleanupLoop delegates to the underlying secret manager to
// start the periodic cleanup goroutine described in ADR-0021.
func (s *userServiceImpl) StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration) {
s.secretManager.StartCleanupLoop(ctx, interval)
}
// RemoveExpiredJWTSecrets triggers an immediate cleanup pass via the
// underlying secret manager. Returns the count of removed expired secrets.
func (s *userServiceImpl) RemoveExpiredJWTSecrets() int {
return s.secretManager.RemoveExpiredSecrets()
}
// UserExists checks if a user exists by username
func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) {
return s.repo.UserExists(ctx, username)
}
// CreateUser creates a new user in the database
func (s *userServiceImpl) CreateUser(ctx context.Context, user *User) error {
return s.repo.CreateUser(ctx, user)
}
// RequestPasswordReset requests a password reset for a user
func (s *userServiceImpl) RequestPasswordReset(ctx context.Context, username string) error {
// Check if user exists
exists, err := s.repo.UserExists(ctx, username)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found: %s", username)
}
// Allow password reset
return s.repo.AllowPasswordReset(ctx, username)
}
// CompletePasswordReset completes the password reset process
func (s *userServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error {
// Hash the new password
hashedPassword, err := s.HashPassword(ctx, newPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
// Complete the password reset
return s.repo.CompletePasswordReset(ctx, username, hashedPassword)
}
// PasswordResetServiceImpl implements the PasswordResetService interface
type PasswordResetServiceImpl struct {
repo UserRepository
auth *userServiceImpl
}
// NewPasswordResetService creates a new password reset service
func NewPasswordResetService(repo UserRepository, auth *userServiceImpl) *PasswordResetServiceImpl {
return &PasswordResetServiceImpl{
repo: repo,
auth: auth,
}
}
// RequestPasswordReset requests a password reset for a user
func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error {
// Check if user exists
exists, err := s.repo.UserExists(ctx, username)
if err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
if !exists {
return fmt.Errorf("user not found: %s", username)
}
// Allow password reset
return s.repo.AllowPasswordReset(ctx, username)
}
// CompletePasswordReset completes the password reset process
func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error {
// Hash the new password
hashedPassword, err := s.auth.HashPassword(ctx, newPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
// Complete the password reset
return s.repo.CompletePasswordReset(ctx, username, hashedPassword)
}