package api import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "dance-lessons-coach/pkg/config" "dance-lessons-coach/pkg/email" "dance-lessons-coach/pkg/user" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // fakeMLRepo is an in-memory MagicLinkRepository for the handler tests. type fakeMLRepo struct { mu sync.Mutex tokens map[string]*user.MagicLinkToken // key: TokenHash nextID uint failOn string // "create" / "get" / "mark" / "" (none) } func newFakeMLRepo() *fakeMLRepo { return &fakeMLRepo{tokens: map[string]*user.MagicLinkToken{}} } func (r *fakeMLRepo) CreateMagicLinkToken(_ context.Context, t *user.MagicLinkToken) error { r.mu.Lock() defer r.mu.Unlock() if r.failOn == "create" { return errors.New("simulated create failure") } r.nextID++ t.ID = r.nextID r.tokens[t.TokenHash] = t return nil } func (r *fakeMLRepo) GetMagicLinkTokenByHash(_ context.Context, h string) (*user.MagicLinkToken, error) { r.mu.Lock() defer r.mu.Unlock() if r.failOn == "get" { return nil, errors.New("simulated get failure") } t, ok := r.tokens[h] if !ok { return nil, nil } cp := *t return &cp, nil } func (r *fakeMLRepo) MarkMagicLinkTokenConsumed(_ context.Context, id uint, when time.Time) error { r.mu.Lock() defer r.mu.Unlock() if r.failOn == "mark" { return errors.New("simulated mark failure") } for _, t := range r.tokens { if t.ID == id { t.ConsumedAt = &when return nil } } return errors.New("not found") } func (r *fakeMLRepo) DeleteExpiredMagicLinkTokens(_ context.Context, _ time.Time) (int64, error) { return 0, nil } // fakeUserSvc is a minimal user.UserService stub. type fakeUserSvc struct { createdUsers []*user.User jwtForID map[uint]string hashCalls int failOn string // "create" / "hash" / "jwt" } func newFakeUserSvc() *fakeUserSvc { return &fakeUserSvc{jwtForID: map[uint]string{}} } func (s *fakeUserSvc) Authenticate(_ context.Context, _, _ string) (*user.User, error) { return nil, errors.New("not used in magic-link tests") } func (s *fakeUserSvc) GenerateJWT(_ context.Context, u *user.User) (string, error) { if s.failOn == "jwt" { return "", errors.New("simulated jwt failure") } return "jwt-for-user-" + u.Username, nil } func (s *fakeUserSvc) ValidateJWT(_ context.Context, _ string) (*user.User, error) { return nil, errors.New("not used") } func (s *fakeUserSvc) AdminAuthenticate(_ context.Context, _ string) (*user.User, error) { return nil, errors.New("not used") } func (s *fakeUserSvc) AddJWTSecret(_ string, _ bool, _ time.Duration) {} func (s *fakeUserSvc) RotateJWTSecret(_ string) {} func (s *fakeUserSvc) GetJWTSecretByIndex(_ int) (string, bool) { return "", false } func (s *fakeUserSvc) ResetJWTSecrets() {} func (s *fakeUserSvc) StartJWTSecretCleanupLoop(_ context.Context, _ time.Duration) {} func (s *fakeUserSvc) RemoveExpiredJWTSecrets() int { return 0 } func (s *fakeUserSvc) ListJWTSecretsInfo() []user.JWTSecretInfo { return nil } func (s *fakeUserSvc) UserExists(_ context.Context, username string) (bool, error) { for _, u := range s.createdUsers { if u.Username == username { return true, nil } } return false, nil } func (s *fakeUserSvc) CreateUser(_ context.Context, u *user.User) error { if s.failOn == "create" { return errors.New("simulated create failure") } u.ID = uint(len(s.createdUsers) + 1) cp := *u s.createdUsers = append(s.createdUsers, &cp) return nil } func (s *fakeUserSvc) HashPassword(_ context.Context, p string) (string, error) { s.hashCalls++ if s.failOn == "hash" { return "", errors.New("simulated hash failure") } return "hash:" + p, nil } func (s *fakeUserSvc) RequestPasswordReset(_ context.Context, _ string) error { return nil } func (s *fakeUserSvc) CompletePasswordReset(_ context.Context, _, _ string) error { return nil } // fakeUserRepo implements user.UserRepository using fakeUserSvc's slice. type fakeUserRepo struct{ svc *fakeUserSvc } func (r *fakeUserRepo) CreateUser(_ context.Context, u *user.User) error { return r.svc.CreateUser(context.Background(), u) } func (r *fakeUserRepo) GetUserByUsername(_ context.Context, name string) (*user.User, error) { for _, u := range r.svc.createdUsers { if u.Username == name { cp := *u return &cp, nil } } return nil, nil } func (r *fakeUserRepo) GetUserByID(_ context.Context, _ uint) (*user.User, error) { return nil, nil } func (r *fakeUserRepo) UpdateUser(_ context.Context, _ *user.User) error { return nil } func (r *fakeUserRepo) DeleteUser(_ context.Context, _ uint) error { return nil } func (r *fakeUserRepo) AllowPasswordReset(_ context.Context, _ string) error { return nil } func (r *fakeUserRepo) CompletePasswordReset(_ context.Context, _, _ string) error { return nil } func (r *fakeUserRepo) UserExists(_ context.Context, name string) (bool, error) { return r.svc.UserExists(context.Background(), name) } func (r *fakeUserRepo) CheckDatabaseHealth(_ context.Context) error { return nil } // recordingSender captures email.Send calls without sending anything. type recordingSender struct { mu sync.Mutex messages []email.Message failNext bool } func (s *recordingSender) Send(_ context.Context, m email.Message) error { s.mu.Lock() defer s.mu.Unlock() if s.failNext { return errors.New("simulated send failure") } s.messages = append(s.messages, m) return nil } func newHandler(t *testing.T) (*MagicLinkHandler, *fakeMLRepo, *fakeUserSvc, *recordingSender) { t.Helper() mlRepo := newFakeMLRepo() svc := newFakeUserSvc() repo := &fakeUserRepo{svc: svc} sender := &recordingSender{} h := NewMagicLinkHandler( mlRepo, svc, repo, sender, config.MagicLinkConfig{TTL: 15 * time.Minute, BaseURL: "http://test.local"}, "noreply@test.local", nil, ) return h, mlRepo, svc, sender } func mountAndRequest(h *MagicLinkHandler, method, path, body string) *httptest.ResponseRecorder { r := chi.NewRouter() h.RegisterRoutes(r) req := httptest.NewRequest(method, path, strings.NewReader(body)) if body != "" { req.Header.Set("Content-Type", "application/json") } rr := httptest.NewRecorder() r.ServeHTTP(rr, req) return rr } // TestRequest_HappyPath confirms POST /magic-link/request stores a token, // sends an email containing the link, and returns 200 with a generic body. func TestRequest_HappyPath(t *testing.T) { h, mlRepo, _, sender := newHandler(t) rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"alice@example.com"}`) require.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), "If that email is valid") // One token persisted, email lower-cased. require.Len(t, mlRepo.tokens, 1) for _, tok := range mlRepo.tokens { assert.Equal(t, "alice@example.com", tok.Email) assert.Greater(t, tok.ExpiresAt.Unix(), time.Now().Unix()) } // One email sent to the same address, link points at our test base URL. require.Len(t, sender.messages, 1) assert.Equal(t, "alice@example.com", sender.messages[0].To) assert.Contains(t, sender.messages[0].BodyText, "http://test.local/api/v1/auth/magic-link/consume?token=") } // TestRequest_NormalizesEmail confirms the email is lower-cased + trimmed. func TestRequest_NormalizesEmail(t *testing.T) { h, mlRepo, _, sender := newHandler(t) rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":" Alice@Example.COM "}`) require.Equal(t, http.StatusOK, rr.Code) require.Len(t, mlRepo.tokens, 1) for _, tok := range mlRepo.tokens { assert.Equal(t, "alice@example.com", tok.Email) } assert.Equal(t, "alice@example.com", sender.messages[0].To) } // TestRequest_BadJSON returns 400. func TestRequest_BadJSON(t *testing.T) { h, _, _, _ := newHandler(t) rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `not json`) assert.Equal(t, http.StatusBadRequest, rr.Code) } // TestRequest_PersistFailureStillReturns200 — a DB error must NOT leak // to the user (would let attackers detect storage outages). func TestRequest_PersistFailureStillReturns200(t *testing.T) { h, mlRepo, _, sender := newHandler(t) mlRepo.failOn = "create" rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"bob@example.com"}`) assert.Equal(t, http.StatusOK, rr.Code) // No email was sent because no token was persisted. assert.Empty(t, sender.messages) } // TestConsume_HappyPath_NewUser exercises sign-up-on-first-link. func TestConsume_HappyPath_NewUser(t *testing.T) { h, mlRepo, svc, _ := newHandler(t) // Seed one token by going through the request flow. mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"alice@example.com"}`) require.Len(t, mlRepo.tokens, 1) // We need the plaintext to consume — derive it from the only token in the // repo by reverse trick : the request handler doesn't expose it. So we // drive consume with a fresh known-plaintext we put into the repo // directly. plain, hashHex, err := user.GenerateMagicLinkToken() require.NoError(t, err) mlRepo.tokens = map[string]*user.MagicLinkToken{ hashHex: {ID: 99, Email: "alice@example.com", TokenHash: hashHex, ExpiresAt: time.Now().Add(5 * time.Minute)}, } mlRepo.nextID = 99 rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "") require.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) var resp MagicLinkResponse require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) assert.Equal(t, "signed in", resp.Message) assert.Equal(t, "jwt-for-user-alice@example.com", resp.Token) // User was created. require.Len(t, svc.createdUsers, 1) assert.Equal(t, "alice@example.com", svc.createdUsers[0].Username) assert.NotEmpty(t, svc.createdUsers[0].PasswordHash, "passwordless user must still have a non-empty hash (random unguessable value)") assert.Equal(t, 1, svc.hashCalls) // Token marked consumed. for _, tok := range mlRepo.tokens { require.NotNil(t, tok.ConsumedAt, "consumed_at must be set after consume") } } // TestConsume_HappyPath_ExistingUser confirms no new user is created // when the email is already known. func TestConsume_HappyPath_ExistingUser(t *testing.T) { h, mlRepo, svc, _ := newHandler(t) // Pre-seed the user. require.NoError(t, svc.CreateUser(context.Background(), &user.User{Username: "carol@example.com", PasswordHash: "x"})) require.Len(t, svc.createdUsers, 1) preCount := len(svc.createdUsers) plain, hashHex, err := user.GenerateMagicLinkToken() require.NoError(t, err) mlRepo.tokens[hashHex] = &user.MagicLinkToken{ID: 1, Email: "carol@example.com", TokenHash: hashHex, ExpiresAt: time.Now().Add(5 * time.Minute)} rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "") require.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) // No new user. assert.Len(t, svc.createdUsers, preCount) assert.Equal(t, 0, svc.hashCalls, "no hash call when user exists") } // TestConsume_MissingToken returns 400. func TestConsume_MissingToken(t *testing.T) { h, _, _, _ := newHandler(t) rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume", "") assert.Equal(t, http.StatusBadRequest, rr.Code) } // TestConsume_UnknownToken returns 401 (single generic shape). func TestConsume_UnknownToken(t *testing.T) { h, _, _, _ := newHandler(t) rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token=neverissued", "") assert.Equal(t, http.StatusUnauthorized, rr.Code) } // TestConsume_ExpiredToken returns 401. func TestConsume_ExpiredToken(t *testing.T) { h, mlRepo, _, _ := newHandler(t) plain, hashHex, err := user.GenerateMagicLinkToken() require.NoError(t, err) mlRepo.tokens[hashHex] = &user.MagicLinkToken{ ID: 1, Email: "x@example.com", TokenHash: hashHex, ExpiresAt: time.Now().Add(-1 * time.Minute), // already expired } rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "") assert.Equal(t, http.StatusUnauthorized, rr.Code) } // TestConsume_AlreadyConsumed returns 401 — single-use guarantee. func TestConsume_AlreadyConsumed(t *testing.T) { h, mlRepo, _, _ := newHandler(t) plain, hashHex, err := user.GenerateMagicLinkToken() require.NoError(t, err) now := time.Now() mlRepo.tokens[hashHex] = &user.MagicLinkToken{ ID: 1, Email: "x@example.com", TokenHash: hashHex, ExpiresAt: now.Add(5 * time.Minute), ConsumedAt: &now, } rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "") assert.Equal(t, http.StatusUnauthorized, rr.Code) } // TestBuildMagicLinkURL_TrailingSlash exercises the small helper. func TestBuildMagicLinkURL_TrailingSlash(t *testing.T) { got := buildMagicLinkURL("http://x.local/", "abc") assert.Equal(t, "http://x.local/api/v1/auth/magic-link/consume?token=abc", got) }