diff --git a/adr/0018-user-management-auth-system.md b/adr/0018-user-management-auth-system.md index e7179d3..5cb2c09 100644 --- a/adr/0018-user-management-auth-system.md +++ b/adr/0018-user-management-auth-system.md @@ -120,6 +120,7 @@ The user management system follows the established DanceLessonsCoach patterns: - 30-minute expiration for access tokens - Secure random signing key - HTTPS-only cookies + - **Secret Rotation:** Multiple valid secrets with retention policy (see Issue #8) 3. **Admin Access:** - Master password from environment variable - Non-persisted admin user @@ -464,6 +465,7 @@ The implementation maintains full backward compatibility: 3. **User Activity Logging:** For audit trails 4. **Password Strength Meter:** For better user experience 5. **Account Recovery:** Email/phone-based recovery options +6. **JWT Secret Rotation:** Implement secret persistence and rotation mechanism (Issue #8) ## References diff --git a/data/users.db b/data/users.db new file mode 100644 index 0000000..57cb87d Binary files /dev/null and b/data/users.db differ diff --git a/features/data/users.db b/features/data/users.db new file mode 100644 index 0000000..6c4c16d Binary files /dev/null and b/features/data/users.db differ diff --git a/features/user_authentication.feature b/features/user_authentication.feature index f667d78..0ee0e6a 100644 --- a/features/user_authentication.feature +++ b/features/user_authentication.feature @@ -33,7 +33,7 @@ Feature: User Authentication Scenario: User registration Given the server is running - When I register a new user "newuser" with password "newpass123" + When I register a new user "newuser_" with password "newpass123" Then the registration should be successful And I should be able to authenticate with the new credentials diff --git a/go.mod b/go.mod index acbeecb..30ba5a6 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,11 @@ require ( github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.30.2 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/rs/zerolog v1.35.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 github.com/swaggo/http-swagger v1.3.4 github.com/swaggo/swag v1.16.6 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 @@ -18,6 +20,9 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 + golang.org/x/crypto v0.49.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 ) require ( @@ -26,6 +31,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect github.com/cucumber/messages/go/v21 v21.0.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect @@ -43,12 +49,16 @@ require ( github.com/hashicorp/go-memdb v1.3.5 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.6 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect @@ -61,7 +71,6 @@ require ( go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.49.0 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect @@ -73,4 +82,5 @@ require ( google.golang.org/grpc v1.80.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 706aebd..cc20f29 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -79,6 +81,10 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -99,6 +105,8 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= @@ -212,3 +220,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/pkg/config/config.go b/pkg/config/config.go index aa02a88..ec97778 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,7 @@ type Config struct { Logging LoggingConfig `mapstructure:"logging"` Telemetry TelemetryConfig `mapstructure:"telemetry"` API APIConfig `mapstructure:"api"` + Auth AuthConfig `mapstructure:"auth"` } // ServerConfig holds server-related configuration @@ -54,6 +55,12 @@ type APIConfig struct { V2Enabled bool `mapstructure:"v2_enabled"` } +// AuthConfig holds authentication configuration +type AuthConfig struct { + JWTSecret string `mapstructure:"jwt_secret"` + AdminMasterPassword string `mapstructure:"admin_master_password"` +} + // VersionInfo holds application version information type VersionInfo struct { Version string `mapstructure:"-"` // Set via ldflags @@ -104,6 +111,10 @@ func LoadConfig() (*Config, error) { // API defaults v.SetDefault("api.v2_enabled", false) + // Auth defaults + v.SetDefault("auth.jwt_secret", "default-secret-key-please-change-in-production") + v.SetDefault("auth.admin_master_password", "admin123") + // Check for custom config file path via environment variable if configFile := os.Getenv("DLC_CONFIG_FILE"); configFile != "" { v.SetConfigFile(configFile) @@ -141,6 +152,10 @@ func LoadConfig() (*Config, error) { v.BindEnv("telemetry.otlp_endpoint", "DLC_TELEMETRY_OTLP_ENDPOINT") v.BindEnv("telemetry.service_name", "DLC_TELEMETRY_SERVICE_NAME") v.BindEnv("telemetry.insecure", "DLC_TELEMETRY_INSECURE") + + // Auth environment variables + v.BindEnv("auth.jwt_secret", "DLC_AUTH_JWT_SECRET") + v.BindEnv("auth.admin_master_password", "DLC_AUTH_ADMIN_MASTER_PASSWORD") v.BindEnv("telemetry.sampler.type", "DLC_TELEMETRY_SAMPLER_TYPE") v.BindEnv("telemetry.sampler.ratio", "DLC_TELEMETRY_SAMPLER_RATIO") @@ -220,6 +235,16 @@ func (c *Config) GetV2Enabled() bool { return c.API.V2Enabled } +// GetJWTSecret returns the JWT secret +func (c *Config) GetJWTSecret() string { + return c.Auth.JWTSecret +} + +// GetAdminMasterPassword returns the admin master password +func (c *Config) GetAdminMasterPassword() string { + return c.Auth.AdminMasterPassword +} + // GetLogLevel returns the logging level func (c *Config) GetLogLevel() string { return c.Logging.Level diff --git a/pkg/greet/greet.go b/pkg/greet/greet.go index 2ca7ec4..3548fcd 100644 --- a/pkg/greet/greet.go +++ b/pkg/greet/greet.go @@ -3,21 +3,45 @@ package greet import ( "context" + "dance-lessons-coach/pkg/user" "github.com/rs/zerolog/log" ) +// Context key for storing authenticated user +type contextKey string + +const ( + // UserContextKey is the context key for storing authenticated user + UserContextKey contextKey = "authenticatedUser" +) + type Service struct{} func NewService() *Service { return &Service{} } +// GetAuthenticatedUserFromContext extracts the authenticated user from context +func GetAuthenticatedUserFromContext(ctx context.Context) (*user.User, bool) { + user, ok := ctx.Value(UserContextKey).(*user.User) + return user, ok +} + // Greet returns a greeting message for the given name. -// If name is empty, it defaults to "world". +// If name is empty, it checks for authenticated user and uses their username. +// If no authenticated user and no name, it defaults to "world". // Implements the Greeter interface. func (s *Service) Greet(ctx context.Context, name string) string { log.Trace().Ctx(ctx).Str("name", name).Msg("Greet function called") + // If no name provided, check for authenticated user + if name == "" { + if authenticatedUser, ok := GetAuthenticatedUserFromContext(ctx); ok { + name = authenticatedUser.Username + log.Trace().Ctx(ctx).Str("authenticated_user", name).Msg("Using authenticated username for greeting") + } + } + if name == "" { return "Hello world!" } diff --git a/pkg/server/middleware.go b/pkg/server/middleware.go new file mode 100644 index 0000000..5c1e11d --- /dev/null +++ b/pkg/server/middleware.go @@ -0,0 +1,62 @@ +package server + +import ( + "context" + "net/http" + + "dance-lessons-coach/pkg/greet" + "dance-lessons-coach/pkg/user" + "github.com/rs/zerolog/log" +) + +// AuthMiddleware handles JWT authentication and adds user to context +type AuthMiddleware struct { + authService user.AuthService +} + +// NewAuthMiddleware creates a new authentication middleware +func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware { + return &AuthMiddleware{ + authService: authService, + } +} + +// Middleware returns the authentication middleware function +func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + // No authorization header, pass through with no user + next.ServeHTTP(w, r) + return + } + + // Extract token from "Bearer " format + const bearerPrefix = "Bearer " + if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { + log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format") + next.ServeHTTP(w, r) + return + } + + token := authHeader[len(bearerPrefix):] + + // Validate JWT token + validatedUser, err := m.authService.ValidateJWT(ctx, token) + if err != nil { + log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed") + next.ServeHTTP(w, r) + return + } + + // Add user to context + ctxWithUser := context.WithValue(ctx, greet.UserContextKey, validatedUser) + r = r.WithContext(ctxWithUser) + + // Continue to next handler + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 74ee1e6..00e67d1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -145,7 +145,18 @@ func (s *Server) setupRoutes() { func (s *Server) registerApiV1Routes(r chi.Router) { greetService := greet.NewService() greetHandler := greet.NewApiV1GreetHandler(greetService) + + // Create auth middleware if available + var authMiddleware *AuthMiddleware + if s.authService != nil { + authMiddleware = NewAuthMiddleware(s.authService) + } + r.Route("/greet", func(r chi.Router) { + // Add optional authentication middleware + if authMiddleware != nil { + r.Use(authMiddleware.Middleware) + } greetHandler.RegisterRoutes(r) }) diff --git a/pkg/user/api/auth_handler.go b/pkg/user/api/auth_handler.go new file mode 100644 index 0000000..4b59136 --- /dev/null +++ b/pkg/user/api/auth_handler.go @@ -0,0 +1,224 @@ +package api + +import ( + "encoding/json" + "net/http" + + "dance-lessons-coach/pkg/user" + + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" +) + +// AuthHandler handles authentication-related HTTP requests +type AuthHandler struct { + authService user.AuthService + userRepo user.UserRepository + passwordResetService user.PasswordResetService +} + +// NewAuthHandler creates a new authentication handler +func NewAuthHandler(authService user.AuthService, userRepo user.UserRepository) *AuthHandler { + passwordResetService := user.NewPasswordResetService(userRepo, authService.(*user.AuthServiceImpl)) + return &AuthHandler{ + authService: authService, + userRepo: userRepo, + passwordResetService: passwordResetService, + } +} + +// RegisterRoutes registers authentication routes +func (h *AuthHandler) RegisterRoutes(router chi.Router) { + router.Post("/login", h.handleLogin) + router.Post("/admin/login", h.handleAdminLogin) + router.Post("/register", h.handleRegister) + router.Post("/password-reset/request", h.handlePasswordResetRequest) + router.Post("/password-reset/complete", h.handlePasswordResetComplete) +} + +// LoginRequest represents a login request +type LoginRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + Password string `json:"password" validate:"required,min=6"` +} + +// LoginResponse represents a login response +type LoginResponse struct { + Token string `json:"token"` +} + +// handleLogin handles user login requests +func (h *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Authenticate user + user, err := h.authService.Authenticate(ctx, req.Username, req.Password) + if err != nil { + log.Trace().Ctx(ctx).Err(err).Str("username", req.Username).Msg("Authentication failed") + http.Error(w, `{"error":"invalid_credentials","message":"Invalid username or password"}`, http.StatusUnauthorized) + return + } + + // Generate JWT token + token, err := h.authService.GenerateJWT(ctx, user) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to generate JWT token") + http.Error(w, `{"error":"server_error","message":"Failed to generate authentication token"}`, http.StatusInternalServerError) + return + } + + // Return token + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(LoginResponse{Token: token}) +} + +// handleAdminLogin handles admin login requests +func (h *AuthHandler) handleAdminLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Authenticate admin + adminUser, err := h.authService.AdminAuthenticate(ctx, req.Password) + if err != nil { + log.Trace().Ctx(ctx).Err(err).Msg("Admin authentication failed") + http.Error(w, `{"error":"invalid_credentials","message":"Invalid admin credentials"}`, http.StatusUnauthorized) + return + } + + // Generate JWT token + token, err := h.authService.GenerateJWT(ctx, adminUser) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to generate JWT token for admin") + http.Error(w, `{"error":"server_error","message":"Failed to generate authentication token"}`, http.StatusInternalServerError) + return + } + + // Return token + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(LoginResponse{Token: token}) +} + +// RegisterRequest represents a user registration request +type RegisterRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + Password string `json:"password" validate:"required,min=6"` +} + +// handleRegister handles user registration requests +func (h *AuthHandler) handleRegister(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req RegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Check if user already exists + exists, err := h.userRepo.UserExists(ctx, req.Username) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to check if user exists") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + if exists { + http.Error(w, `{"error":"user_exists","message":"Username already taken"}`, http.StatusConflict) + return + } + + // Hash password + hashedPassword, err := h.authService.HashPassword(ctx, req.Password) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to hash password") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + + // Create user + newUser := &user.User{ + Username: req.Username, + PasswordHash: hashedPassword, + IsAdmin: false, + } + + if err := h.userRepo.CreateUser(ctx, newUser); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to create user") + http.Error(w, `{"error":"server_error","message":"Failed to create user"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"message": "User registered successfully"}) +} + +// PasswordResetRequest represents a password reset request +type PasswordResetRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` +} + +// handlePasswordResetRequest handles password reset requests +func (h *AuthHandler) handlePasswordResetRequest(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Request password reset + if err := h.passwordResetService.RequestPasswordReset(ctx, req.Username); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to request password reset") + http.Error(w, `{"error":"server_error","message":"Failed to process password reset request"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset allowed, user can now reset password"}) +} + +// PasswordResetCompleteRequest represents a password reset completion request +type PasswordResetCompleteRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + NewPassword string `json:"new_password" validate:"required,min=6"` +} + +// handlePasswordResetComplete handles password reset completion requests +func (h *AuthHandler) handlePasswordResetComplete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Complete password reset + if err := h.passwordResetService.CompletePasswordReset(ctx, req.Username, req.NewPassword); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to complete password reset") + http.Error(w, `{"error":"server_error","message":"Failed to complete password reset"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset completed successfully"}) +} diff --git a/pkg/user/auth_service.go b/pkg/user/auth_service.go new file mode 100644 index 0000000..4a43495 --- /dev/null +++ b/pkg/user/auth_service.go @@ -0,0 +1,198 @@ +package user + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "golang.org/x/crypto/bcrypt" +) + +// JWTConfig holds JWT configuration +type JWTConfig struct { + Secret string + ExpirationTime time.Duration + Issuer string +} + +// AuthServiceImpl implements the AuthService interface +type AuthServiceImpl struct { + repo UserRepository + jwtConfig JWTConfig + masterPassword string +} + +// NewAuthService creates a new authentication service +func NewAuthService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *AuthServiceImpl { + return &AuthServiceImpl{ + repo: repo, + jwtConfig: jwtConfig, + masterPassword: masterPassword, + } +} + +// Authenticate authenticates a user with username and password +func (s *AuthServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) { + user, err := s.repo.GetUserByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + if user == nil { + return nil, errors.New("invalid credentials") + } + + // Check password + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return nil, errors.New("invalid credentials") + } + + // Update last login time + now := time.Now() + user.LastLogin = &now + if err := s.repo.UpdateUser(ctx, user); err != nil { + // Don't fail authentication if we can't update last login + // Just log it and continue + } + + return user, nil +} + +// GenerateJWT generates a JWT token for the given user +func (s *AuthServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) { + // Create the claims + claims := jwt.MapClaims{ + "sub": user.ID, + "name": user.Username, + "admin": user.IsAdmin, + "exp": time.Now().Add(s.jwtConfig.ExpirationTime).Unix(), + "iat": time.Now().Unix(), + "iss": s.jwtConfig.Issuer, + } + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Sign and get the complete encoded token as a string + tokenString, err := token.SignedString([]byte(s.jwtConfig.Secret)) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + + return tokenString, nil +} + +// ValidateJWT validates a JWT token and returns the user +func (s *AuthServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) { + // Parse the token + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(s.jwtConfig.Secret), nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + // Check if token is valid + if !token.Valid { + return nil, errors.New("invalid JWT token") + } + + // Get claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("invalid JWT claims") + } + + // Get user ID from claims + userIDFloat, ok := claims["sub"].(float64) + if !ok { + return nil, errors.New("invalid user ID in JWT") + } + + userID := uint(userIDFloat) + + // Get user from repository + user, err := s.repo.GetUserByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get user from JWT: %w", err) + } + if user == nil { + return nil, errors.New("user not found") + } + + return user, nil +} + +// HashPassword hashes a password using bcrypt +func (s *AuthServiceImpl) HashPassword(ctx context.Context, password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash password: %w", err) + } + return string(hash), nil +} + +// AdminAuthenticate authenticates an admin user with master password +func (s *AuthServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) { + // Check if master password matches + if masterPassword != s.masterPassword { + return nil, errors.New("invalid admin credentials") + } + + // Create a virtual admin user (not persisted) + adminUser := &User{ + ID: 0, // Special ID for admin + Username: "admin", + IsAdmin: true, + } + + return adminUser, nil +} + +// PasswordResetServiceImpl implements the PasswordResetService interface +type PasswordResetServiceImpl struct { + repo UserRepository + auth *AuthServiceImpl +} + +// NewPasswordResetService creates a new password reset service +func NewPasswordResetService(repo UserRepository, auth *AuthServiceImpl) *PasswordResetServiceImpl { + return &PasswordResetServiceImpl{ + repo: repo, + auth: auth, + } +} + +// RequestPasswordReset requests a password reset for a user +func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error { + // Check if user exists + exists, err := s.repo.UserExists(ctx, username) + if err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) + } + if !exists { + return fmt.Errorf("user not found: %s", username) + } + + // Allow password reset + return s.repo.AllowPasswordReset(ctx, username) +} + +// CompletePasswordReset completes the password reset process +func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error { + // Hash the new password + hashedPassword, err := s.auth.HashPassword(ctx, newPassword) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + + // Complete the password reset + return s.repo.CompletePasswordReset(ctx, username, hashedPassword) +} diff --git a/pkg/user/sqlite_repository.go b/pkg/user/sqlite_repository.go new file mode 100644 index 0000000..caec64d --- /dev/null +++ b/pkg/user/sqlite_repository.go @@ -0,0 +1,174 @@ +package user + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// SQLiteRepository implements UserRepository using SQLite +type SQLiteRepository struct { + db *gorm.DB + dbPath string +} + +// NewSQLiteRepository creates a new SQLite repository +func NewSQLiteRepository(dbPath string) (*SQLiteRepository, error) { + repo := &SQLiteRepository{ + dbPath: dbPath, + } + + if err := repo.initializeDatabase(); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + return repo, nil +} + +// initializeDatabase sets up the SQLite database and runs migrations +func (r *SQLiteRepository) initializeDatabase() error { + // Create directory if it doesn't exist + dir := filepath.Dir(r.dbPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Configure GORM logger to use standard log + gormLogger := logger.New( + log.New(os.Stdout, "\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logger.Warn, + IgnoreRecordNotFoundError: true, + Colorful: true, + }, + ) + + var err error + r.db, err = gorm.Open(sqlite.Open(r.dbPath), &gorm.Config{ + Logger: gormLogger, + }) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + // Auto-migrate the User model + if err := r.db.AutoMigrate(&User{}); err != nil { + return fmt.Errorf("failed to auto-migrate: %w", err) + } + + return nil +} + +// CreateUser creates a new user in the database +func (r *SQLiteRepository) CreateUser(ctx context.Context, user *User) error { + result := r.db.WithContext(ctx).Create(user) + if result.Error != nil { + return fmt.Errorf("failed to create user: %w", result.Error) + } + return nil +} + +// GetUserByUsername retrieves a user by username +func (r *SQLiteRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) { + var user User + result := r.db.WithContext(ctx).Where("username = ?", username).First(&user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get user by username: %w", result.Error) + } + return &user, nil +} + +// GetUserByID retrieves a user by ID +func (r *SQLiteRepository) GetUserByID(ctx context.Context, id uint) (*User, error) { + var user User + result := r.db.WithContext(ctx).First(&user, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get user by ID: %w", result.Error) + } + return &user, nil +} + +// UpdateUser updates a user in the database +func (r *SQLiteRepository) UpdateUser(ctx context.Context, user *User) error { + result := r.db.WithContext(ctx).Save(user) + if result.Error != nil { + return fmt.Errorf("failed to update user: %w", result.Error) + } + return nil +} + +// DeleteUser deletes a user from the database +func (r *SQLiteRepository) DeleteUser(ctx context.Context, id uint) error { + result := r.db.WithContext(ctx).Delete(&User{}, id) + if result.Error != nil { + return fmt.Errorf("failed to delete user: %w", result.Error) + } + return nil +} + +// AllowPasswordReset flags a user for password reset +func (r *SQLiteRepository) AllowPasswordReset(ctx context.Context, username string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + user.AllowPasswordReset = true + return r.UpdateUser(ctx, user) +} + +// CompletePasswordReset completes the password reset process +func (r *SQLiteRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset completion: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + if !user.AllowPasswordReset { + return fmt.Errorf("password reset not allowed for user: %s", username) + } + + user.PasswordHash = newPasswordHash + user.AllowPasswordReset = false + return r.UpdateUser(ctx, user) +} + +// UserExists checks if a user exists by username +func (r *SQLiteRepository) UserExists(ctx context.Context, username string) (bool, error) { + var count int64 + result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count) + if result.Error != nil { + return false, fmt.Errorf("failed to check if user exists: %w", result.Error) + } + return count > 0, nil +} + +// Close closes the database connection +func (r *SQLiteRepository) Close() error { + sqlDB, err := r.db.DB() + if err != nil { + return fmt.Errorf("failed to get database connection: %w", err) + } + return sqlDB.Close() +} diff --git a/pkg/user/user.go b/pkg/user/user.go new file mode 100644 index 0000000..cf857dd --- /dev/null +++ b/pkg/user/user.go @@ -0,0 +1,48 @@ +package user + +import ( + "context" + "time" +) + +// User represents a user in the system +type User struct { + ID uint `json:"id" gorm:"primaryKey"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt *time.Time `json:"deleted_at,omitempty" gorm:"index"` + Username string `json:"username" gorm:"unique;not null" validate:"required,min=3,max=50"` + PasswordHash string `json:"-" gorm:"not null"` + Description *string `json:"description,omitempty"` + CurrentGoal *string `json:"current_goal,omitempty"` + IsAdmin bool `json:"is_admin" gorm:"default:false"` + AllowPasswordReset bool `json:"allow_password_reset" gorm:"default:false"` + LastLogin *time.Time `json:"last_login,omitempty"` +} + +// UserRepository defines the interface for user persistence +type UserRepository interface { + CreateUser(ctx context.Context, user *User) error + GetUserByUsername(ctx context.Context, username string) (*User, error) + GetUserByID(ctx context.Context, id uint) (*User, error) + UpdateUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id uint) error + AllowPasswordReset(ctx context.Context, username string) error + CompletePasswordReset(ctx context.Context, username, newPassword string) error + UserExists(ctx context.Context, username string) (bool, error) +} + +// AuthService defines the interface for authentication +type AuthService interface { + Authenticate(ctx context.Context, username, password string) (*User, error) + GenerateJWT(ctx context.Context, user *User) (string, error) + ValidateJWT(ctx context.Context, token string) (*User, error) + HashPassword(ctx context.Context, password string) (string, error) + AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) +} + +// PasswordResetService defines the interface for password reset workflow +type PasswordResetService interface { + RequestPasswordReset(ctx context.Context, username string) error + CompletePasswordReset(ctx context.Context, username, newPassword string) error +} diff --git a/pkg/user/user_test.go b/pkg/user/user_test.go new file mode 100644 index 0000000..8ba06b9 --- /dev/null +++ b/pkg/user/user_test.go @@ -0,0 +1,222 @@ +package user + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSQLiteRepository(t *testing.T) { + t.Run("CRUD operations", func(t *testing.T) { + // Create a temporary database + dbPath := "test_db.sqlite" + defer os.Remove(dbPath) + + repo, err := NewSQLiteRepository(dbPath) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Test CreateUser + user := &User{ + Username: "testuser", + PasswordHash: "hashedpassword", + Description: ptrString("Test user"), + CurrentGoal: ptrString("Learn to dance"), + IsAdmin: false, + } + + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + assert.NotZero(t, user.ID) + + // Test GetUserByUsername + retrievedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.NotNil(t, retrievedUser) + assert.Equal(t, "testuser", retrievedUser.Username) + + // Test UserExists + exists, err := repo.UserExists(ctx, "testuser") + require.NoError(t, err) + assert.True(t, exists) + + // Test UpdateUser + retrievedUser.Description = ptrString("Updated description") + err = repo.UpdateUser(ctx, retrievedUser) + require.NoError(t, err) + + // Verify update + updatedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Equal(t, "Updated description", *updatedUser.Description) + + // Test AllowPasswordReset + err = repo.AllowPasswordReset(ctx, "testuser") + require.NoError(t, err) + + // Verify password reset flag + userWithReset, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.True(t, userWithReset.AllowPasswordReset) + + // Test CompletePasswordReset + err = repo.CompletePasswordReset(ctx, "testuser", "newhashedpassword") + require.NoError(t, err) + + // Verify password reset completion + userAfterReset, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Equal(t, "newhashedpassword", userAfterReset.PasswordHash) + assert.False(t, userAfterReset.AllowPasswordReset) + + // Test DeleteUser + err = repo.DeleteUser(ctx, userAfterReset.ID) + require.NoError(t, err) + + // Verify deletion + deletedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Nil(t, deletedUser) + }) +} + +func TestAuthService(t *testing.T) { + t.Run("Password hashing and authentication", func(t *testing.T) { + // Create a temporary database + dbPath := "test_auth_db.sqlite" + defer os.Remove(dbPath) + + repo, err := NewSQLiteRepository(dbPath) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Create auth service + jwtConfig := JWTConfig{ + Secret: "test-secret", + ExpirationTime: time.Hour, + Issuer: "test-issuer", + } + authService := NewAuthService(repo, jwtConfig, "admin123") + + // Test password hashing + password := "testpassword123" + hashedPassword, err := authService.HashPassword(ctx, password) + require.NoError(t, err) + assert.NotEmpty(t, hashedPassword) + + // Create a test user + user := &User{ + Username: "testuser", + PasswordHash: hashedPassword, + } + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + + // Test successful authentication + authenticatedUser, err := authService.Authenticate(ctx, "testuser", password) + require.NoError(t, err) + assert.NotNil(t, authenticatedUser) + assert.Equal(t, "testuser", authenticatedUser.Username) + + // Test failed authentication with wrong password + _, err = authService.Authenticate(ctx, "testuser", "wrongpassword") + assert.Error(t, err) + assert.Equal(t, "invalid credentials", err.Error()) + + // Test JWT generation + token, err := authService.GenerateJWT(ctx, authenticatedUser) + require.NoError(t, err) + assert.NotEmpty(t, token) + + // Test JWT validation + validatedUser, err := authService.ValidateJWT(ctx, token) + require.NoError(t, err) + assert.NotNil(t, validatedUser) + assert.Equal(t, authenticatedUser.ID, validatedUser.ID) + + // Test admin authentication + adminUser, err := authService.AdminAuthenticate(ctx, "admin123") + require.NoError(t, err) + assert.NotNil(t, adminUser) + assert.True(t, adminUser.IsAdmin) + assert.Equal(t, "admin", adminUser.Username) + + // Test failed admin authentication + _, err = authService.AdminAuthenticate(ctx, "wrongadminpassword") + assert.Error(t, err) + assert.Equal(t, "invalid admin credentials", err.Error()) + }) +} + +func TestPasswordResetService(t *testing.T) { + t.Run("Password reset workflow", func(t *testing.T) { + // Create a temporary database + dbPath := "test_reset_db.sqlite" + defer os.Remove(dbPath) + + repo, err := NewSQLiteRepository(dbPath) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Create auth service + jwtConfig := JWTConfig{ + Secret: "test-secret", + ExpirationTime: time.Hour, + Issuer: "test-issuer", + } + authService := NewAuthService(repo, jwtConfig, "admin123") + passwordResetService := NewPasswordResetService(repo, authService) + + // Create a test user + password := "oldpassword123" + hashedPassword, err := authService.HashPassword(ctx, password) + require.NoError(t, err) + + user := &User{ + Username: "resetuser", + PasswordHash: hashedPassword, + } + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + + // Test password reset request + err = passwordResetService.RequestPasswordReset(ctx, "resetuser") + require.NoError(t, err) + + // Verify user is flagged for reset + userAfterRequest, err := repo.GetUserByUsername(ctx, "resetuser") + require.NoError(t, err) + assert.True(t, userAfterRequest.AllowPasswordReset) + + // Test password reset completion + newPassword := "newpassword123" + err = passwordResetService.CompletePasswordReset(ctx, "resetuser", newPassword) + require.NoError(t, err) + + // Verify password was updated and reset flag was cleared + userAfterReset, err := repo.GetUserByUsername(ctx, "resetuser") + require.NoError(t, err) + assert.False(t, userAfterReset.AllowPasswordReset) + + // Verify new password works by authenticating with the new password + authenticatedUser, err := authService.Authenticate(ctx, "resetuser", newPassword) + require.NoError(t, err) + assert.NotNil(t, authenticatedUser) + assert.Equal(t, "resetuser", authenticatedUser.Username) + }) +} + +// Helper function to create string pointers +func ptrString(s string) *string { + return &s +} diff --git a/server b/server new file mode 100755 index 0000000..5b9de78 Binary files /dev/null and b/server differ