diff --git a/pkg/user/magic_link.go b/pkg/user/magic_link.go new file mode 100644 index 0000000..3a887ea --- /dev/null +++ b/pkg/user/magic_link.go @@ -0,0 +1,150 @@ +package user + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "gorm.io/gorm" +) + +// MagicLinkToken is the persistent record of a passwordless-auth token. +// +// Per ADR-0028 Phase A: the token VALUE is never stored. Only its SHA-256 +// hash sits in the DB ; if the table leaks, the attacker has no usable +// tokens (mirrors ADR-0021 secret retention via fingerprint approach). +// +// The plaintext token is delivered to the user exactly once via email and +// must be supplied back through the consume endpoint to re-derive the +// hash and find the row. +type MagicLinkToken struct { + ID uint `gorm:"primaryKey"` + CreatedAt time.Time `gorm:"autoCreateTime;not null;index"` + Email string `gorm:"not null;index"` + TokenHash string `gorm:"not null;uniqueIndex;size:64"` // hex-encoded sha256 = 64 chars + ExpiresAt time.Time `gorm:"not null;index"` + ConsumedAt *time.Time `gorm:""` +} + +// MagicLinkRepository is the persistence contract for magic-link tokens. +// PostgresRepository implements it ; tests can use a fake. +type MagicLinkRepository interface { + CreateMagicLinkToken(ctx context.Context, token *MagicLinkToken) error + GetMagicLinkTokenByHash(ctx context.Context, tokenHash string) (*MagicLinkToken, error) + MarkMagicLinkTokenConsumed(ctx context.Context, id uint, consumedAt time.Time) error + DeleteExpiredMagicLinkTokens(ctx context.Context, before time.Time) (int64, error) +} + +// GenerateMagicLinkToken returns a fresh url-safe random token suitable +// for inclusion in an email link, plus its SHA-256 hex digest for storage. +// +// The plaintext is what gets emailed ; the hash is what gets persisted. +// 32 bytes of entropy = 256 bits ; collision-resistant for our scale. +func GenerateMagicLinkToken() (plaintext, hashHex string, err error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", "", fmt.Errorf("magic link rand: %w", err) + } + plaintext = base64.RawURLEncoding.EncodeToString(buf) + hashHex = HashMagicLinkToken(plaintext) + return plaintext, hashHex, nil +} + +// HashMagicLinkToken returns the lowercase hex sha256 of token. Stable +// over time : the same plaintext always maps to the same hash, so +// consume can re-derive and look up the row. +func HashMagicLinkToken(plaintext string) string { + sum := sha256.Sum256([]byte(plaintext)) + return hex.EncodeToString(sum[:]) +} + +// CreateMagicLinkToken persists a magic-link token. The caller is +// responsible for hashing the plaintext (cf. HashMagicLinkToken) and +// setting ExpiresAt ; this method does not generate either. +func (r *PostgresRepository) CreateMagicLinkToken(ctx context.Context, token *MagicLinkToken) error { + ctx, span := r.createSpan(ctx, "create_magic_link_token") + if span != nil { + defer span.End() + span.SetAttributes(attribute.String("email", token.Email)) + } + if err := r.db.WithContext(ctx).Create(token).Error; err != nil { + if span != nil { + span.RecordError(err) + } + return fmt.Errorf("failed to create magic link token: %w", err) + } + return nil +} + +// GetMagicLinkTokenByHash looks up a magic-link token by its hex sha256. +// Returns (nil, nil) when no row matches — callers must treat that as +// "invalid token" and respond with the same generic error as "expired" +// or "consumed" to avoid leaking which condition failed. +func (r *PostgresRepository) GetMagicLinkTokenByHash(ctx context.Context, tokenHash string) (*MagicLinkToken, error) { + ctx, span := r.createSpan(ctx, "get_magic_link_token_by_hash") + if span != nil { + defer span.End() + } + var t MagicLinkToken + err := r.db.WithContext(ctx).Where("token_hash = ?", tokenHash).First(&t).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + if span != nil { + span.RecordError(err) + } + return nil, fmt.Errorf("failed to get magic link token: %w", err) + } + return &t, nil +} + +// MarkMagicLinkTokenConsumed sets consumed_at on the row with the given +// ID. Idempotent only at the SQL-engine level — the consume handler is +// responsible for refusing to act when consumed_at is already set. +func (r *PostgresRepository) MarkMagicLinkTokenConsumed(ctx context.Context, id uint, consumedAt time.Time) error { + ctx, span := r.createSpan(ctx, "mark_magic_link_token_consumed") + if span != nil { + defer span.End() + } + res := r.db.WithContext(ctx). + Model(&MagicLinkToken{}). + Where("id = ?", id). + Update("consumed_at", consumedAt) + if res.Error != nil { + if span != nil { + span.RecordError(res.Error) + } + return fmt.Errorf("failed to mark magic link token consumed: %w", res.Error) + } + if res.RowsAffected == 0 { + return fmt.Errorf("no magic link token with id=%d", id) + } + return nil +} + +// DeleteExpiredMagicLinkTokens removes rows whose expires_at is strictly +// before the given cutoff. Returns the count deleted. Used by the +// scheduled cleanup job. +func (r *PostgresRepository) DeleteExpiredMagicLinkTokens(ctx context.Context, before time.Time) (int64, error) { + ctx, span := r.createSpan(ctx, "delete_expired_magic_link_tokens") + if span != nil { + defer span.End() + } + res := r.db.WithContext(ctx). + Where("expires_at < ?", before). + Delete(&MagicLinkToken{}) + if res.Error != nil { + if span != nil { + span.RecordError(res.Error) + } + return 0, fmt.Errorf("failed to delete expired magic link tokens: %w", res.Error) + } + return res.RowsAffected, nil +} diff --git a/pkg/user/magic_link_integration_test.go b/pkg/user/magic_link_integration_test.go new file mode 100644 index 0000000..5eb273e --- /dev/null +++ b/pkg/user/magic_link_integration_test.go @@ -0,0 +1,194 @@ +//go:build integration + +// Integration tests for the magic-link repository methods. Run with: +// +// go test -tags integration ./pkg/user/... +// +// Requires a running Postgres reachable via the same env vars / defaults +// the BDD suite already uses (DLC_DATABASE_HOST, etc., default +// localhost:5432 from docker-compose). +package user + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "testing" + "time" + + "dance-lessons-coach/pkg/config" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// freshRepo connects to the local Postgres, creates a uniquely-named +// schema for THIS test, and returns a repository scoped to it. +// On test end, the schema is dropped (cleanup is best-effort). +func freshRepo(t *testing.T) *PostgresRepository { + t.Helper() + cfg, err := config.LoadConfig() + require.NoError(t, err) + + var raw [6]byte + _, err = rand.Read(raw[:]) + require.NoError(t, err) + schema := "ml_test_" + hex.EncodeToString(raw[:]) + + // Bootstrap schema via a default-DSN repo (no search_path). + bootDSN := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + cfg.GetDatabaseHost(), + cfg.GetDatabasePort(), + cfg.GetDatabaseUser(), + cfg.GetDatabasePassword(), + cfg.GetDatabaseName(), + cfg.GetDatabaseSSLMode(), + ) + bootRepo, err := NewPostgresRepositoryFromDSN(cfg, bootDSN) + require.NoError(t, err) + require.NoError(t, bootRepo.Exec(fmt.Sprintf(`CREATE SCHEMA "%s"`, schema))) + + t.Cleanup(func() { + _ = bootRepo.Exec(fmt.Sprintf(`DROP SCHEMA "%s" CASCADE`, schema)) + }) + + dsn := BuildSchemaIsolatedDSN(cfg, schema) + repo, err := NewPostgresRepositoryFromDSN(cfg, dsn) + require.NoError(t, err) + return repo +} + +// TestMagicLinkRepo_CreateAndGetByHash is the end-to-end happy path : +// store a token, look it up by hash, get the row back. +func TestMagicLinkRepo_CreateAndGetByHash(t *testing.T) { + repo := freshRepo(t) + ctx := context.Background() + + plain, hashHex, err := GenerateMagicLinkToken() + require.NoError(t, err) + + tok := &MagicLinkToken{ + Email: "alice@example.com", + TokenHash: hashHex, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + require.NoError(t, repo.CreateMagicLinkToken(ctx, tok)) + assert.NotZero(t, tok.ID, "ID should be populated by GORM after Create") + + got, err := repo.GetMagicLinkTokenByHash(ctx, hashHex) + require.NoError(t, err) + require.NotNil(t, got, "fresh token must be retrievable") + assert.Equal(t, "alice@example.com", got.Email) + assert.Nil(t, got.ConsumedAt, "fresh token is not yet consumed") + + // Lookup by the plaintext (which the consume handler does NOT receive + // directly — it must hash first). This confirms the hashing direction + // is consistent. + got2, err := repo.GetMagicLinkTokenByHash(ctx, HashMagicLinkToken(plain)) + require.NoError(t, err) + require.NotNil(t, got2) + assert.Equal(t, tok.ID, got2.ID) +} + +// TestMagicLinkRepo_GetByHash_Missing returns (nil, nil) for a hash that +// never existed. Callers must NOT distinguish "missing" from "expired" +// or "consumed" — they all collapse to a single generic error to the user. +func TestMagicLinkRepo_GetByHash_Missing(t *testing.T) { + repo := freshRepo(t) + got, err := repo.GetMagicLinkTokenByHash(context.Background(), HashMagicLinkToken("never-issued")) + require.NoError(t, err) + assert.Nil(t, got) +} + +// TestMagicLinkRepo_MarkConsumed flips consumed_at and refuses to act +// on a non-existent ID. +func TestMagicLinkRepo_MarkConsumed(t *testing.T) { + repo := freshRepo(t) + ctx := context.Background() + + _, hashHex, err := GenerateMagicLinkToken() + require.NoError(t, err) + tok := &MagicLinkToken{ + Email: "bob@example.com", + TokenHash: hashHex, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + require.NoError(t, repo.CreateMagicLinkToken(ctx, tok)) + + now := time.Now().UTC().Truncate(time.Second) + require.NoError(t, repo.MarkMagicLinkTokenConsumed(ctx, tok.ID, now)) + + got, err := repo.GetMagicLinkTokenByHash(ctx, hashHex) + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.ConsumedAt, "consumed_at must be set") + assert.WithinDuration(t, now, got.ConsumedAt.UTC(), time.Second) + + // Marking a non-existent ID returns an error (defensive — the consume + // handler should never call us with a fake ID, but if it does we want + // the failure to be loud). + err = repo.MarkMagicLinkTokenConsumed(ctx, 999999, time.Now()) + require.Error(t, err) +} + +// TestMagicLinkRepo_DeleteExpired confirms the cleanup pass deletes +// strictly-before-cutoff rows and leaves future ones alone. +func TestMagicLinkRepo_DeleteExpired(t *testing.T) { + repo := freshRepo(t) + ctx := context.Background() + + now := time.Now() + expired := &MagicLinkToken{ + Email: "expired@example.com", + TokenHash: HashMagicLinkToken("expired-token"), + ExpiresAt: now.Add(-1 * time.Hour), + } + fresh := &MagicLinkToken{ + Email: "fresh@example.com", + TokenHash: HashMagicLinkToken("fresh-token"), + ExpiresAt: now.Add(1 * time.Hour), + } + require.NoError(t, repo.CreateMagicLinkToken(ctx, expired)) + require.NoError(t, repo.CreateMagicLinkToken(ctx, fresh)) + + deleted, err := repo.DeleteExpiredMagicLinkTokens(ctx, now) + require.NoError(t, err) + assert.EqualValues(t, 1, deleted, "exactly one row was past the cutoff") + + // Expired row is gone, fresh row is still there. + got, err := repo.GetMagicLinkTokenByHash(ctx, HashMagicLinkToken("expired-token")) + require.NoError(t, err) + assert.Nil(t, got, "expired token must be gone") + + got, err = repo.GetMagicLinkTokenByHash(ctx, HashMagicLinkToken("fresh-token")) + require.NoError(t, err) + require.NotNil(t, got, "fresh token must remain") +} + +// TestMagicLinkRepo_HashUniqueness is a defensive check that the unique +// index on token_hash actually rejects duplicates. If the index is ever +// dropped from the schema, this test catches it before security does. +func TestMagicLinkRepo_HashUniqueness(t *testing.T) { + repo := freshRepo(t) + ctx := context.Background() + + _, hashHex, err := GenerateMagicLinkToken() + require.NoError(t, err) + + first := &MagicLinkToken{ + Email: "a@example.com", + TokenHash: hashHex, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + require.NoError(t, repo.CreateMagicLinkToken(ctx, first)) + + dup := &MagicLinkToken{ + Email: "b@example.com", + TokenHash: hashHex, // same hash as `first` + ExpiresAt: time.Now().Add(15 * time.Minute), + } + err = repo.CreateMagicLinkToken(ctx, dup) + require.Error(t, err, "second insert with same hash must violate the unique index") +} diff --git a/pkg/user/magic_link_test.go b/pkg/user/magic_link_test.go new file mode 100644 index 0000000..cd51a58 --- /dev/null +++ b/pkg/user/magic_link_test.go @@ -0,0 +1,78 @@ +package user + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGenerateMagicLinkToken_ShapeAndHashAgree confirms the contract that +// HashMagicLinkToken(plaintext) == returned hashHex. Without that, the +// consume handler can never look up what request stored. +func TestGenerateMagicLinkToken_ShapeAndHashAgree(t *testing.T) { + plain, hashHex, err := GenerateMagicLinkToken() + require.NoError(t, err) + + assert.NotEmpty(t, plain) + assert.NotEmpty(t, hashHex) + assert.Len(t, hashHex, 64, "sha256 hex = 64 chars") + assert.Equal(t, hashHex, HashMagicLinkToken(plain), + "GenerateMagicLinkToken must return a hash that matches HashMagicLinkToken(plain)") +} + +// TestGenerateMagicLinkToken_PlainIsURLSafeBase64 confirms the link can +// be embedded in a URL without further escaping. RawURLEncoding => no +// "/", "+", or "=" padding chars. +func TestGenerateMagicLinkToken_PlainIsURLSafeBase64(t *testing.T) { + plain, _, err := GenerateMagicLinkToken() + require.NoError(t, err) + + for _, bad := range []string{"/", "+", "="} { + assert.False(t, strings.Contains(plain, bad), + "plaintext token must not contain %q (URL-unsafe)", bad) + } + + decoded, err := base64.RawURLEncoding.DecodeString(plain) + require.NoError(t, err, "plaintext must round-trip through RawURLEncoding") + assert.Len(t, decoded, 32, "32 bytes of entropy") +} + +// TestGenerateMagicLinkToken_Unique confirms two consecutive calls +// produce different tokens (not a deterministic seeding bug). +func TestGenerateMagicLinkToken_Unique(t *testing.T) { + a, ah, err := GenerateMagicLinkToken() + require.NoError(t, err) + b, bh, err := GenerateMagicLinkToken() + require.NoError(t, err) + + assert.NotEqual(t, a, b, "plaintexts must differ between calls") + assert.NotEqual(t, ah, bh, "hashes must differ between calls") +} + +// TestHashMagicLinkToken_StableAndCorrect confirms HashMagicLinkToken is +// a pure function (same input -> same output) AND that it produces the +// expected sha256 hex digest. Cross-checked against the stdlib so we +// catch any accidental algorithm swap. +func TestHashMagicLinkToken_StableAndCorrect(t *testing.T) { + const sample = "abc123-test-token" + got1 := HashMagicLinkToken(sample) + got2 := HashMagicLinkToken(sample) + assert.Equal(t, got1, got2, "HashMagicLinkToken must be deterministic") + + sum := sha256.Sum256([]byte(sample)) + want := hex.EncodeToString(sum[:]) + assert.Equal(t, want, got1, "HashMagicLinkToken must be sha256 hex") +} + +// TestHashMagicLinkToken_DiffersOnDifferentInput is the tautological +// counter-test of stability : different inputs -> different outputs. +// Catches the (unlikely) case where someone replaces the impl with +// a constant. +func TestHashMagicLinkToken_DiffersOnDifferentInput(t *testing.T) { + assert.NotEqual(t, HashMagicLinkToken("a"), HashMagicLinkToken("b")) +} diff --git a/pkg/user/postgres_repository.go b/pkg/user/postgres_repository.go index d7c0507..abcc00b 100644 --- a/pkg/user/postgres_repository.go +++ b/pkg/user/postgres_repository.go @@ -160,7 +160,7 @@ func NewPostgresRepositoryFromDSN(cfg *config.Config, dsn string) (*PostgresRepo sqlDB.SetMaxIdleConns(cfg.GetDatabaseMaxIdleConns()) sqlDB.SetConnMaxLifetime(cfg.GetDatabaseConnMaxLifetime()) - if err := db.AutoMigrate(&User{}); err != nil { + if err := db.AutoMigrate(&User{}, &MagicLinkToken{}); err != nil { return nil, fmt.Errorf("failed to auto-migrate via custom DSN: %w", err) } @@ -264,8 +264,8 @@ func (r *PostgresRepository) initializeDatabase() error { sqlDB.SetMaxIdleConns(r.config.GetDatabaseMaxIdleConns()) sqlDB.SetConnMaxLifetime(r.config.GetDatabaseConnMaxLifetime()) - // Auto-migrate the User model - if err := r.db.AutoMigrate(&User{}); err != nil { + // Auto-migrate the User model + MagicLinkToken (ADR-0028 Phase A) + if err := r.db.AutoMigrate(&User{}, &MagicLinkToken{}); err != nil { return fmt.Errorf("failed to auto-migrate: %w", err) }