Files
dance-lessons-coach/pkg/user/postgres_repository.go
Gabriel Radureau 7a2b1a0a87 feat(user): NewPostgresRepositoryFromDSN factory + integration test (T12 stage 1/2)
First building block for parallel-safe BDD scenario isolation (T12 plan,
ADR-0025 follow-up). PR #28 had to revert BDD_SCHEMA_ISOLATION because
SetupScenarioSchema created an empty schema without migrations -- the
production server's repo never saw it. This PR adds the missing piece:
a factory that opens a *PostgresRepository connected via an arbitrary
DSN AND runs AutoMigrate against it, so a per-scenario schema actually
gets the users table.

Public API additions in pkg/user/postgres_repository.go:

- NewPostgresRepositoryFromDSN(cfg, dsn) (*PostgresRepository, error)
  Opens the repo from an explicit DSN (overrides cfg's host/port/etc),
  runs AutoMigrate -- creates tables in whatever schema the DSN's
  search_path points to.

- BuildSchemaIsolatedDSN(cfg, schemaName) string
  Builds a DSN with `search_path=<schemaName>` from a base config.

The existing NewPostgresRepository(cfg) is unchanged. Existing Close()
method is reused.

Integration test in postgres_repository_isolated_test.go proves:
- AutoMigrate creates `users` table in the per-scenario schema (not public)
- A CreateUser through the isolated repo writes into the per-scenario schema
- public.users sees ZERO rows for the test username
- The per-scenario schema users table sees exactly 1 row

Test skips gracefully when DLC_DATABASE_HOST is not set.

Out of scope (T12 stage 2/2 next):
- Wiring this factory into pkg/bdd/testserver/SetupScenarioSchema
- Spawning a fresh server.Server per scenario (requires NewServerWithUserRepo)
- Removing -p 1 from scripts/run-bdd-tests.sh after parallel safety is achieved

Per code-reviewer skill SOLID/DDD section :
- SRP : factory has single responsibility (open + migrate, no business logic)
- OCP : the new factory extends the package without changing existing callers
- Cognitive load : 1 file, 50 lines added, 1 dedicated test file

🤖 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-03 18:03:08 +02:00

413 lines
12 KiB
Go

package user
import (
"context"
"errors"
"fmt"
"log"
"os"
"time"
"dance-lessons-coach/pkg/config"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// ZerologWriter implements logger.Writer interface using zerolog
type ZerologWriter struct {
logger zerolog.Logger
}
func (zw *ZerologWriter) Printf(format string, v ...interface{}) {
message := fmt.Sprintf(format, v...)
// Determine appropriate log level based on message content
if len(message) > 0 {
// Check for error indicators
if containsErrorIndicators(message) {
zw.logger.Error().Str("gorm", message).Send()
return
}
// Check for slow query indicators
if containsSlowQueryIndicators(message) {
zw.logger.Warn().Str("gorm", message).Send()
return
}
// Default to debug level for regular SQL queries
zw.logger.Debug().Str("gorm", message).Send()
}
}
// containsErrorIndicators checks if the message contains error-related keywords
func containsErrorIndicators(message string) bool {
errorKeywords := []string{"error", "Error", "failed", "Failed", "not found", "Not Found"}
for _, keyword := range errorKeywords {
if containsIgnoreCase(message, keyword) {
return true
}
}
return false
}
// containsSlowQueryIndicators checks if the message contains slow query indicators
func containsSlowQueryIndicators(message string) bool {
slowKeywords := []string{"slow", "Slow", "timeout", "Timeout"}
for _, keyword := range slowKeywords {
if containsIgnoreCase(message, keyword) {
return true
}
}
return false
}
// containsIgnoreCase performs case-insensitive string containment check
func containsIgnoreCase(s, substr string) bool {
return containsIgnoreCaseBytes([]byte(s), []byte(substr))
}
// containsIgnoreCaseBytes is a helper for case-insensitive byte slice containment
func containsIgnoreCaseBytes(s, substr []byte) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if toLower(s[i+j]) != toLower(substr[j]) {
match = false
break
}
}
if match {
return true
}
}
return false
}
// toLower converts byte to lowercase
func toLower(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + 32
}
return b
}
// PostgresRepository implements UserRepository using PostgreSQL
type PostgresRepository struct {
db *gorm.DB
config *config.Config
spanPrefix string
}
// NewPostgresRepository creates a new PostgreSQL repository
func NewPostgresRepository(cfg *config.Config) (*PostgresRepository, error) {
repo := &PostgresRepository{
config: cfg,
spanPrefix: "user.repo.",
}
if err := repo.initializeDatabase(); err != nil {
return nil, fmt.Errorf("failed to initialize PostgreSQL database: %w", err)
}
return repo, nil
}
// NewPostgresRepositoryFromDSN creates a PostgresRepository connected via the given DSN
// and runs AutoMigrate against it. Used by BDD test infra to create a per-scenario
// repository pointing at an isolated schema (the DSN typically includes search_path=<schema>).
//
// Pass the same cfg used elsewhere (it is required by methods that read pool settings),
// but the DSN passed here OVERRIDES the host/port/dbname/etc that cfg would have built.
func NewPostgresRepositoryFromDSN(cfg *config.Config, dsn string) (*PostgresRepository, error) {
repo := &PostgresRepository{
config: cfg,
spanPrefix: "user.repo.",
}
gormLogger := logger.New(
log.New(os.Stderr, "\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Warn,
IgnoreRecordNotFoundError: true,
Colorful: false,
},
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: gormLogger})
if err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL with custom DSN: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB from gorm: %w", err)
}
sqlDB.SetMaxOpenConns(cfg.GetDatabaseMaxOpenConns())
sqlDB.SetMaxIdleConns(cfg.GetDatabaseMaxIdleConns())
sqlDB.SetConnMaxLifetime(cfg.GetDatabaseConnMaxLifetime())
if err := db.AutoMigrate(&User{}); err != nil {
return nil, fmt.Errorf("failed to auto-migrate via custom DSN: %w", err)
}
repo.db = db
return repo, nil
}
// BuildSchemaIsolatedDSN returns a Postgres DSN that targets the given schema via
// the search_path connection parameter. Use with NewPostgresRepositoryFromDSN to
// get a repository whose connection only sees the per-scenario schema.
func BuildSchemaIsolatedDSN(cfg *config.Config, schemaName string) string {
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s search_path=%s",
cfg.GetDatabaseHost(),
cfg.GetDatabasePort(),
cfg.GetDatabaseUser(),
cfg.GetDatabasePassword(),
cfg.GetDatabaseName(),
cfg.GetDatabaseSSLMode(),
schemaName,
)
}
// (Close already exists below; we reuse it.)
// initializeDatabase sets up the PostgreSQL database connection and runs migrations
func (r *PostgresRepository) initializeDatabase() error {
// Configure GORM logger based on config
var gormLogger logger.Interface
if r.config.GetLoggingJSON() {
// Create zerolog logger that respects the configured output
var logOutput = os.Stderr
// If a log file is configured, use it
if output := r.config.GetLogOutput(); output != "" {
if file, err := os.OpenFile(output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
logOutput = file
}
}
// Create zerolog logger with component context
globalLogger := zerolog.New(logOutput).With().Str("component", "gorm").Logger()
zw := &ZerologWriter{logger: globalLogger}
gormLogger = logger.New(
zw,
logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Warn,
IgnoreRecordNotFoundError: true,
Colorful: false,
},
)
} else {
// Use console logger for non-JSON mode
gormLogger = logger.New(
log.New(os.Stderr, "\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Warn,
IgnoreRecordNotFoundError: true,
Colorful: true,
},
)
}
// Build PostgreSQL DSN
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
r.config.GetDatabaseHost(),
r.config.GetDatabasePort(),
r.config.GetDatabaseUser(),
r.config.GetDatabasePassword(),
r.config.GetDatabaseName(),
r.config.GetDatabaseSSLMode(),
)
var err error
r.db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger,
})
if err != nil {
return fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
// Configure connection pool
sqlDB, err := r.db.DB()
if err != nil {
return fmt.Errorf("failed to get SQL DB: %w", err)
}
// Set connection pool settings
sqlDB.SetMaxOpenConns(r.config.GetDatabaseMaxOpenConns())
sqlDB.SetMaxIdleConns(r.config.GetDatabaseMaxIdleConns())
sqlDB.SetConnMaxLifetime(r.config.GetDatabaseConnMaxLifetime())
// Auto-migrate the User model
if err := r.db.AutoMigrate(&User{}); err != nil {
return fmt.Errorf("failed to auto-migrate: %w", err)
}
return nil
}
// CreateUser creates a new user in the database
func (r *PostgresRepository) CreateUser(ctx context.Context, user *User) error {
// Create telemetry span
ctx, span := r.createSpan(ctx, "create_user")
if span != nil {
defer span.End()
}
result := r.db.WithContext(ctx).Create(user)
if result.Error != nil {
if span != nil {
span.RecordError(result.Error)
}
return fmt.Errorf("failed to create user: %w", result.Error)
}
return nil
}
// GetUserByUsername retrieves a user by username
func (r *PostgresRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) {
// Create telemetry span
ctx, span := r.createSpan(ctx, "get_user_by_username")
if span != nil {
defer span.End()
span.SetAttributes(attribute.String("username", username))
}
var user User
result := r.db.WithContext(ctx).Where("username = ?", username).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if span != nil {
span.RecordError(result.Error)
}
return nil, fmt.Errorf("failed to get user by username: %w", result.Error)
}
return &user, nil
}
// GetUserByID retrieves a user by ID
func (r *PostgresRepository) GetUserByID(ctx context.Context, id uint) (*User, error) {
var user User
result := r.db.WithContext(ctx).First(&user, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by ID: %w", result.Error)
}
return &user, nil
}
// UpdateUser updates a user in the database
func (r *PostgresRepository) UpdateUser(ctx context.Context, user *User) error {
result := r.db.WithContext(ctx).Save(user)
if result.Error != nil {
return fmt.Errorf("failed to update user: %w", result.Error)
}
return nil
}
// DeleteUser deletes a user from the database
func (r *PostgresRepository) DeleteUser(ctx context.Context, id uint) error {
result := r.db.WithContext(ctx).Delete(&User{}, id)
if result.Error != nil {
return fmt.Errorf("failed to delete user: %w", result.Error)
}
return nil
}
// AllowPasswordReset flags a user for password reset
func (r *PostgresRepository) AllowPasswordReset(ctx context.Context, username string) error {
user, err := r.GetUserByUsername(ctx, username)
if err != nil {
return fmt.Errorf("failed to get user for password reset: %w", err)
}
if user == nil {
return fmt.Errorf("user not found: %s", username)
}
user.AllowPasswordReset = true
return r.UpdateUser(ctx, user)
}
// CompletePasswordReset completes the password reset process
func (r *PostgresRepository) CompletePasswordReset(ctx context.Context, username, newPasswordHash string) error {
user, err := r.GetUserByUsername(ctx, username)
if err != nil {
return fmt.Errorf("failed to get user for password reset completion: %w", err)
}
if user == nil {
return fmt.Errorf("user not found: %s", username)
}
if !user.AllowPasswordReset {
return fmt.Errorf("password reset not allowed for user: %s", username)
}
user.PasswordHash = newPasswordHash
user.AllowPasswordReset = false
return r.UpdateUser(ctx, user)
}
// UserExists checks if a user exists by username
func (r *PostgresRepository) UserExists(ctx context.Context, username string) (bool, error) {
var count int64
result := r.db.WithContext(ctx).Model(&User{}).Where("username = ?", username).Count(&count)
if result.Error != nil {
return false, fmt.Errorf("failed to check if user exists: %w", result.Error)
}
return count > 0, nil
}
// Close closes the database connection
func (r *PostgresRepository) Close() error {
sqlDB, err := r.db.DB()
if err != nil {
return fmt.Errorf("failed to get database connection: %w", err)
}
return sqlDB.Close()
}
// CheckDatabaseHealth checks if the database is healthy and responsive
func (r *PostgresRepository) CheckDatabaseHealth(ctx context.Context) error {
// Simple query to test database connectivity
var count int64
result := r.db.WithContext(ctx).Model(&User{}).Count(&count)
if result.Error != nil {
return fmt.Errorf("database health check failed: %w", result.Error)
}
return nil
}
// createSpan creates a new telemetry span if persistence telemetry is enabled
func (r *PostgresRepository) createSpan(ctx context.Context, operation string) (context.Context, trace.Span) {
if r.config == nil || !r.config.GetPersistenceTelemetryEnabled() {
return ctx, trace.SpanFromContext(ctx)
}
// Create a new span with the operation name
spanName := r.spanPrefix + operation
tr := otel.Tracer("user-repository")
return tr.Start(ctx, spanName)
}