package user import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "time" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" ) // JWTConfig holds JWT configuration. // // GetTTL, when non-nil, is called on every token generation to read the // current TTL — this enables ADR-0023 Phase 2 hot-reload of `auth.jwt.ttl`. // If nil, ExpirationTime is used as a static fallback. type JWTConfig struct { Secret string ExpirationTime time.Duration GetTTL func() time.Duration Issuer string } // effectiveTTL returns the live TTL: GetTTL() when wired, else // ExpirationTime as a static fallback (used by tests that don't go // through the server-level wiring). func (c JWTConfig) effectiveTTL() time.Duration { if c.GetTTL != nil { if ttl := c.GetTTL(); ttl > 0 { return ttl } } return c.ExpirationTime } // userServiceImpl implements the unified UserService interface type userServiceImpl struct { repo UserRepository jwtConfig JWTConfig masterPassword string secretManager *JWTSecretManager } // NewUserService creates a new user service with all functionality func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *userServiceImpl { return &userServiceImpl{ repo: repo, jwtConfig: jwtConfig, masterPassword: masterPassword, secretManager: NewJWTSecretManager(jwtConfig.Secret), } } // Authenticate authenticates a user with username and password func (s *userServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) { user, err := s.repo.GetUserByUsername(ctx, username) if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } if user == nil { return nil, errors.New("invalid credentials") } // Check password if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { return nil, errors.New("invalid credentials") } // Update last login time now := time.Now() user.LastLogin = &now if err := s.repo.UpdateUser(ctx, user); err != nil { // Don't fail authentication if we can't update last login // Just log it and continue } return user, nil } // GenerateJWT generates a JWT token for the given user func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) { // Create the claims claims := jwt.MapClaims{ "sub": user.ID, "name": user.Username, "admin": user.IsAdmin, "exp": time.Now().Add(s.jwtConfig.effectiveTTL()).Unix(), "iat": time.Now().Unix(), "iss": s.jwtConfig.Issuer, } // Create token token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Get all valid secrets and use the most recently added one for signing // This supports JWT secret rotation by signing new tokens with the latest secret validSecrets := s.secretManager.GetAllValidSecrets() if len(validSecrets) == 0 { return "", errors.New("no valid JWT secrets available") } // Use the most recently added secret (last in the list) // This ensures new tokens are signed with the latest secret signingSecret := validSecrets[len(validSecrets)-1].Secret log.Trace().Ctx(ctx).Str("signing_secret_fp", tokenFingerprint(signingSecret)).Bool("is_primary", validSecrets[len(validSecrets)-1].IsPrimary).Msg("Generating JWT with latest secret") // Sign and get the complete encoded token as a string tokenString, err := token.SignedString([]byte(signingSecret)) if err != nil { return "", fmt.Errorf("failed to sign JWT: %w", err) } log.Trace().Ctx(ctx).Str("token_fp", tokenFingerprint(tokenString)).Int("token_len", len(tokenString)).Msg("Generated JWT token") return tokenString, nil } // ValidateJWT validates a JWT token and returns the user func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) { log.Trace().Ctx(ctx).Str("token_fp", tokenFingerprint(tokenString)).Int("token_len", len(tokenString)).Msg("Validating JWT token") // Get all valid secrets for validation validSecrets := s.secretManager.GetAllValidSecrets() log.Trace().Ctx(ctx).Int("num_secrets", len(validSecrets)).Msg("Validating JWT with multiple secrets") for i, secret := range validSecrets { log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret_fp", tokenFingerprint(secret.Secret)).Bool("is_primary", secret.IsPrimary).Msg("Trying secret") } // Try each valid secret until we find one that works var parsedToken *jwt.Token var validationError error for i, secret := range validSecrets { // Parse the token with current secret token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(secret.Secret), nil }) if err == nil && token.Valid { log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret_fp", tokenFingerprint(secret.Secret)).Msg("JWT validation successful") parsedToken = token break } // Store the last error for reporting validationError = err if err != nil { log.Trace().Ctx(ctx).Int("secret_index", i).Str("secret_fp", tokenFingerprint(secret.Secret)).Err(err).Msg("JWT validation failed") } } if parsedToken == nil { if validationError != nil { return nil, fmt.Errorf("failed to parse JWT: %w", validationError) } return nil, errors.New("invalid JWT token") } // Get claims claims, ok := parsedToken.Claims.(jwt.MapClaims) if !ok { return nil, errors.New("invalid JWT claims") } // Get user ID from claims userIDFloat, ok := claims["sub"].(float64) if !ok { return nil, errors.New("invalid user ID in JWT") } userID := uint(userIDFloat) // Get user from repository user, err := s.repo.GetUserByID(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to get user from JWT: %w", err) } if user == nil { return nil, errors.New("user not found") } return user, nil } // HashPassword hashes a password using bcrypt (implements PasswordService interface) func (s *userServiceImpl) HashPassword(ctx context.Context, password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", fmt.Errorf("failed to hash password: %w", err) } return string(hash), nil } // AdminAuthenticate authenticates an admin user with master password func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) { // Check if master password matches if masterPassword != s.masterPassword { return nil, errors.New("invalid admin credentials") } // Create a virtual admin user (not persisted) adminUser := &User{ ID: 0, // Special ID for admin Username: "admin", IsAdmin: true, } return adminUser, nil } // AddJWTSecret adds a new JWT secret to the manager func (s *userServiceImpl) AddJWTSecret(secret string, isPrimary bool, expiresIn time.Duration) { s.secretManager.AddSecret(secret, isPrimary, expiresIn) } // RotateJWTSecret rotates to a new primary JWT secret func (s *userServiceImpl) RotateJWTSecret(newSecret string) { s.secretManager.RotateToSecret(newSecret) } // GetJWTSecretByIndex returns a JWT secret by index for testing func (s *userServiceImpl) GetJWTSecretByIndex(index int) (string, bool) { return s.secretManager.GetSecretByIndex(index) } // ResetJWTSecrets resets JWT secrets to initial state for test cleanup func (s *userServiceImpl) ResetJWTSecrets() { s.secretManager.Reset(s.jwtConfig.Secret) } // StartJWTSecretCleanupLoop delegates to the underlying secret manager to // start the periodic cleanup goroutine described in ADR-0021. func (s *userServiceImpl) StartJWTSecretCleanupLoop(ctx context.Context, interval time.Duration) { s.secretManager.StartCleanupLoop(ctx, interval) } // RemoveExpiredJWTSecrets triggers an immediate cleanup pass via the // underlying secret manager. Returns the count of removed expired secrets. func (s *userServiceImpl) RemoveExpiredJWTSecrets() int { return s.secretManager.RemoveExpiredSecrets() } // ListJWTSecretsInfo returns metadata about every currently-tracked JWT // secret WITHOUT exposing the secret values. Used by the admin // introspection endpoint and BDD tests verifying cleanup behavior. func (s *userServiceImpl) ListJWTSecretsInfo() []JWTSecretInfo { all := s.secretManager.GetAllValidSecrets() now := time.Now() out := make([]JWTSecretInfo, 0, len(all)) for _, sec := range all { hash := sha256.Sum256([]byte(sec.Secret)) info := JWTSecretInfo{ IsPrimary: sec.IsPrimary, CreatedAtUnix: sec.CreatedAt.Unix(), AgeSeconds: int64(now.Sub(sec.CreatedAt).Seconds()), SecretSHA256: hex.EncodeToString(hash[:8]), // 16 hex chars = 8 bytes — fingerprint } if sec.ExpiresAt != nil { exp := sec.ExpiresAt.Unix() info.ExpiresAtUnix = &exp info.IsExpired = !sec.ExpiresAt.After(now) } out = append(out, info) } return out } // UserExists checks if a user exists by username func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) { return s.repo.UserExists(ctx, username) } // CreateUser creates a new user in the database func (s *userServiceImpl) CreateUser(ctx context.Context, user *User) error { return s.repo.CreateUser(ctx, user) } // RequestPasswordReset requests a password reset for a user func (s *userServiceImpl) RequestPasswordReset(ctx context.Context, username string) error { // Check if user exists exists, err := s.repo.UserExists(ctx, username) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } if !exists { return fmt.Errorf("user not found: %s", username) } // Allow password reset return s.repo.AllowPasswordReset(ctx, username) } // CompletePasswordReset completes the password reset process func (s *userServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error { // Hash the new password hashedPassword, err := s.HashPassword(ctx, newPassword) if err != nil { return fmt.Errorf("failed to hash new password: %w", err) } // Complete the password reset return s.repo.CompletePasswordReset(ctx, username, hashedPassword) } // PasswordResetServiceImpl implements the PasswordResetService interface type PasswordResetServiceImpl struct { repo UserRepository auth *userServiceImpl } // NewPasswordResetService creates a new password reset service func NewPasswordResetService(repo UserRepository, auth *userServiceImpl) *PasswordResetServiceImpl { return &PasswordResetServiceImpl{ repo: repo, auth: auth, } } // RequestPasswordReset requests a password reset for a user func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error { // Check if user exists exists, err := s.repo.UserExists(ctx, username) if err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } if !exists { return fmt.Errorf("user not found: %s", username) } // Allow password reset return s.repo.AllowPasswordReset(ctx, username) } // CompletePasswordReset completes the password reset process func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error { // Hash the new password hashedPassword, err := s.auth.HashPassword(ctx, newPassword) if err != nil { return fmt.Errorf("failed to hash new password: %w", err) } // Complete the password reset return s.repo.CompletePasswordReset(ctx, username, hashedPassword) } // tokenFingerprint returns the first 16 hex chars of SHA-256 hash of a token/secret. // Used for safe logging correlation without leaking sensitive values. func tokenFingerprint(tok string) string { if tok == "" { return "" } sum := sha256.Sum256([]byte(tok)) return hex.EncodeToString(sum[:8]) // 16 hex chars = 8 bytes }