package user import ( "context" "errors" "fmt" "log" "os" "time" "dance-lessons-coach/pkg/config" "github.com/rs/zerolog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // ZerologWriter implements logger.Writer interface using zerolog type ZerologWriter struct { logger zerolog.Logger } func (zw *ZerologWriter) Printf(format string, v ...interface{}) { message := fmt.Sprintf(format, v...) // Determine appropriate log level based on message content if len(message) > 0 { // Check for error indicators if containsErrorIndicators(message) { zw.logger.Error().Str("gorm", message).Send() return } // Check for slow query indicators if containsSlowQueryIndicators(message) { zw.logger.Warn().Str("gorm", message).Send() return } // Default to debug level for regular SQL queries zw.logger.Debug().Str("gorm", message).Send() } } // containsErrorIndicators checks if the message contains error-related keywords func containsErrorIndicators(message string) bool { errorKeywords := []string{"error", "Error", "failed", "Failed", "not found", "Not Found"} for _, keyword := range errorKeywords { if containsIgnoreCase(message, keyword) { return true } } return false } // containsSlowQueryIndicators checks if the message contains slow query indicators func containsSlowQueryIndicators(message string) bool { slowKeywords := []string{"slow", "Slow", "timeout", "Timeout"} for _, keyword := range slowKeywords { if containsIgnoreCase(message, keyword) { return true } } return false } // containsIgnoreCase performs case-insensitive string containment check func containsIgnoreCase(s, substr string) bool { return containsIgnoreCaseBytes([]byte(s), []byte(substr)) } // containsIgnoreCaseBytes is a helper for case-insensitive byte slice containment func containsIgnoreCaseBytes(s, substr []byte) bool { if len(substr) == 0 { return true } if len(s) < len(substr) { return false } for i := 0; i <= len(s)-len(substr); i++ { match := true for j := 0; j < len(substr); j++ { if toLower(s[i+j]) != toLower(substr[j]) { match = false break } } if match { return true } } return false } // toLower converts byte to lowercase func toLower(b byte) byte { if b >= 'A' && b <= 'Z' { return b + 32 } return b } // PostgresRepository implements UserRepository using PostgreSQL type PostgresRepository struct { db *gorm.DB config *config.Config spanPrefix string } // NewPostgresRepository creates a new PostgreSQL repository func NewPostgresRepository(cfg *config.Config) (*PostgresRepository, error) { repo := &PostgresRepository{ config: cfg, spanPrefix: "user.repo.", } if err := repo.initializeDatabase(); err != nil { return nil, fmt.Errorf("failed to initialize PostgreSQL database: %w", err) } return repo, nil } // NewPostgresRepositoryFromDSN creates a PostgresRepository connected via the given DSN // and runs AutoMigrate against it. Used by BDD test infra to create a per-scenario // repository pointing at an isolated schema (the DSN typically includes search_path=). // // Pass the same cfg used elsewhere (it is required by methods that read pool settings), // but the DSN passed here OVERRIDES the host/port/dbname/etc that cfg would have built. func NewPostgresRepositoryFromDSN(cfg *config.Config, dsn string) (*PostgresRepository, error) { repo := &PostgresRepository{ config: cfg, spanPrefix: "user.repo.", } gormLogger := logger.New( log.New(os.Stderr, "\n", log.LstdFlags), logger.Config{ SlowThreshold: time.Second, LogLevel: logger.Warn, IgnoreRecordNotFoundError: true, Colorful: false, }, ) db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: gormLogger}) if err != nil { return nil, fmt.Errorf("failed to connect to PostgreSQL with custom DSN: %w", err) } sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("failed to get sql.DB from gorm: %w", err) } sqlDB.SetMaxOpenConns(cfg.GetDatabaseMaxOpenConns()) sqlDB.SetMaxIdleConns(cfg.GetDatabaseMaxIdleConns()) sqlDB.SetConnMaxLifetime(cfg.GetDatabaseConnMaxLifetime()) if err := db.AutoMigrate(&User{}); err != nil { return nil, fmt.Errorf("failed to auto-migrate via custom DSN: %w", err) } repo.db = db return repo, nil } // BuildSchemaIsolatedDSN returns a Postgres DSN that targets the given schema via // the search_path connection parameter. Use with NewPostgresRepositoryFromDSN to // get a repository whose connection only sees the per-scenario schema. func BuildSchemaIsolatedDSN(cfg *config.Config, schemaName string) string { return fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s search_path=%s", cfg.GetDatabaseHost(), cfg.GetDatabasePort(), cfg.GetDatabaseUser(), cfg.GetDatabasePassword(), cfg.GetDatabaseName(), cfg.GetDatabaseSSLMode(), schemaName, ) } // Exec runs a raw SQL statement against the repository's connection. // Used by BDD test infra for schema lifecycle (CREATE SCHEMA / DROP SCHEMA). // Avoid in production code paths -- prefer the typed Repository methods. func (r *PostgresRepository) Exec(sql string) error { if r.db == nil { return fmt.Errorf("Exec called on PostgresRepository with nil db") } return r.db.Exec(sql).Error } // initializeDatabase sets up the PostgreSQL database connection and runs migrations func (r *PostgresRepository) initializeDatabase() error { // Configure GORM logger based on config var gormLogger logger.Interface if r.config.GetLoggingJSON() { // Create zerolog logger that respects the configured output var logOutput = os.Stderr // If a log file is configured, use it if output := r.config.GetLogOutput(); output != "" { if file, err := os.OpenFile(output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { logOutput = file } } // Create zerolog logger with component context globalLogger := zerolog.New(logOutput).With().Str("component", "gorm").Logger() zw := &ZerologWriter{logger: globalLogger} gormLogger = logger.New( zw, logger.Config{ SlowThreshold: time.Second, LogLevel: logger.Warn, IgnoreRecordNotFoundError: true, Colorful: false, }, ) } else { // Use console logger for non-JSON mode gormLogger = logger.New( log.New(os.Stderr, "\n", log.LstdFlags), logger.Config{ SlowThreshold: time.Second, LogLevel: logger.Warn, IgnoreRecordNotFoundError: true, Colorful: true, }, ) } // Build PostgreSQL DSN dsn := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", r.config.GetDatabaseHost(), r.config.GetDatabasePort(), r.config.GetDatabaseUser(), r.config.GetDatabasePassword(), r.config.GetDatabaseName(), r.config.GetDatabaseSSLMode(), ) var err error r.db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: gormLogger, }) if err != nil { return fmt.Errorf("failed to connect to PostgreSQL: %w", err) } // Configure connection pool sqlDB, err := r.db.DB() if err != nil { return fmt.Errorf("failed to get SQL DB: %w", err) } // Set connection pool settings sqlDB.SetMaxOpenConns(r.config.GetDatabaseMaxOpenConns()) sqlDB.SetMaxIdleConns(r.config.GetDatabaseMaxIdleConns()) sqlDB.SetConnMaxLifetime(r.config.GetDatabaseConnMaxLifetime()) // Auto-migrate the User model if err := r.db.AutoMigrate(&User{}); err != nil { return fmt.Errorf("failed to auto-migrate: %w", err) } return nil } // CreateUser creates a new user in the database func (r *PostgresRepository) CreateUser(ctx context.Context, user *User) error { // Create telemetry span ctx, span := r.createSpan(ctx, "create_user") if span != nil { defer span.End() } result := r.db.WithContext(ctx).Create(user) if result.Error != nil { if span != nil { span.RecordError(result.Error) } return fmt.Errorf("failed to create user: %w", result.Error) } return nil } // GetUserByUsername retrieves a user by username func (r *PostgresRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) { // Create telemetry span ctx, span := r.createSpan(ctx, "get_user_by_username") if span != nil { defer span.End() span.SetAttributes(attribute.String("username", username)) } var user User result := r.db.WithContext(ctx).Where("username = ?", username).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } if span != nil { span.RecordError(result.Error) } return nil, fmt.Errorf("failed to get user by username: %w", result.Error) } return &user, nil } // GetUserByID retrieves a user by ID func (r *PostgresRepository) GetUserByID(ctx context.Context, id uint) (*User, error) { var user User result := r.db.WithContext(ctx).First(&user, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get user by ID: %w", result.Error) } return &user, nil } // UpdateUser updates a user in the database func (r *PostgresRepository) UpdateUser(ctx context.Context, user *User) error { result := r.db.WithContext(ctx).Save(user) if result.Error != nil { return fmt.Errorf("failed to update user: %w", result.Error) } return nil } // DeleteUser deletes a user from the database func (r *PostgresRepository) DeleteUser(ctx context.Context, id uint) error { result := r.db.WithContext(ctx).Delete(&User{}, id) if result.Error != nil { return fmt.Errorf("failed to delete user: %w", result.Error) } return nil } // AllowPasswordReset flags a user for password reset func (r *PostgresRepository) AllowPasswordReset(ctx context.Context, username string) error { user, err := r.GetUserByUsername(ctx, username) if err != nil { return fmt.Errorf("failed to get user for password reset: %w", err) } if user == nil { return fmt.Errorf("user not found: %s", username) } user.AllowPasswordReset = true return r.UpdateUser(ctx, user) } // CompletePasswordReset completes the password reset process func (r *PostgresRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error { user, err := r.GetUserByUsername(ctx, username) if err != nil { return fmt.Errorf("failed to get user for password reset completion: %w", err) } if user == nil { return fmt.Errorf("user not found: %s", username) } if !user.AllowPasswordReset { return fmt.Errorf("password reset not allowed for user: %s", username) } user.PasswordHash = newPasswordHash user.AllowPasswordReset = false return r.UpdateUser(ctx, user) } // UserExists checks if a user exists by username func (r *PostgresRepository) UserExists(ctx context.Context, username string) (bool, error) { var count int64 result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count) if result.Error != nil { return false, fmt.Errorf("failed to check if user exists: %w", result.Error) } return count > 0, nil } // Close closes the database connection func (r *PostgresRepository) Close() error { sqlDB, err := r.db.DB() if err != nil { return fmt.Errorf("failed to get database connection: %w", err) } return sqlDB.Close() } // CheckDatabaseHealth checks if the database is healthy and responsive func (r *PostgresRepository) CheckDatabaseHealth(ctx context.Context) error { // Simple query to test database connectivity var count int64 result := r.db.WithContext(ctx).Model(&User{}).Count(&count) if result.Error != nil { return fmt.Errorf("database health check failed: %w", result.Error) } return nil } // createSpan creates a new telemetry span if persistence telemetry is enabled func (r *PostgresRepository) createSpan(ctx context.Context, operation string) (context.Context, trace.Span) { if r.config == nil || !r.config.GetPersistenceTelemetryEnabled() { return ctx, trace.SpanFromContext(ctx) } // Create a new span with the operation name spanName := r.spanPrefix + operation tr := otel.Tracer("user-repository") return tr.Start(ctx, spanName) }