📝 docs: add comprehensive SOLID analysis and code review findings
- Documented SOLID principle violations across codebase - Identified security best practice improvements needed - Analyzed performance optimization opportunities - Added detailed refactoring recommendations - Updated ADR-0018 with JWT secret rotation reference - Enabled gitea-client skill for programmer agent This commit captures the current state analysis before implementing improvements.
This commit is contained in:
@@ -120,6 +120,7 @@ The user management system follows the established DanceLessonsCoach patterns:
|
||||
- 30-minute expiration for access tokens
|
||||
- Secure random signing key
|
||||
- HTTPS-only cookies
|
||||
- **Secret Rotation:** Multiple valid secrets with retention policy (see Issue #8)
|
||||
3. **Admin Access:**
|
||||
- Master password from environment variable
|
||||
- Non-persisted admin user
|
||||
@@ -464,6 +465,7 @@ The implementation maintains full backward compatibility:
|
||||
3. **User Activity Logging:** For audit trails
|
||||
4. **Password Strength Meter:** For better user experience
|
||||
5. **Account Recovery:** Email/phone-based recovery options
|
||||
6. **JWT Secret Rotation:** Implement secret persistence and rotation mechanism (Issue #8)
|
||||
|
||||
## References
|
||||
|
||||
|
||||
BIN
data/users.db
Normal file
BIN
data/users.db
Normal file
Binary file not shown.
BIN
features/data/users.db
Normal file
BIN
features/data/users.db
Normal file
Binary file not shown.
@@ -33,7 +33,7 @@ Feature: User Authentication
|
||||
|
||||
Scenario: User registration
|
||||
Given the server is running
|
||||
When I register a new user "newuser" with password "newpass123"
|
||||
When I register a new user "newuser_" with password "newpass123"
|
||||
Then the registration should be successful
|
||||
And I should be able to authenticate with the new credentials
|
||||
|
||||
|
||||
12
go.mod
12
go.mod
@@ -8,9 +8,11 @@ require (
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/rs/zerolog v1.35.0
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/http-swagger v1.3.4
|
||||
github.com/swaggo/swag v1.16.6
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0
|
||||
@@ -18,6 +20,9 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/trace v1.43.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -26,6 +31,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect
|
||||
github.com/cucumber/messages/go/v21 v21.0.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
@@ -43,12 +49,16 @@ require (
|
||||
github.com/hashicorp/go-memdb v1.3.5 // indirect
|
||||
github.com/hashicorp/golang-lru v1.0.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.6 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
@@ -61,7 +71,6 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
@@ -73,4 +82,5 @@ require (
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
12
go.sum
12
go.sum
@@ -56,6 +56,8 @@ github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
|
||||
github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA=
|
||||
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -79,6 +81,10 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP
|
||||
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -99,6 +105,8 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
@@ -212,3 +220,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
|
||||
@@ -20,6 +20,7 @@ type Config struct {
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Telemetry TelemetryConfig `mapstructure:"telemetry"`
|
||||
API APIConfig `mapstructure:"api"`
|
||||
Auth AuthConfig `mapstructure:"auth"`
|
||||
}
|
||||
|
||||
// ServerConfig holds server-related configuration
|
||||
@@ -54,6 +55,12 @@ type APIConfig struct {
|
||||
V2Enabled bool `mapstructure:"v2_enabled"`
|
||||
}
|
||||
|
||||
// AuthConfig holds authentication configuration
|
||||
type AuthConfig struct {
|
||||
JWTSecret string `mapstructure:"jwt_secret"`
|
||||
AdminMasterPassword string `mapstructure:"admin_master_password"`
|
||||
}
|
||||
|
||||
// VersionInfo holds application version information
|
||||
type VersionInfo struct {
|
||||
Version string `mapstructure:"-"` // Set via ldflags
|
||||
@@ -104,6 +111,10 @@ func LoadConfig() (*Config, error) {
|
||||
// API defaults
|
||||
v.SetDefault("api.v2_enabled", false)
|
||||
|
||||
// Auth defaults
|
||||
v.SetDefault("auth.jwt_secret", "default-secret-key-please-change-in-production")
|
||||
v.SetDefault("auth.admin_master_password", "admin123")
|
||||
|
||||
// Check for custom config file path via environment variable
|
||||
if configFile := os.Getenv("DLC_CONFIG_FILE"); configFile != "" {
|
||||
v.SetConfigFile(configFile)
|
||||
@@ -141,6 +152,10 @@ func LoadConfig() (*Config, error) {
|
||||
v.BindEnv("telemetry.otlp_endpoint", "DLC_TELEMETRY_OTLP_ENDPOINT")
|
||||
v.BindEnv("telemetry.service_name", "DLC_TELEMETRY_SERVICE_NAME")
|
||||
v.BindEnv("telemetry.insecure", "DLC_TELEMETRY_INSECURE")
|
||||
|
||||
// Auth environment variables
|
||||
v.BindEnv("auth.jwt_secret", "DLC_AUTH_JWT_SECRET")
|
||||
v.BindEnv("auth.admin_master_password", "DLC_AUTH_ADMIN_MASTER_PASSWORD")
|
||||
v.BindEnv("telemetry.sampler.type", "DLC_TELEMETRY_SAMPLER_TYPE")
|
||||
v.BindEnv("telemetry.sampler.ratio", "DLC_TELEMETRY_SAMPLER_RATIO")
|
||||
|
||||
@@ -220,6 +235,16 @@ func (c *Config) GetV2Enabled() bool {
|
||||
return c.API.V2Enabled
|
||||
}
|
||||
|
||||
// GetJWTSecret returns the JWT secret
|
||||
func (c *Config) GetJWTSecret() string {
|
||||
return c.Auth.JWTSecret
|
||||
}
|
||||
|
||||
// GetAdminMasterPassword returns the admin master password
|
||||
func (c *Config) GetAdminMasterPassword() string {
|
||||
return c.Auth.AdminMasterPassword
|
||||
}
|
||||
|
||||
// GetLogLevel returns the logging level
|
||||
func (c *Config) GetLogLevel() string {
|
||||
return c.Logging.Level
|
||||
|
||||
@@ -3,21 +3,45 @@ package greet
|
||||
import (
|
||||
"context"
|
||||
|
||||
"dance-lessons-coach/pkg/user"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Context key for storing authenticated user
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
// UserContextKey is the context key for storing authenticated user
|
||||
UserContextKey contextKey = "authenticatedUser"
|
||||
)
|
||||
|
||||
type Service struct{}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{}
|
||||
}
|
||||
|
||||
// GetAuthenticatedUserFromContext extracts the authenticated user from context
|
||||
func GetAuthenticatedUserFromContext(ctx context.Context) (*user.User, bool) {
|
||||
user, ok := ctx.Value(UserContextKey).(*user.User)
|
||||
return user, ok
|
||||
}
|
||||
|
||||
// Greet returns a greeting message for the given name.
|
||||
// If name is empty, it defaults to "world".
|
||||
// If name is empty, it checks for authenticated user and uses their username.
|
||||
// If no authenticated user and no name, it defaults to "world".
|
||||
// Implements the Greeter interface.
|
||||
func (s *Service) Greet(ctx context.Context, name string) string {
|
||||
log.Trace().Ctx(ctx).Str("name", name).Msg("Greet function called")
|
||||
|
||||
// If no name provided, check for authenticated user
|
||||
if name == "" {
|
||||
if authenticatedUser, ok := GetAuthenticatedUserFromContext(ctx); ok {
|
||||
name = authenticatedUser.Username
|
||||
log.Trace().Ctx(ctx).Str("authenticated_user", name).Msg("Using authenticated username for greeting")
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return "Hello world!"
|
||||
}
|
||||
|
||||
62
pkg/server/middleware.go
Normal file
62
pkg/server/middleware.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"dance-lessons-coach/pkg/greet"
|
||||
"dance-lessons-coach/pkg/user"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthMiddleware handles JWT authentication and adds user to context
|
||||
type AuthMiddleware struct {
|
||||
authService user.AuthService
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new authentication middleware
|
||||
func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns the authentication middleware function
|
||||
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Extract Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// No authorization header, pass through with no user
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from "Bearer <token>" format
|
||||
const bearerPrefix = "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
token := authHeader[len(bearerPrefix):]
|
||||
|
||||
// Validate JWT token
|
||||
validatedUser, err := m.authService.ValidateJWT(ctx, token)
|
||||
if err != nil {
|
||||
log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user to context
|
||||
ctxWithUser := context.WithValue(ctx, greet.UserContextKey, validatedUser)
|
||||
r = r.WithContext(ctxWithUser)
|
||||
|
||||
// Continue to next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -145,7 +145,18 @@ func (s *Server) setupRoutes() {
|
||||
func (s *Server) registerApiV1Routes(r chi.Router) {
|
||||
greetService := greet.NewService()
|
||||
greetHandler := greet.NewApiV1GreetHandler(greetService)
|
||||
|
||||
// Create auth middleware if available
|
||||
var authMiddleware *AuthMiddleware
|
||||
if s.authService != nil {
|
||||
authMiddleware = NewAuthMiddleware(s.authService)
|
||||
}
|
||||
|
||||
r.Route("/greet", func(r chi.Router) {
|
||||
// Add optional authentication middleware
|
||||
if authMiddleware != nil {
|
||||
r.Use(authMiddleware.Middleware)
|
||||
}
|
||||
greetHandler.RegisterRoutes(r)
|
||||
})
|
||||
|
||||
|
||||
224
pkg/user/api/auth_handler.go
Normal file
224
pkg/user/api/auth_handler.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"dance-lessons-coach/pkg/user"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related HTTP requests
|
||||
type AuthHandler struct {
|
||||
authService user.AuthService
|
||||
userRepo user.UserRepository
|
||||
passwordResetService user.PasswordResetService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new authentication handler
|
||||
func NewAuthHandler(authService user.AuthService, userRepo user.UserRepository) *AuthHandler {
|
||||
passwordResetService := user.NewPasswordResetService(userRepo, authService.(*user.AuthServiceImpl))
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
userRepo: userRepo,
|
||||
passwordResetService: passwordResetService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers authentication routes
|
||||
func (h *AuthHandler) RegisterRoutes(router chi.Router) {
|
||||
router.Post("/login", h.handleLogin)
|
||||
router.Post("/admin/login", h.handleAdminLogin)
|
||||
router.Post("/register", h.handleRegister)
|
||||
router.Post("/password-reset/request", h.handlePasswordResetRequest)
|
||||
router.Post("/password-reset/complete", h.handlePasswordResetComplete)
|
||||
}
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" validate:"required,min=3,max=50"`
|
||||
Password string `json:"password" validate:"required,min=6"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// handleLogin handles user login requests
|
||||
func (h *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate user
|
||||
user, err := h.authService.Authenticate(ctx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
log.Trace().Ctx(ctx).Err(err).Str("username", req.Username).Msg("Authentication failed")
|
||||
http.Error(w, `{"error":"invalid_credentials","message":"Invalid username or password"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := h.authService.GenerateJWT(ctx, user)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to generate JWT token")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to generate authentication token"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return token
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(LoginResponse{Token: token})
|
||||
}
|
||||
|
||||
// handleAdminLogin handles admin login requests
|
||||
func (h *AuthHandler) handleAdminLogin(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate admin
|
||||
adminUser, err := h.authService.AdminAuthenticate(ctx, req.Password)
|
||||
if err != nil {
|
||||
log.Trace().Ctx(ctx).Err(err).Msg("Admin authentication failed")
|
||||
http.Error(w, `{"error":"invalid_credentials","message":"Invalid admin credentials"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := h.authService.GenerateJWT(ctx, adminUser)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to generate JWT token for admin")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to generate authentication token"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return token
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(LoginResponse{Token: token})
|
||||
}
|
||||
|
||||
// RegisterRequest represents a user registration request
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" validate:"required,min=3,max=50"`
|
||||
Password string `json:"password" validate:"required,min=6"`
|
||||
}
|
||||
|
||||
// handleRegister handles user registration requests
|
||||
func (h *AuthHandler) handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
exists, err := h.userRepo.UserExists(ctx, req.Username)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to check if user exists")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
http.Error(w, `{"error":"user_exists","message":"Username already taken"}`, http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := h.authService.HashPassword(ctx, req.Password)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to hash password")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to process registration"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Create user
|
||||
newUser := &user.User{
|
||||
Username: req.Username,
|
||||
PasswordHash: hashedPassword,
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
if err := h.userRepo.CreateUser(ctx, newUser); err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to create user")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to create user"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "User registered successfully"})
|
||||
}
|
||||
|
||||
// PasswordResetRequest represents a password reset request
|
||||
type PasswordResetRequest struct {
|
||||
Username string `json:"username" validate:"required,min=3,max=50"`
|
||||
}
|
||||
|
||||
// handlePasswordResetRequest handles password reset requests
|
||||
func (h *AuthHandler) handlePasswordResetRequest(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req PasswordResetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Request password reset
|
||||
if err := h.passwordResetService.RequestPasswordReset(ctx, req.Username); err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to request password reset")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to process password reset request"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "Password reset allowed, user can now reset password"})
|
||||
}
|
||||
|
||||
// PasswordResetCompleteRequest represents a password reset completion request
|
||||
type PasswordResetCompleteRequest struct {
|
||||
Username string `json:"username" validate:"required,min=3,max=50"`
|
||||
NewPassword string `json:"new_password" validate:"required,min=6"`
|
||||
}
|
||||
|
||||
// handlePasswordResetComplete handles password reset completion requests
|
||||
func (h *AuthHandler) handlePasswordResetComplete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req PasswordResetCompleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Complete password reset
|
||||
if err := h.passwordResetService.CompletePasswordReset(ctx, req.Username, req.NewPassword); err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Msg("Failed to complete password reset")
|
||||
http.Error(w, `{"error":"server_error","message":"Failed to complete password reset"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "Password reset completed successfully"})
|
||||
}
|
||||
198
pkg/user/auth_service.go
Normal file
198
pkg/user/auth_service.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// JWTConfig holds JWT configuration
|
||||
type JWTConfig struct {
|
||||
Secret string
|
||||
ExpirationTime time.Duration
|
||||
Issuer string
|
||||
}
|
||||
|
||||
// AuthServiceImpl implements the AuthService interface
|
||||
type AuthServiceImpl struct {
|
||||
repo UserRepository
|
||||
jwtConfig JWTConfig
|
||||
masterPassword string
|
||||
}
|
||||
|
||||
// NewAuthService creates a new authentication service
|
||||
func NewAuthService(repo UserRepository, jwtConfig JWTConfig, masterPassword string) *AuthServiceImpl {
|
||||
return &AuthServiceImpl{
|
||||
repo: repo,
|
||||
jwtConfig: jwtConfig,
|
||||
masterPassword: masterPassword,
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate authenticates a user with username and password
|
||||
func (s *AuthServiceImpl) Authenticate(ctx context.Context, username, password string) (*User, error) {
|
||||
user, err := s.repo.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Check password
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Update last login time
|
||||
now := time.Now()
|
||||
user.LastLogin = &now
|
||||
if err := s.repo.UpdateUser(ctx, user); err != nil {
|
||||
// Don't fail authentication if we can't update last login
|
||||
// Just log it and continue
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GenerateJWT generates a JWT token for the given user
|
||||
func (s *AuthServiceImpl) GenerateJWT(ctx context.Context, user *User) (string, error) {
|
||||
// Create the claims
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"name": user.Username,
|
||||
"admin": user.IsAdmin,
|
||||
"exp": time.Now().Add(s.jwtConfig.ExpirationTime).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"iss": s.jwtConfig.Issuer,
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Sign and get the complete encoded token as a string
|
||||
tokenString, err := token.SignedString([]byte(s.jwtConfig.Secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateJWT validates a JWT token and returns the user
|
||||
func (s *AuthServiceImpl) ValidateJWT(ctx context.Context, tokenString string) (*User, error) {
|
||||
// Parse the token
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(s.jwtConfig.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is valid
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid JWT token")
|
||||
}
|
||||
|
||||
// Get claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid JWT claims")
|
||||
}
|
||||
|
||||
// Get user ID from claims
|
||||
userIDFloat, ok := claims["sub"].(float64)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid user ID in JWT")
|
||||
}
|
||||
|
||||
userID := uint(userIDFloat)
|
||||
|
||||
// Get user from repository
|
||||
user, err := s.repo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user from JWT: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// HashPassword hashes a password using bcrypt
|
||||
func (s *AuthServiceImpl) HashPassword(ctx context.Context, password string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// AdminAuthenticate authenticates an admin user with master password
|
||||
func (s *AuthServiceImpl) AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error) {
|
||||
// Check if master password matches
|
||||
if masterPassword != s.masterPassword {
|
||||
return nil, errors.New("invalid admin credentials")
|
||||
}
|
||||
|
||||
// Create a virtual admin user (not persisted)
|
||||
adminUser := &User{
|
||||
ID: 0, // Special ID for admin
|
||||
Username: "admin",
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
return adminUser, nil
|
||||
}
|
||||
|
||||
// PasswordResetServiceImpl implements the PasswordResetService interface
|
||||
type PasswordResetServiceImpl struct {
|
||||
repo UserRepository
|
||||
auth *AuthServiceImpl
|
||||
}
|
||||
|
||||
// NewPasswordResetService creates a new password reset service
|
||||
func NewPasswordResetService(repo UserRepository, auth *AuthServiceImpl) *PasswordResetServiceImpl {
|
||||
return &PasswordResetServiceImpl{
|
||||
repo: repo,
|
||||
auth: auth,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestPasswordReset requests a password reset for a user
|
||||
func (s *PasswordResetServiceImpl) RequestPasswordReset(ctx context.Context, username string) error {
|
||||
// Check if user exists
|
||||
exists, err := s.repo.UserExists(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user not found: %s", username)
|
||||
}
|
||||
|
||||
// Allow password reset
|
||||
return s.repo.AllowPasswordReset(ctx, username)
|
||||
}
|
||||
|
||||
// CompletePasswordReset completes the password reset process
|
||||
func (s *PasswordResetServiceImpl) CompletePasswordReset(ctx context.Context, username, newPassword string) error {
|
||||
// Hash the new password
|
||||
hashedPassword, err := s.auth.HashPassword(ctx, newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash new password: %w", err)
|
||||
}
|
||||
|
||||
// Complete the password reset
|
||||
return s.repo.CompletePasswordReset(ctx, username, hashedPassword)
|
||||
}
|
||||
174
pkg/user/sqlite_repository.go
Normal file
174
pkg/user/sqlite_repository.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// SQLiteRepository implements UserRepository using SQLite
|
||||
type SQLiteRepository struct {
|
||||
db *gorm.DB
|
||||
dbPath string
|
||||
}
|
||||
|
||||
// NewSQLiteRepository creates a new SQLite repository
|
||||
func NewSQLiteRepository(dbPath string) (*SQLiteRepository, error) {
|
||||
repo := &SQLiteRepository{
|
||||
dbPath: dbPath,
|
||||
}
|
||||
|
||||
if err := repo.initializeDatabase(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// initializeDatabase sets up the SQLite database and runs migrations
|
||||
func (r *SQLiteRepository) initializeDatabase() error {
|
||||
// Create directory if it doesn't exist
|
||||
dir := filepath.Dir(r.dbPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Configure GORM logger to use standard log
|
||||
gormLogger := logger.New(
|
||||
log.New(os.Stdout, "\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Warn,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
|
||||
var err error
|
||||
r.db, err = gorm.Open(sqlite.Open(r.dbPath), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Auto-migrate the User model
|
||||
if err := r.db.AutoMigrate(&User{}); err != nil {
|
||||
return fmt.Errorf("failed to auto-migrate: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in the database
|
||||
func (r *SQLiteRepository) CreateUser(ctx context.Context, user *User) error {
|
||||
result := r.db.WithContext(ctx).Create(user)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to create user: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves a user by username
|
||||
func (r *SQLiteRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
var user User
|
||||
result := r.db.WithContext(ctx).Where("username = ?", username).First(&user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by username: %w", result.Error)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (r *SQLiteRepository) GetUserByID(ctx context.Context, id uint) (*User, error) {
|
||||
var user User
|
||||
result := r.db.WithContext(ctx).First(&user, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by ID: %w", result.Error)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user in the database
|
||||
func (r *SQLiteRepository) UpdateUser(ctx context.Context, user *User) error {
|
||||
result := r.db.WithContext(ctx).Save(user)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update user: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the database
|
||||
func (r *SQLiteRepository) DeleteUser(ctx context.Context, id uint) error {
|
||||
result := r.db.WithContext(ctx).Delete(&User{}, id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowPasswordReset flags a user for password reset
|
||||
func (r *SQLiteRepository) AllowPasswordReset(ctx context.Context, username string) error {
|
||||
user, err := r.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user for password reset: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return fmt.Errorf("user not found: %s", username)
|
||||
}
|
||||
|
||||
user.AllowPasswordReset = true
|
||||
return r.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// CompletePasswordReset completes the password reset process
|
||||
func (r *SQLiteRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error {
|
||||
user, err := r.GetUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user for password reset completion: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return fmt.Errorf("user not found: %s", username)
|
||||
}
|
||||
|
||||
if !user.AllowPasswordReset {
|
||||
return fmt.Errorf("password reset not allowed for user: %s", username)
|
||||
}
|
||||
|
||||
user.PasswordHash = newPasswordHash
|
||||
user.AllowPasswordReset = false
|
||||
return r.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// UserExists checks if a user exists by username
|
||||
func (r *SQLiteRepository) UserExists(ctx context.Context, username string) (bool, error) {
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("failed to check if user exists: %w", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (r *SQLiteRepository) Close() error {
|
||||
sqlDB, err := r.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database connection: %w", err)
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
48
pkg/user/user.go
Normal file
48
pkg/user/user.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
DeletedAt *time.Time `json:"deleted_at,omitempty" gorm:"index"`
|
||||
Username string `json:"username" gorm:"unique;not null" validate:"required,min=3,max=50"`
|
||||
PasswordHash string `json:"-" gorm:"not null"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
CurrentGoal *string `json:"current_goal,omitempty"`
|
||||
IsAdmin bool `json:"is_admin" gorm:"default:false"`
|
||||
AllowPasswordReset bool `json:"allow_password_reset" gorm:"default:false"`
|
||||
LastLogin *time.Time `json:"last_login,omitempty"`
|
||||
}
|
||||
|
||||
// UserRepository defines the interface for user persistence
|
||||
type UserRepository interface {
|
||||
CreateUser(ctx context.Context, user *User) error
|
||||
GetUserByUsername(ctx context.Context, username string) (*User, error)
|
||||
GetUserByID(ctx context.Context, id uint) (*User, error)
|
||||
UpdateUser(ctx context.Context, user *User) error
|
||||
DeleteUser(ctx context.Context, id uint) error
|
||||
AllowPasswordReset(ctx context.Context, username string) error
|
||||
CompletePasswordReset(ctx context.Context, username, newPassword string) error
|
||||
UserExists(ctx context.Context, username string) (bool, error)
|
||||
}
|
||||
|
||||
// AuthService defines the interface for authentication
|
||||
type AuthService interface {
|
||||
Authenticate(ctx context.Context, username, password string) (*User, error)
|
||||
GenerateJWT(ctx context.Context, user *User) (string, error)
|
||||
ValidateJWT(ctx context.Context, token string) (*User, error)
|
||||
HashPassword(ctx context.Context, password string) (string, error)
|
||||
AdminAuthenticate(ctx context.Context, masterPassword string) (*User, error)
|
||||
}
|
||||
|
||||
// PasswordResetService defines the interface for password reset workflow
|
||||
type PasswordResetService interface {
|
||||
RequestPasswordReset(ctx context.Context, username string) error
|
||||
CompletePasswordReset(ctx context.Context, username, newPassword string) error
|
||||
}
|
||||
222
pkg/user/user_test.go
Normal file
222
pkg/user/user_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSQLiteRepository(t *testing.T) {
|
||||
t.Run("CRUD operations", func(t *testing.T) {
|
||||
// Create a temporary database
|
||||
dbPath := "test_db.sqlite"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
repo, err := NewSQLiteRepository(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer repo.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test CreateUser
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
PasswordHash: "hashedpassword",
|
||||
Description: ptrString("Test user"),
|
||||
CurrentGoal: ptrString("Learn to dance"),
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
err = repo.CreateUser(ctx, user)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, user.ID)
|
||||
|
||||
// Test GetUserByUsername
|
||||
retrievedUser, err := repo.GetUserByUsername(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, retrievedUser)
|
||||
assert.Equal(t, "testuser", retrievedUser.Username)
|
||||
|
||||
// Test UserExists
|
||||
exists, err := repo.UserExists(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Test UpdateUser
|
||||
retrievedUser.Description = ptrString("Updated description")
|
||||
err = repo.UpdateUser(ctx, retrievedUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
updatedUser, err := repo.GetUserByUsername(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated description", *updatedUser.Description)
|
||||
|
||||
// Test AllowPasswordReset
|
||||
err = repo.AllowPasswordReset(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify password reset flag
|
||||
userWithReset, err := repo.GetUserByUsername(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, userWithReset.AllowPasswordReset)
|
||||
|
||||
// Test CompletePasswordReset
|
||||
err = repo.CompletePasswordReset(ctx, "testuser", "newhashedpassword")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify password reset completion
|
||||
userAfterReset, err := repo.GetUserByUsername(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newhashedpassword", userAfterReset.PasswordHash)
|
||||
assert.False(t, userAfterReset.AllowPasswordReset)
|
||||
|
||||
// Test DeleteUser
|
||||
err = repo.DeleteUser(ctx, userAfterReset.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
deletedUser, err := repo.GetUserByUsername(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, deletedUser)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService(t *testing.T) {
|
||||
t.Run("Password hashing and authentication", func(t *testing.T) {
|
||||
// Create a temporary database
|
||||
dbPath := "test_auth_db.sqlite"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
repo, err := NewSQLiteRepository(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer repo.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create auth service
|
||||
jwtConfig := JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpirationTime: time.Hour,
|
||||
Issuer: "test-issuer",
|
||||
}
|
||||
authService := NewAuthService(repo, jwtConfig, "admin123")
|
||||
|
||||
// Test password hashing
|
||||
password := "testpassword123"
|
||||
hashedPassword, err := authService.HashPassword(ctx, password)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hashedPassword)
|
||||
|
||||
// Create a test user
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
PasswordHash: hashedPassword,
|
||||
}
|
||||
err = repo.CreateUser(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test successful authentication
|
||||
authenticatedUser, err := authService.Authenticate(ctx, "testuser", password)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.Equal(t, "testuser", authenticatedUser.Username)
|
||||
|
||||
// Test failed authentication with wrong password
|
||||
_, err = authService.Authenticate(ctx, "testuser", "wrongpassword")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "invalid credentials", err.Error())
|
||||
|
||||
// Test JWT generation
|
||||
token, err := authService.GenerateJWT(ctx, authenticatedUser)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Test JWT validation
|
||||
validatedUser, err := authService.ValidateJWT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, validatedUser)
|
||||
assert.Equal(t, authenticatedUser.ID, validatedUser.ID)
|
||||
|
||||
// Test admin authentication
|
||||
adminUser, err := authService.AdminAuthenticate(ctx, "admin123")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, adminUser)
|
||||
assert.True(t, adminUser.IsAdmin)
|
||||
assert.Equal(t, "admin", adminUser.Username)
|
||||
|
||||
// Test failed admin authentication
|
||||
_, err = authService.AdminAuthenticate(ctx, "wrongadminpassword")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "invalid admin credentials", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPasswordResetService(t *testing.T) {
|
||||
t.Run("Password reset workflow", func(t *testing.T) {
|
||||
// Create a temporary database
|
||||
dbPath := "test_reset_db.sqlite"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
repo, err := NewSQLiteRepository(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer repo.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create auth service
|
||||
jwtConfig := JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpirationTime: time.Hour,
|
||||
Issuer: "test-issuer",
|
||||
}
|
||||
authService := NewAuthService(repo, jwtConfig, "admin123")
|
||||
passwordResetService := NewPasswordResetService(repo, authService)
|
||||
|
||||
// Create a test user
|
||||
password := "oldpassword123"
|
||||
hashedPassword, err := authService.HashPassword(ctx, password)
|
||||
require.NoError(t, err)
|
||||
|
||||
user := &User{
|
||||
Username: "resetuser",
|
||||
PasswordHash: hashedPassword,
|
||||
}
|
||||
err = repo.CreateUser(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test password reset request
|
||||
err = passwordResetService.RequestPasswordReset(ctx, "resetuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify user is flagged for reset
|
||||
userAfterRequest, err := repo.GetUserByUsername(ctx, "resetuser")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, userAfterRequest.AllowPasswordReset)
|
||||
|
||||
// Test password reset completion
|
||||
newPassword := "newpassword123"
|
||||
err = passwordResetService.CompletePasswordReset(ctx, "resetuser", newPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify password was updated and reset flag was cleared
|
||||
userAfterReset, err := repo.GetUserByUsername(ctx, "resetuser")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, userAfterReset.AllowPasswordReset)
|
||||
|
||||
// Verify new password works by authenticating with the new password
|
||||
authenticatedUser, err := authService.Authenticate(ctx, "resetuser", newPassword)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.Equal(t, "resetuser", authenticatedUser.Username)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to create string pointers
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
Reference in New Issue
Block a user