From 52a4ce41392f6144d1792555c0f8ed676180564e Mon Sep 17 00:00:00 2001 From: Gabriel Radureau Date: Thu, 9 Apr 2026 00:25:43 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20implement=20user=20authenti?= =?UTF-8?q?cation=20system=20with=20JWT=20and=20PostgreSQL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive user management system: - User registration with validation (3-50 char username, 6+ char password) - JWT-based authentication with bcrypt password hashing - Admin authentication with master password - Password reset workflow with admin flagging - PostgreSQL repository implementation - SQLite repository for testing - Unified authentication service interface API Endpoints: - POST /api/v1/auth/register - User registration - POST /api/v1/auth/login - User/admin authentication - POST /api/v1/auth/password-reset/request - Request password reset - POST /api/v1/auth/password-reset/complete - Complete password reset - POST /api/v1/auth/validate - JWT token validation Security Features: - Password hashing with bcrypt - JWT token generation and validation - Admin claims in JWT tokens - Configurable token expiration - Input validation for all endpoints Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe --- cmd/server/main.go | 13 +- config.yaml | 40 +++- go.mod | 18 +- go.sum | 25 +++ pkg/user/api/auth_handler.go | 359 +++++++++++++++++++++++++++++++ pkg/user/api/password_handler.go | 79 +++++++ pkg/user/api/user_handler.go | 81 +++++++ pkg/user/auth_service.go | 235 ++++++++++++++++++++ pkg/user/postgres_repository.go | 351 ++++++++++++++++++++++++++++++ pkg/user/sqlite_repository.go | 225 +++++++++++++++++++ pkg/user/user.go | 69 ++++++ pkg/user/user_test.go | 237 ++++++++++++++++++++ 12 files changed, 1723 insertions(+), 9 deletions(-) create mode 100644 pkg/user/api/auth_handler.go create mode 100644 pkg/user/api/password_handler.go create mode 100644 pkg/user/api/user_handler.go create mode 100644 pkg/user/auth_service.go create mode 100644 pkg/user/postgres_repository.go create mode 100644 pkg/user/sqlite_repository.go create mode 100644 pkg/user/user.go create mode 100644 pkg/user/user_test.go diff --git a/cmd/server/main.go b/cmd/server/main.go index a682e76..6f8cf38 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,7 +1,7 @@ // Package main provides the dance-lessons-coach server entry point // // @title dance-lessons-coach API -// @version 1.2.0 +// @version 1.4.0 // @description API for dance-lessons-coach service providing greeting functionality // @termsOfService http://swagger.io/terms/ @@ -12,9 +12,14 @@ // @license.name MIT // @license.url https://opensource.org/licenses/MIT -// @host localhost:8080 -// @BasePath /api -// @schemes http https +// @host localhost:8080 +// @BasePath /api +// @schemes http https +// +// @securityDefinitions.apikey BearerAuth +// @in header +// @name Authorization +// @description JWT authentication using Bearer token. Format: Bearer package main diff --git a/config.yaml b/config.yaml index ef18b21..7d10a05 100644 --- a/config.yaml +++ b/config.yaml @@ -1,4 +1,4 @@ -# DanceLessonsCoach Configuration +# dance-lessons-coach Configuration # This file serves as both the default configuration and documentation # All available options are shown with their default values @@ -41,8 +41,8 @@ telemetry: # Format: host:port otlp_endpoint: "localhost:4317" - # Service name for tracing (default: "DanceLessonsCoach") - service_name: "DanceLessonsCoach" + # Service name for tracing (default: "dance-lessons-coach") + service_name: "dance-lessons-coach" # Use insecure connection (no TLS) (default: true) insecure: true @@ -55,4 +55,36 @@ telemetry: # Sampling ratio (0.0 to 1.0, default: 1.0) # Only used with traceidratio and parentbased_traceidratio samplers - ratio: 1.0 \ No newline at end of file + ratio: 1.0 + +# Database configuration (PostgreSQL) +database: + # PostgreSQL host address (default: "localhost") + host: "localhost" + + # PostgreSQL port (default: 5432) + port: 5432 + + # PostgreSQL username (default: "postgres") + user: "postgres" + + # PostgreSQL password (default: "postgres") + # Change this for production! + password: "postgres" + + # Database name (default: "dance_lessons_coach") + name: "dance_lessons_coach" + + # SSL mode (default: "disable") + # Options: "disable", "allow", "prefer", "require", "verify-ca", "verify-full" + ssl_mode: "disable" + + # Maximum number of open connections (default: 25) + max_open_conns: 25 + + # Maximum number of idle connections (default: 5) + max_idle_conns: 5 + + # Maximum lifetime of connections (default: "1h") + # Format: number + unit (s, m, h) + conn_max_lifetime: 1h \ No newline at end of file diff --git a/go.mod b/go.mod index acbeecb..df37303 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,12 @@ require ( github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.30.2 + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/lib/pq v1.12.3 github.com/rs/zerolog v1.35.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 github.com/swaggo/http-swagger v1.3.4 github.com/swaggo/swag v1.16.6 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 @@ -18,6 +21,10 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 + golang.org/x/crypto v0.49.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 ) require ( @@ -26,6 +33,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect github.com/cucumber/messages/go/v21 v21.0.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect @@ -43,12 +51,20 @@ require ( github.com/hashicorp/go-memdb v1.3.5 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.6 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect @@ -61,7 +77,6 @@ require ( go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.49.0 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect @@ -73,4 +88,5 @@ require ( google.golang.org/grpc v1.80.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 706aebd..71307a4 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -79,6 +81,18 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -91,6 +105,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= @@ -99,6 +115,8 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= @@ -131,6 +149,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -212,3 +231,9 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/pkg/user/api/auth_handler.go b/pkg/user/api/auth_handler.go new file mode 100644 index 0000000..18a9174 --- /dev/null +++ b/pkg/user/api/auth_handler.go @@ -0,0 +1,359 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "dance-lessons-coach/pkg/user" + "dance-lessons-coach/pkg/validation" + + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" +) + +// AuthHandler handles authentication-related HTTP requests +type AuthHandler struct { + authService user.AuthService + userService user.UserService + validator *validation.Validator +} + +// NewAuthHandler creates a new authentication handler +func NewAuthHandler(authService user.AuthService, userService user.UserService, validator *validation.Validator) *AuthHandler { + return &AuthHandler{ + authService: authService, + userService: userService, + validator: validator, + } +} + +// RegisterRoutes registers authentication routes +func (h *AuthHandler) RegisterRoutes(router chi.Router) { + router.Post("/login", h.handleLogin) + router.Post("/register", h.handleRegister) + router.Post("/password-reset/request", h.handlePasswordResetRequest) + router.Post("/password-reset/complete", h.handlePasswordResetComplete) + router.Post("/validate", h.handleValidateToken) +} + +// writeValidationError writes a structured validation error response +func (h *AuthHandler) writeValidationError(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + + // The validator returns a ValidationError that we can use directly + var validationErr *validation.ValidationError + if errors.As(err, &validationErr) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "validation_failed", + "message": "Invalid request data", + "details": validationErr.Messages, + }) + return + } + + // Fallback for other error types + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "validation_failed", + "message": err.Error(), + }) +} + +// LoginRequest represents a login request +type LoginRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + Password string `json:"password" validate:"required,min=6"` +} + +// LoginResponse represents a login response +type LoginResponse struct { + Token string `json:"token"` +} + +// handleLogin godoc +// +// @Summary User login +// @Description Authenticate user or admin and return JWT token. Supports both regular users and admin authentication. +// @Tags API/v1/User +// @Accept json +// @Produce json +// @Param request body LoginRequest true "Login credentials" +// @Success 200 {object} LoginResponse "Successful authentication" +// @Failure 400 {object} map[string]string "Invalid request" +// @Failure 401 {object} map[string]string "Invalid credentials" +// @Failure 500 {object} map[string]string "Server error" +// @Router /v1/auth/login [post] +func (h *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Validate request using validator + if h.validator != nil { + if err := h.validator.Validate(req); err != nil { + h.writeValidationError(w, err) + return + } + } + + // Try unified authentication (regular user first, then admin fallback) + var authenticatedUser *user.User + var authError error + + // Try regular user authentication first + authenticatedUser, authError = h.authService.Authenticate(ctx, req.Username, req.Password) + + // If regular auth fails, try admin authentication + if authError != nil { + authenticatedUser, authError = h.authService.AdminAuthenticate(ctx, req.Password) + } + + // If both authentication methods failed + if authError != nil { + log.Trace().Ctx(ctx).Err(authError).Str("username", req.Username).Msg("Authentication failed") + http.Error(w, `{"error":"invalid_credentials","message":"Invalid username or password"}`, http.StatusUnauthorized) + return + } + + // Generate JWT token using the authenticated user (regular or admin) + token, err := h.authService.GenerateJWT(ctx, authenticatedUser) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to generate JWT token") + http.Error(w, `{"error":"server_error","message":"Failed to generate authentication token"}`, http.StatusInternalServerError) + return + } + + // Return token + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(LoginResponse{Token: token}) +} + +// RegisterRequest represents a user registration request +type RegisterRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + Password string `json:"password" validate:"required,min=6,max=100"` +} + +// handleRegister godoc +// +// @Summary User registration +// @Description Register a new user account +// @Tags API/v1/User +// @Accept json +// @Produce json +// @Param request body RegisterRequest true "Registration details" +// @Success 201 {object} map[string]string "User created" +// @Failure 400 {object} map[string]string "Invalid request" +// @Failure 409 {object} map[string]string "Username already taken" +// @Failure 500 {object} map[string]string "Server error" +// @Router /v1/auth/register [post] +func (h *AuthHandler) handleRegister(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req RegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Validate request using validator + if h.validator != nil { + if err := h.validator.Validate(req); err != nil { + h.writeValidationError(w, err) + return + } + } + + // Check if user already exists + exists, err := h.userService.UserExists(ctx, req.Username) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to check if user exists") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + if exists { + http.Error(w, `{"error":"user_exists","message":"Username already taken"}`, http.StatusConflict) + return + } + + // Hash password + hashedPassword, err := h.userService.HashPassword(ctx, req.Password) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to hash password") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + + // Create user + newUser := &user.User{ + Username: req.Username, + PasswordHash: hashedPassword, + IsAdmin: false, + } + + if err := h.userService.CreateUser(ctx, newUser); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to create user") + http.Error(w, `{"error":"server_error","message":"Failed to create user"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"message": "User registered successfully"}) +} + +// PasswordResetRequest represents a password reset request +type PasswordResetRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` +} + +// handlePasswordResetRequest godoc +// +// @Summary Request password reset +// @Description Initiate password reset process for a user +// @Tags API/v1/User +// @Accept json +// @Produce json +// @Param request body PasswordResetRequest true "Password reset request" +// @Success 200 {object} map[string]string "Reset allowed" +// @Failure 400 {object} map[string]string "Invalid request" +// @Failure 500 {object} map[string]string "Server error" +// @Router /v1/auth/password-reset/request [post] +func (h *AuthHandler) handlePasswordResetRequest(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Validate request using validator + if h.validator != nil { + if err := h.validator.Validate(req); err != nil { + h.writeValidationError(w, err) + return + } + } + + // Request password reset + if err := h.userService.RequestPasswordReset(ctx, req.Username); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to request password reset") + http.Error(w, `{"error":"server_error","message":"Failed to process password reset request"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset allowed, user can now reset password"}) +} + +// PasswordResetCompleteRequest represents a password reset completion request +type PasswordResetCompleteRequest struct { + Username string `json:"username" validate:"required,min=3,max=50"` + NewPassword string `json:"new_password" validate:"required,min=6,max=100"` +} + +// handlePasswordResetComplete godoc +// +// @Summary Complete password reset +// @Description Complete password reset with new password +// @Tags API/v1/User +// @Accept json +// @Produce json +// @Param request body PasswordResetCompleteRequest true "Password reset completion" +// @Success 200 {object} map[string]string "Password updated" +// @Failure 400 {object} map[string]string "Invalid request" +// @Failure 500 {object} map[string]string "Server error" +// @Router /v1/auth/password-reset/complete [post] +func (h *AuthHandler) handlePasswordResetComplete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Validate request using validator + if h.validator != nil { + if err := h.validator.Validate(req); err != nil { + h.writeValidationError(w, err) + return + } + } + + // Complete password reset + if err := h.userService.CompletePasswordReset(ctx, req.Username, req.NewPassword); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to complete password reset") + http.Error(w, `{"error":"server_error","message":"Failed to complete password reset"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset completed successfully"}) +} + +// TokenValidationRequest represents a JWT token validation request +// This is used for testing JWT validation with different token scenarios +type TokenValidationRequest struct { + Token string `json:"token" validate:"required"` +} + +// handleValidateToken godoc +// +// @Summary Validate JWT token +// @Description Validate a JWT token and return user information if valid +// @Tags API/v1/User +// @Accept json +// @Produce json +// @Param request body TokenValidationRequest true "Token validation request" +// @Success 200 {object} map[string]interface{} "Token is valid with user info" +// @Failure 400 {object} map[string]string "Invalid request" +// @Failure 401 {object} map[string]string "Invalid token" +// @Router /v1/auth/validate [post] +func (h *AuthHandler) handleValidateToken(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req TokenValidationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Validate request using validator + if h.validator != nil { + if err := h.validator.Validate(req); err != nil { + h.writeValidationError(w, err) + return + } + } + + // Validate the JWT token + user, err := h.authService.ValidateJWT(ctx, req.Token) + if err != nil { + log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed in validate endpoint") + http.Error(w, fmt.Sprintf(`{"error":"invalid_token","message":"%s"}`, err.Error()), http.StatusUnauthorized) + return + } + + // Return success with user info + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "valid": true, + "user_id": user.ID, + "message": "Token is valid", + }) +} diff --git a/pkg/user/api/password_handler.go b/pkg/user/api/password_handler.go new file mode 100644 index 0000000..8b4ea8f --- /dev/null +++ b/pkg/user/api/password_handler.go @@ -0,0 +1,79 @@ +package api + +import ( + "encoding/json" + "net/http" + + "dance-lessons-coach/pkg/user" + + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" +) + +// PasswordResetHandler handles password reset requests +type PasswordResetHandler struct { + passwordResetService user.PasswordResetService +} + +// NewPasswordResetHandler creates a new password reset handler +func NewPasswordResetHandler(passwordResetService user.PasswordResetService) *PasswordResetHandler { + return &PasswordResetHandler{ + passwordResetService: passwordResetService, + } +} + +// RegisterRoutes registers password reset routes +func (h *PasswordResetHandler) RegisterRoutes(router chi.Router) { + router.Post("/password-reset/request", h.handlePasswordResetRequest) + router.Post("/password-reset/complete", h.handlePasswordResetComplete) +} + +// PasswordResetRequest represents a password reset request + +// handlePasswordResetRequest handles password reset requests +func (h *PasswordResetHandler) handlePasswordResetRequest(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Request password reset + if err := h.passwordResetService.RequestPasswordReset(ctx, req.Username); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to request password reset") + http.Error(w, `{"error":"server_error","message":"Failed to process password reset request"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset allowed, user can now reset password"}) +} + +// PasswordResetCompleteRequest represents a password reset completion request + +// handlePasswordResetComplete handles password reset completion requests +func (h *PasswordResetHandler) handlePasswordResetComplete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req PasswordResetCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Complete password reset + if err := h.passwordResetService.CompletePasswordReset(ctx, req.Username, req.NewPassword); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to complete password reset") + http.Error(w, `{"error":"server_error","message":"Failed to complete password reset"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "Password reset completed successfully"}) +} diff --git a/pkg/user/api/user_handler.go b/pkg/user/api/user_handler.go new file mode 100644 index 0000000..e91cfe7 --- /dev/null +++ b/pkg/user/api/user_handler.go @@ -0,0 +1,81 @@ +package api + +import ( + "encoding/json" + "net/http" + + "dance-lessons-coach/pkg/user" + + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" +) + +// UserHandler handles user management requests +type UserHandler struct { + userRepo user.UserRepository + passwordService user.PasswordService +} + +// NewUserHandler creates a new user handler +func NewUserHandler(userRepo user.UserRepository, passwordService user.PasswordService) *UserHandler { + return &UserHandler{ + userRepo: userRepo, + passwordService: passwordService, + } +} + +// RegisterRoutes registers user routes +func (h *UserHandler) RegisterRoutes(router chi.Router) { + router.Post("/register", h.handleRegister) +} + +// RegisterRequest represents a user registration request + +// handleRegister handles user registration requests +func (h *UserHandler) handleRegister(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req RegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest) + return + } + + // Check if user already exists + exists, err := h.userRepo.UserExists(ctx, req.Username) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to check if user exists") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + if exists { + http.Error(w, `{"error":"user_exists","message":"Username already taken"}`, http.StatusConflict) + return + } + + // Hash password + hashedPassword, err := h.passwordService.HashPassword(ctx, req.Password) + if err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to hash password") + http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError) + return + } + + // Create user + newUser := &user.User{ + Username: req.Username, + PasswordHash: hashedPassword, + IsAdmin: false, + } + + if err := h.userRepo.CreateUser(ctx, newUser); err != nil { + log.Error().Ctx(ctx).Err(err).Msg("Failed to create user") + http.Error(w, `{"error":"server_error","message":"Failed to create user"}`, http.StatusInternalServerError) + return + } + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"message": "User registered successfully"}) +} diff --git a/pkg/user/auth_service.go b/pkg/user/auth_service.go new file mode 100644 index 0000000..2bd01e3 --- /dev/null +++ b/pkg/user/auth_service.go @@ -0,0 +1,235 @@ +package user + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "golang.org/x/crypto/bcrypt" +) + +// JWTConfig holds JWT configuration +type JWTConfig struct { + Secret string + ExpirationTime time.Duration + Issuer string +} + +// userServiceImpl implements the unified UserService interface +type userServiceImpl struct { + repo UserRepository + jwtConfig JWTConfig + masterPassword string +} + +// NewUserService creates a new user service with all functionality +func NewUserService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *userServiceImpl { + return &userServiceImpl{ + repo: repo, + jwtConfig: jwtConfig, + masterPassword: masterPassword, + } +} + +// Authenticate authenticates a user with username and password +func (s *userServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) { + user, err := s.repo.GetUserByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + if user == nil { + return nil, errors.New("invalid credentials") + } + + // Check password + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return nil, errors.New("invalid credentials") + } + + // Update last login time + now := time.Now() + user.LastLogin = &now + if err := s.repo.UpdateUser(ctx, user); err != nil { + // Don't fail authentication if we can't update last login + // Just log it and continue + } + + return user, nil +} + +// GenerateJWT generates a JWT token for the given user +func (s *userServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) { + // Create the claims + claims := jwt.MapClaims{ + "sub": user.ID, + "name": user.Username, + "admin": user.IsAdmin, + "exp": time.Now().Add(s.jwtConfig.ExpirationTime).Unix(), + "iat": time.Now().Unix(), + "iss": s.jwtConfig.Issuer, + } + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Sign and get the complete encoded token as a string + tokenString, err := token.SignedString([]byte(s.jwtConfig.Secret)) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + + return tokenString, nil +} + +// ValidateJWT validates a JWT token and returns the user +func (s *userServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) { + // Parse the token + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(s.jwtConfig.Secret), nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + // Check if token is valid + if !token.Valid { + return nil, errors.New("invalid JWT token") + } + + // Get claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("invalid JWT claims") + } + + // Get user ID from claims + userIDFloat, ok := claims["sub"].(float64) + if !ok { + return nil, errors.New("invalid user ID in JWT") + } + + userID := uint(userIDFloat) + + // Get user from repository + user, err := s.repo.GetUserByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get user from JWT: %w", err) + } + if user == nil { + return nil, errors.New("user not found") + } + + return user, nil +} + +// HashPassword hashes a password using bcrypt (implements PasswordService interface) +func (s *userServiceImpl) HashPassword(ctx context.Context, password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash password: %w", err) + } + return string(hash), nil +} + +// AdminAuthenticate authenticates an admin user with master password +func (s *userServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) { + // Check if master password matches + if masterPassword != s.masterPassword { + return nil, errors.New("invalid admin credentials") + } + + // Create a virtual admin user (not persisted) + adminUser := &User{ + ID: 0, // Special ID for admin + Username: "admin", + IsAdmin: true, + } + + return adminUser, nil +} + +// UserExists checks if a user exists by username +func (s *userServiceImpl) UserExists(ctx context.Context, username string) (bool, error) { + return s.repo.UserExists(ctx, username) +} + +// CreateUser creates a new user in the database +func (s *userServiceImpl) CreateUser(ctx context.Context, user *User) error { + return s.repo.CreateUser(ctx, user) +} + +// RequestPasswordReset requests a password reset for a user +func (s *userServiceImpl) RequestPasswordReset(ctx context.Context, username string) error { + // Check if user exists + exists, err := s.repo.UserExists(ctx, username) + if err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) + } + if !exists { + return fmt.Errorf("user not found: %s", username) + } + + // Allow password reset + return s.repo.AllowPasswordReset(ctx, username) +} + +// CompletePasswordReset completes the password reset process +func (s *userServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error { + // Hash the new password + hashedPassword, err := s.HashPassword(ctx, newPassword) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + + // Complete the password reset + return s.repo.CompletePasswordReset(ctx, username, hashedPassword) +} + +// PasswordResetServiceImpl implements the PasswordResetService interface +type PasswordResetServiceImpl struct { + repo UserRepository + auth *userServiceImpl +} + +// NewPasswordResetService creates a new password reset service +func NewPasswordResetService(repo UserRepository, auth *userServiceImpl) *PasswordResetServiceImpl { + return &PasswordResetServiceImpl{ + repo: repo, + auth: auth, + } +} + +// RequestPasswordReset requests a password reset for a user +func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error { + // Check if user exists + exists, err := s.repo.UserExists(ctx, username) + if err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) + } + if !exists { + return fmt.Errorf("user not found: %s", username) + } + + // Allow password reset + return s.repo.AllowPasswordReset(ctx, username) +} + +// CompletePasswordReset completes the password reset process +func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error { + // Hash the new password + hashedPassword, err := s.auth.HashPassword(ctx, newPassword) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + + // Complete the password reset + return s.repo.CompletePasswordReset(ctx, username, hashedPassword) +} diff --git a/pkg/user/postgres_repository.go b/pkg/user/postgres_repository.go new file mode 100644 index 0000000..54209c0 --- /dev/null +++ b/pkg/user/postgres_repository.go @@ -0,0 +1,351 @@ +package user + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "time" + + "dance-lessons-coach/pkg/config" + + "github.com/rs/zerolog" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// ZerologWriter implements logger.Writer interface using zerolog +type ZerologWriter struct { + logger zerolog.Logger +} + +func (zw *ZerologWriter) Printf(format string, v ...interface{}) { + message := fmt.Sprintf(format, v...) + + // Determine appropriate log level based on message content + if len(message) > 0 { + // Check for error indicators + if containsErrorIndicators(message) { + zw.logger.Error().Str("gorm", message).Send() + return + } + + // Check for slow query indicators + if containsSlowQueryIndicators(message) { + zw.logger.Warn().Str("gorm", message).Send() + return + } + + // Default to debug level for regular SQL queries + zw.logger.Debug().Str("gorm", message).Send() + } +} + +// containsErrorIndicators checks if the message contains error-related keywords +func containsErrorIndicators(message string) bool { + errorKeywords := []string{"error", "Error", "failed", "Failed", "not found", "Not Found"} + for _, keyword := range errorKeywords { + if containsIgnoreCase(message, keyword) { + return true + } + } + return false +} + +// containsSlowQueryIndicators checks if the message contains slow query indicators +func containsSlowQueryIndicators(message string) bool { + slowKeywords := []string{"slow", "Slow", "timeout", "Timeout"} + for _, keyword := range slowKeywords { + if containsIgnoreCase(message, keyword) { + return true + } + } + return false +} + +// containsIgnoreCase performs case-insensitive string containment check +func containsIgnoreCase(s, substr string) bool { + return containsIgnoreCaseBytes([]byte(s), []byte(substr)) +} + +// containsIgnoreCaseBytes is a helper for case-insensitive byte slice containment +func containsIgnoreCaseBytes(s, substr []byte) bool { + if len(substr) == 0 { + return true + } + if len(s) < len(substr) { + return false + } + for i := 0; i <= len(s)-len(substr); i++ { + match := true + for j := 0; j < len(substr); j++ { + if toLower(s[i+j]) != toLower(substr[j]) { + match = false + break + } + } + if match { + return true + } + } + return false +} + +// toLower converts byte to lowercase +func toLower(b byte) byte { + if b >= 'A' && b <= 'Z' { + return b + 32 + } + return b +} + +// PostgresRepository implements UserRepository using PostgreSQL +type PostgresRepository struct { + db *gorm.DB + config *config.Config + spanPrefix string +} + +// NewPostgresRepository creates a new PostgreSQL repository +func NewPostgresRepository(cfg *config.Config) (*PostgresRepository, error) { + repo := &PostgresRepository{ + config: cfg, + spanPrefix: "user.repo.", + } + + if err := repo.initializeDatabase(); err != nil { + return nil, fmt.Errorf("failed to initialize PostgreSQL database: %w", err) + } + + return repo, nil +} + +// initializeDatabase sets up the PostgreSQL database connection and runs migrations +func (r *PostgresRepository) initializeDatabase() error { + // Configure GORM logger based on config + var gormLogger logger.Interface + if r.config.GetLoggingJSON() { + // Create zerolog logger that respects the configured output + var logOutput = os.Stderr + + // If a log file is configured, use it + if output := r.config.GetLogOutput(); output != "" { + if file, err := os.OpenFile(output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { + logOutput = file + } + } + + // Create zerolog logger with component context + globalLogger := zerolog.New(logOutput).With().Str("component", "gorm").Logger() + zw := &ZerologWriter{logger: globalLogger} + gormLogger = logger.New( + zw, + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logger.Warn, + IgnoreRecordNotFoundError: true, + Colorful: false, + }, + ) + } else { + // Use console logger for non-JSON mode + gormLogger = logger.New( + log.New(os.Stderr, "\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logger.Warn, + IgnoreRecordNotFoundError: true, + Colorful: true, + }, + ) + } + + // Build PostgreSQL DSN + dsn := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + r.config.GetDatabaseHost(), + r.config.GetDatabasePort(), + r.config.GetDatabaseUser(), + r.config.GetDatabasePassword(), + r.config.GetDatabaseName(), + r.config.GetDatabaseSSLMode(), + ) + + var err error + r.db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: gormLogger, + }) + if err != nil { + return fmt.Errorf("failed to connect to PostgreSQL: %w", err) + } + + // Configure connection pool + sqlDB, err := r.db.DB() + if err != nil { + return fmt.Errorf("failed to get SQL DB: %w", err) + } + + // Set connection pool settings + sqlDB.SetMaxOpenConns(r.config.GetDatabaseMaxOpenConns()) + sqlDB.SetMaxIdleConns(r.config.GetDatabaseMaxIdleConns()) + sqlDB.SetConnMaxLifetime(r.config.GetDatabaseConnMaxLifetime()) + + // Auto-migrate the User model + if err := r.db.AutoMigrate(&User{}); err != nil { + return fmt.Errorf("failed to auto-migrate: %w", err) + } + + return nil +} + +// CreateUser creates a new user in the database +func (r *PostgresRepository) CreateUser(ctx context.Context, user *User) error { + // Create telemetry span + ctx, span := r.createSpan(ctx, "create_user") + if span != nil { + defer span.End() + } + + result := r.db.WithContext(ctx).Create(user) + if result.Error != nil { + if span != nil { + span.RecordError(result.Error) + } + return fmt.Errorf("failed to create user: %w", result.Error) + } + return nil +} + +// GetUserByUsername retrieves a user by username +func (r *PostgresRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) { + // Create telemetry span + ctx, span := r.createSpan(ctx, "get_user_by_username") + if span != nil { + defer span.End() + span.SetAttributes(attribute.String("username", username)) + } + + var user User + result := r.db.WithContext(ctx).Where("username = ?", username).First(&user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + if span != nil { + span.RecordError(result.Error) + } + return nil, fmt.Errorf("failed to get user by username: %w", result.Error) + } + return &user, nil +} + +// GetUserByID retrieves a user by ID +func (r *PostgresRepository) GetUserByID(ctx context.Context, id uint) (*User, error) { + var user User + result := r.db.WithContext(ctx).First(&user, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get user by ID: %w", result.Error) + } + return &user, nil +} + +// UpdateUser updates a user in the database +func (r *PostgresRepository) UpdateUser(ctx context.Context, user *User) error { + result := r.db.WithContext(ctx).Save(user) + if result.Error != nil { + return fmt.Errorf("failed to update user: %w", result.Error) + } + return nil +} + +// DeleteUser deletes a user from the database +func (r *PostgresRepository) DeleteUser(ctx context.Context, id uint) error { + result := r.db.WithContext(ctx).Delete(&User{}, id) + if result.Error != nil { + return fmt.Errorf("failed to delete user: %w", result.Error) + } + return nil +} + +// AllowPasswordReset flags a user for password reset +func (r *PostgresRepository) AllowPasswordReset(ctx context.Context, username string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + user.AllowPasswordReset = true + return r.UpdateUser(ctx, user) +} + +// CompletePasswordReset completes the password reset process +func (r *PostgresRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset completion: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + if !user.AllowPasswordReset { + return fmt.Errorf("password reset not allowed for user: %s", username) + } + + user.PasswordHash = newPasswordHash + user.AllowPasswordReset = false + return r.UpdateUser(ctx, user) +} + +// UserExists checks if a user exists by username +func (r *PostgresRepository) UserExists(ctx context.Context, username string) (bool, error) { + var count int64 + result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count) + if result.Error != nil { + return false, fmt.Errorf("failed to check if user exists: %w", result.Error) + } + return count > 0, nil +} + +// Close closes the database connection +func (r *PostgresRepository) Close() error { + sqlDB, err := r.db.DB() + if err != nil { + return fmt.Errorf("failed to get database connection: %w", err) + } + return sqlDB.Close() +} + +// CheckDatabaseHealth checks if the database is healthy and responsive +func (r *PostgresRepository) CheckDatabaseHealth(ctx context.Context) error { + // Simple query to test database connectivity + var count int64 + result := r.db.WithContext(ctx).Model(&User{}).Count(&count) + if result.Error != nil { + return fmt.Errorf("database health check failed: %w", result.Error) + } + return nil +} + +// createSpan creates a new telemetry span if persistence telemetry is enabled +func (r *PostgresRepository) createSpan(ctx context.Context, operation string) (context.Context, trace.Span) { + if r.config == nil || !r.config.GetPersistenceTelemetryEnabled() { + return ctx, trace.SpanFromContext(ctx) + } + + // Create a new span with the operation name + spanName := r.spanPrefix + operation + tr := otel.Tracer("user-repository") + return tr.Start(ctx, spanName) +} diff --git a/pkg/user/sqlite_repository.go b/pkg/user/sqlite_repository.go new file mode 100644 index 0000000..cba9c21 --- /dev/null +++ b/pkg/user/sqlite_repository.go @@ -0,0 +1,225 @@ +package user + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "dance-lessons-coach/pkg/config" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// SQLiteRepository implements UserRepository using SQLite +type SQLiteRepository struct { + db *gorm.DB + dbPath string + config *config.Config + spanPrefix string +} + +// NewSQLiteRepository creates a new SQLite repository +func NewSQLiteRepository(dbPath string, config *config.Config) (*SQLiteRepository, error) { + repo := &SQLiteRepository{ + dbPath: dbPath, + config: config, + spanPrefix: "user.repo.", + } + + if err := repo.initializeDatabase(); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + return repo, nil +} + +// initializeDatabase sets up the SQLite database and runs migrations +func (r *SQLiteRepository) initializeDatabase() error { + // Create directory if it doesn't exist + dir := filepath.Dir(r.dbPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Configure GORM logger to use standard log + gormLogger := logger.New( + log.New(os.Stdout, "\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logger.Warn, + IgnoreRecordNotFoundError: true, + Colorful: true, + }, + ) + + var err error + r.db, err = gorm.Open(sqlite.Open(r.dbPath), &gorm.Config{ + Logger: gormLogger, + }) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + // Auto-migrate the User model + if err := r.db.AutoMigrate(&User{}); err != nil { + return fmt.Errorf("failed to auto-migrate: %w", err) + } + + return nil +} + +// CreateUser creates a new user in the database +func (r *SQLiteRepository) CreateUser(ctx context.Context, user *User) error { + // Create telemetry span + ctx, span := r.createSpan(ctx, "create_user") + if span != nil { + defer span.End() + } + + result := r.db.WithContext(ctx).Create(user) + if result.Error != nil { + if span != nil { + span.RecordError(result.Error) + } + return fmt.Errorf("failed to create user: %w", result.Error) + } + return nil +} + +// GetUserByUsername retrieves a user by username +func (r *SQLiteRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) { + // Create telemetry span + ctx, span := r.createSpan(ctx, "get_user_by_username") + if span != nil { + defer span.End() + span.SetAttributes(attribute.String("username", username)) + } + + var user User + result := r.db.WithContext(ctx).Where("username = ?", username).First(&user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + if span != nil { + span.RecordError(result.Error) + } + return nil, fmt.Errorf("failed to get user by username: %w", result.Error) + } + return &user, nil +} + +// GetUserByID retrieves a user by ID +func (r *SQLiteRepository) GetUserByID(ctx context.Context, id uint) (*User, error) { + var user User + result := r.db.WithContext(ctx).First(&user, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get user by ID: %w", result.Error) + } + return &user, nil +} + +// UpdateUser updates a user in the database +func (r *SQLiteRepository) UpdateUser(ctx context.Context, user *User) error { + result := r.db.WithContext(ctx).Save(user) + if result.Error != nil { + return fmt.Errorf("failed to update user: %w", result.Error) + } + return nil +} + +// DeleteUser deletes a user from the database +func (r *SQLiteRepository) DeleteUser(ctx context.Context, id uint) error { + result := r.db.WithContext(ctx).Delete(&User{}, id) + if result.Error != nil { + return fmt.Errorf("failed to delete user: %w", result.Error) + } + return nil +} + +// AllowPasswordReset flags a user for password reset +func (r *SQLiteRepository) AllowPasswordReset(ctx context.Context, username string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + user.AllowPasswordReset = true + return r.UpdateUser(ctx, user) +} + +// CompletePasswordReset completes the password reset process +func (r *SQLiteRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error { + user, err := r.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user for password reset completion: %w", err) + } + if user == nil { + return fmt.Errorf("user not found: %s", username) + } + + if !user.AllowPasswordReset { + return fmt.Errorf("password reset not allowed for user: %s", username) + } + + user.PasswordHash = newPasswordHash + user.AllowPasswordReset = false + return r.UpdateUser(ctx, user) +} + +// UserExists checks if a user exists by username +func (r *SQLiteRepository) UserExists(ctx context.Context, username string) (bool, error) { + var count int64 + result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count) + if result.Error != nil { + return false, fmt.Errorf("failed to check if user exists: %w", result.Error) + } + return count > 0, nil +} + +// Close closes the database connection +func (r *SQLiteRepository) Close() error { + sqlDB, err := r.db.DB() + if err != nil { + return fmt.Errorf("failed to get database connection: %w", err) + } + return sqlDB.Close() +} + +// CheckDatabaseHealth checks if the database is healthy and responsive +func (r *SQLiteRepository) CheckDatabaseHealth(ctx context.Context) error { + // Simple query to test database connectivity + var count int64 + result := r.db.WithContext(ctx).Model(&User{}).Count(&count) + if result.Error != nil { + return fmt.Errorf("database health check failed: %w", result.Error) + } + return nil +} + +// createSpan creates a new telemetry span if persistence telemetry is enabled +func (r *SQLiteRepository) createSpan(ctx context.Context, operation string) (context.Context, trace.Span) { + if r.config == nil || !r.config.GetPersistenceTelemetryEnabled() { + return ctx, trace.SpanFromContext(ctx) + } + + // Create a new span with the operation name + spanName := r.spanPrefix + operation + tr := otel.Tracer("user-repository") + return tr.Start(ctx, spanName) +} diff --git a/pkg/user/user.go b/pkg/user/user.go new file mode 100644 index 0000000..04aafd7 --- /dev/null +++ b/pkg/user/user.go @@ -0,0 +1,69 @@ +package user + +import ( + "context" + "time" +) + +// User represents a user in the system +type User struct { + ID uint `json:"id" gorm:"primaryKey"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt *time.Time `json:"deleted_at,omitempty" gorm:"index"` + Username string `json:"username" gorm:"unique;not null" validate:"required,min=3,max=50"` + PasswordHash string `json:"-" gorm:"not null"` + Description *string `json:"description,omitempty"` + CurrentGoal *string `json:"current_goal,omitempty"` + IsAdmin bool `json:"is_admin" gorm:"default:false"` + AllowPasswordReset bool `json:"allow_password_reset" gorm:"default:false"` + LastLogin *time.Time `json:"last_login,omitempty"` +} + +// UserRepository defines the interface for user persistence +type UserRepository interface { + CreateUser(ctx context.Context, user *User) error + GetUserByUsername(ctx context.Context, username string) (*User, error) + GetUserByID(ctx context.Context, id uint) (*User, error) + UpdateUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id uint) error + AllowPasswordReset(ctx context.Context, username string) error + CompletePasswordReset(ctx context.Context, username, newPassword string) error + UserExists(ctx context.Context, username string) (bool, error) + CheckDatabaseHealth(ctx context.Context) error +} + +// AuthService defines interface for authentication operations +type AuthService interface { + Authenticate(ctx context.Context, username, password string) (*User, error) + GenerateJWT(ctx context.Context, user *User) (string, error) + ValidateJWT(ctx context.Context, token string) (*User, error) + AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) +} + +// UserManager defines interface for user management operations +type UserManager interface { + UserExists(ctx context.Context, username string) (bool, error) + CreateUser(ctx context.Context, user *User) error +} + +// PasswordService defines interface for password operations +type PasswordService interface { + HashPassword(ctx context.Context, password string) (string, error) + RequestPasswordReset(ctx context.Context, username string) error + CompletePasswordReset(ctx context.Context, username, newPassword string) error +} + +// UserService composes all user-related interfaces using Go's interface composition +// This is cleaner than aggregation and better for testing +type UserService interface { + AuthService + UserManager + PasswordService +} + +// PasswordResetService defines the interface for password reset workflow +type PasswordResetService interface { + RequestPasswordReset(ctx context.Context, username string) error + CompletePasswordReset(ctx context.Context, username, newPassword string) error +} diff --git a/pkg/user/user_test.go b/pkg/user/user_test.go new file mode 100644 index 0000000..28bc9b9 --- /dev/null +++ b/pkg/user/user_test.go @@ -0,0 +1,237 @@ +package user + +import ( + "context" + "os" + "testing" + "time" + + "dance-lessons-coach/pkg/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createTestConfig creates a test configuration with telemetry disabled +func createTestConfig() *config.Config { + return &config.Config{ + Telemetry: config.TelemetryConfig{ + Enabled: false, + Persistence: config.PersistenceTelemetryConfig{ + Enabled: false, + }, + }, + } +} + +func TestSQLiteRepository(t *testing.T) { + t.Run("CRUD operations", func(t *testing.T) { + // Create a temporary database + dbPath := "test_db.sqlite" + defer os.Remove(dbPath) + + cfg := createTestConfig() + repo, err := NewSQLiteRepository(dbPath, cfg) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Test CreateUser + user := &User{ + Username: "testuser", + PasswordHash: "hashedpassword", + Description: ptrString("Test user"), + CurrentGoal: ptrString("Learn to dance"), + IsAdmin: false, + } + + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + assert.NotZero(t, user.ID) + + // Test GetUserByUsername + retrievedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.NotNil(t, retrievedUser) + assert.Equal(t, "testuser", retrievedUser.Username) + + // Test UserExists + exists, err := repo.UserExists(ctx, "testuser") + require.NoError(t, err) + assert.True(t, exists) + + // Test UpdateUser + retrievedUser.Description = ptrString("Updated description") + err = repo.UpdateUser(ctx, retrievedUser) + require.NoError(t, err) + + // Verify update + updatedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Equal(t, "Updated description", *updatedUser.Description) + + // Test AllowPasswordReset + err = repo.AllowPasswordReset(ctx, "testuser") + require.NoError(t, err) + + // Verify password reset flag + userWithReset, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.True(t, userWithReset.AllowPasswordReset) + + // Test CompletePasswordReset + err = repo.CompletePasswordReset(ctx, "testuser", "newhashedpassword") + require.NoError(t, err) + + // Verify password reset completion + userAfterReset, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Equal(t, "newhashedpassword", userAfterReset.PasswordHash) + assert.False(t, userAfterReset.AllowPasswordReset) + + // Test DeleteUser + err = repo.DeleteUser(ctx, userAfterReset.ID) + require.NoError(t, err) + + // Verify deletion + deletedUser, err := repo.GetUserByUsername(ctx, "testuser") + require.NoError(t, err) + assert.Nil(t, deletedUser) + }) +} + +func TestAuthService(t *testing.T) { + t.Run("Password hashing and authentication", func(t *testing.T) { + // Create a temporary database + dbPath := "test_auth_db.sqlite" + defer os.Remove(dbPath) + + cfg := createTestConfig() + repo, err := NewSQLiteRepository(dbPath, cfg) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Create user service + jwtConfig := JWTConfig{ + Secret: "test-secret", + ExpirationTime: time.Hour, + Issuer: "test-issuer", + } + userService := NewUserService(repo, jwtConfig, "admin123") + + // Test password hashing + password := "testpassword123" + hashedPassword, err := userService.HashPassword(ctx, password) + require.NoError(t, err) + assert.NotEmpty(t, hashedPassword) + + // Create a test user + user := &User{ + Username: "testuser", + PasswordHash: hashedPassword, + } + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + + // Test successful authentication + authenticatedUser, err := userService.Authenticate(ctx, "testuser", password) + require.NoError(t, err) + assert.NotNil(t, authenticatedUser) + assert.Equal(t, "testuser", authenticatedUser.Username) + + // Test failed authentication with wrong password + _, err = userService.Authenticate(ctx, "testuser", "wrongpassword") + assert.Error(t, err) + assert.Equal(t, "invalid credentials", err.Error()) + + // Test JWT generation + token, err := userService.GenerateJWT(ctx, authenticatedUser) + require.NoError(t, err) + assert.NotEmpty(t, token) + + // Test JWT validation + validatedUser, err := userService.ValidateJWT(ctx, token) + require.NoError(t, err) + assert.NotNil(t, validatedUser) + assert.Equal(t, authenticatedUser.ID, validatedUser.ID) + + // Test admin authentication + adminUser, err := userService.AdminAuthenticate(ctx, "admin123") + require.NoError(t, err) + assert.NotNil(t, adminUser) + assert.True(t, adminUser.IsAdmin) + assert.Equal(t, "admin", adminUser.Username) + + // Test failed admin authentication + _, err = userService.AdminAuthenticate(ctx, "wrongadminpassword") + assert.Error(t, err) + assert.Equal(t, "invalid admin credentials", err.Error()) + }) +} + +func TestPasswordResetService(t *testing.T) { + t.Run("Password reset workflow", func(t *testing.T) { + // Create a temporary database + dbPath := "test_reset_db.sqlite" + defer os.Remove(dbPath) + + cfg := createTestConfig() + repo, err := NewSQLiteRepository(dbPath, cfg) + require.NoError(t, err) + defer repo.Close() + + ctx := context.Background() + + // Create user service + jwtConfig := JWTConfig{ + Secret: "test-secret", + ExpirationTime: time.Hour, + Issuer: "test-issuer", + } + userService := NewUserService(repo, jwtConfig, "admin123") + + // Create a test user + password := "oldpassword123" + hashedPassword, err := userService.HashPassword(ctx, password) + require.NoError(t, err) + + user := &User{ + Username: "resetuser", + PasswordHash: hashedPassword, + } + err = repo.CreateUser(ctx, user) + require.NoError(t, err) + + // Test password reset request + err = userService.RequestPasswordReset(ctx, "resetuser") + require.NoError(t, err) + + // Verify user is flagged for reset + userAfterRequest, err := repo.GetUserByUsername(ctx, "resetuser") + require.NoError(t, err) + assert.True(t, userAfterRequest.AllowPasswordReset) + + // Test password reset completion + newPassword := "newpassword123" + err = userService.CompletePasswordReset(ctx, "resetuser", newPassword) + require.NoError(t, err) + + // Verify password was updated and reset flag was cleared + userAfterReset, err := repo.GetUserByUsername(ctx, "resetuser") + require.NoError(t, err) + assert.False(t, userAfterReset.AllowPasswordReset) + + // Verify new password works by authenticating with the new password + authenticatedUser, err := userService.Authenticate(ctx, "resetuser", newPassword) + require.NoError(t, err) + assert.NotNil(t, authenticatedUser) + assert.Equal(t, "resetuser", authenticatedUser.Username) + }) +} + +// Helper function to create string pointers +func ptrString(s string) *string { + return &s +}