✨ feat(auth): magic-link request + consume HTTP handlers (ADR-0028 Phase A.4) #62
@@ -104,10 +104,17 @@ type APIConfig struct {
|
|||||||
|
|
||||||
// AuthConfig holds authentication configuration
|
// AuthConfig holds authentication configuration
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
JWTSecret string `mapstructure:"jwt_secret"`
|
JWTSecret string `mapstructure:"jwt_secret"`
|
||||||
AdminMasterPassword string `mapstructure:"admin_master_password"`
|
AdminMasterPassword string `mapstructure:"admin_master_password"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Email EmailConfig `mapstructure:"email"`
|
Email EmailConfig `mapstructure:"email"`
|
||||||
|
MagicLink MagicLinkConfig `mapstructure:"magic_link"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MagicLinkConfig holds passwordless-auth magic-link parameters (ADR-0028 Phase A).
|
||||||
|
type MagicLinkConfig struct {
|
||||||
|
TTL time.Duration `mapstructure:"ttl"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmailConfig holds outgoing email transport configuration.
|
// EmailConfig holds outgoing email transport configuration.
|
||||||
@@ -276,6 +283,10 @@ func LoadConfig() (*Config, error) {
|
|||||||
v.SetDefault("auth.email.smtp_use_tls", false)
|
v.SetDefault("auth.email.smtp_use_tls", false)
|
||||||
v.SetDefault("auth.email.timeout", 10*time.Second)
|
v.SetDefault("auth.email.timeout", 10*time.Second)
|
||||||
|
|
||||||
|
// Magic-link defaults (ADR-0028 Phase A).
|
||||||
|
v.SetDefault("auth.magic_link.ttl", 15*time.Minute)
|
||||||
|
v.SetDefault("auth.magic_link.base_url", "http://localhost:8080")
|
||||||
|
|
||||||
// Check for custom config file path via environment variable
|
// Check for custom config file path via environment variable
|
||||||
if configFile := os.Getenv("DLC_CONFIG_FILE"); configFile != "" {
|
if configFile := os.Getenv("DLC_CONFIG_FILE"); configFile != "" {
|
||||||
v.SetConfigFile(configFile)
|
v.SetConfigFile(configFile)
|
||||||
@@ -328,6 +339,10 @@ func LoadConfig() (*Config, error) {
|
|||||||
v.BindEnv("auth.email.smtp_password", "DLC_AUTH_EMAIL_SMTP_PASSWORD")
|
v.BindEnv("auth.email.smtp_password", "DLC_AUTH_EMAIL_SMTP_PASSWORD")
|
||||||
v.BindEnv("auth.email.smtp_use_tls", "DLC_AUTH_EMAIL_SMTP_USE_TLS")
|
v.BindEnv("auth.email.smtp_use_tls", "DLC_AUTH_EMAIL_SMTP_USE_TLS")
|
||||||
v.BindEnv("auth.email.timeout", "DLC_AUTH_EMAIL_TIMEOUT")
|
v.BindEnv("auth.email.timeout", "DLC_AUTH_EMAIL_TIMEOUT")
|
||||||
|
|
||||||
|
// Magic-link environment variables (ADR-0028 Phase A).
|
||||||
|
v.BindEnv("auth.magic_link.ttl", "DLC_AUTH_MAGIC_LINK_TTL")
|
||||||
|
v.BindEnv("auth.magic_link.base_url", "DLC_AUTH_MAGIC_LINK_BASE_URL")
|
||||||
v.BindEnv("telemetry.sampler.type", "DLC_TELEMETRY_SAMPLER_TYPE")
|
v.BindEnv("telemetry.sampler.type", "DLC_TELEMETRY_SAMPLER_TYPE")
|
||||||
v.BindEnv("telemetry.sampler.ratio", "DLC_TELEMETRY_SAMPLER_RATIO")
|
v.BindEnv("telemetry.sampler.ratio", "DLC_TELEMETRY_SAMPLER_RATIO")
|
||||||
|
|
||||||
@@ -466,6 +481,19 @@ func (c *Config) GetEmailConfig() EmailConfig {
|
|||||||
return c.Auth.Email
|
return c.Auth.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMagicLinkConfig returns the passwordless-auth magic-link parameters
|
||||||
|
// (ADR-0028 Phase A). TTL defaults to 15m, BaseURL to http://localhost:8080.
|
||||||
|
func (c *Config) GetMagicLinkConfig() MagicLinkConfig {
|
||||||
|
out := c.Auth.MagicLink
|
||||||
|
if out.TTL <= 0 {
|
||||||
|
out.TTL = 15 * time.Minute
|
||||||
|
}
|
||||||
|
if out.BaseURL == "" {
|
||||||
|
out.BaseURL = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetJWTTTL returns the JWT TTL
|
// GetJWTTTL returns the JWT TTL
|
||||||
func (c *Config) GetJWTTTL() time.Duration {
|
func (c *Config) GetJWTTTL() time.Duration {
|
||||||
if c.Auth.JWT.TTL == 0 {
|
if c.Auth.JWT.TTL == 0 {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"dance-lessons-coach/pkg/cache"
|
"dance-lessons-coach/pkg/cache"
|
||||||
"dance-lessons-coach/pkg/config"
|
"dance-lessons-coach/pkg/config"
|
||||||
|
"dance-lessons-coach/pkg/email"
|
||||||
"dance-lessons-coach/pkg/greet"
|
"dance-lessons-coach/pkg/greet"
|
||||||
"dance-lessons-coach/pkg/middleware"
|
"dance-lessons-coach/pkg/middleware"
|
||||||
"dance-lessons-coach/pkg/telemetry"
|
"dance-lessons-coach/pkg/telemetry"
|
||||||
@@ -252,6 +253,29 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
|
|||||||
handler := userapi.NewAuthHandler(s.userService, s.userService, s.validator)
|
handler := userapi.NewAuthHandler(s.userService, s.userService, s.validator)
|
||||||
r.Route("/auth", func(r chi.Router) {
|
r.Route("/auth", func(r chi.Router) {
|
||||||
handler.RegisterRoutes(r)
|
handler.RegisterRoutes(r)
|
||||||
|
// Magic-link routes (ADR-0028 Phase A). Mounted only when the
|
||||||
|
// userRepo also implements MagicLinkRepository (PostgresRepository does).
|
||||||
|
if mlRepo, ok := s.userRepo.(user.MagicLinkRepository); ok {
|
||||||
|
emailCfg := s.config.GetEmailConfig()
|
||||||
|
sender := email.NewSMTPSender(email.SMTPConfig{
|
||||||
|
Host: emailCfg.SMTPHost,
|
||||||
|
Port: emailCfg.SMTPPort,
|
||||||
|
Username: emailCfg.SMTPUsername,
|
||||||
|
Password: emailCfg.SMTPPassword,
|
||||||
|
UseTLS: emailCfg.SMTPUseTLS,
|
||||||
|
Timeout: emailCfg.Timeout,
|
||||||
|
})
|
||||||
|
mlHandler := userapi.NewMagicLinkHandler(
|
||||||
|
mlRepo,
|
||||||
|
s.userService,
|
||||||
|
s.userRepo,
|
||||||
|
sender,
|
||||||
|
s.config.GetMagicLinkConfig(),
|
||||||
|
emailCfg.From,
|
||||||
|
s.validator,
|
||||||
|
)
|
||||||
|
mlHandler.RegisterRoutes(r)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Register admin routes
|
// Register admin routes
|
||||||
|
|||||||
274
pkg/user/api/magic_link_handler.go
Normal file
274
pkg/user/api/magic_link_handler.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/config"
|
||||||
|
"dance-lessons-coach/pkg/email"
|
||||||
|
"dance-lessons-coach/pkg/user"
|
||||||
|
"dance-lessons-coach/pkg/validation"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MagicLinkHandler exposes the passwordless-auth endpoints described
|
||||||
|
// in ADR-0028 Phase A : `POST /magic-link/request` and
|
||||||
|
// `GET /magic-link/consume?token=...`.
|
||||||
|
type MagicLinkHandler struct {
|
||||||
|
tokens user.MagicLinkRepository
|
||||||
|
users user.UserService
|
||||||
|
repo user.UserRepository // for GetUserByUsername (sign-up flow)
|
||||||
|
sender email.Sender
|
||||||
|
cfg config.MagicLinkConfig
|
||||||
|
emailFrom string
|
||||||
|
validator *validation.Validator
|
||||||
|
clock func() time.Time
|
||||||
|
newPassword func() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMagicLinkHandler wires the handler. emailFrom must be the From
|
||||||
|
// address (typically cfg.GetEmailConfig().From).
|
||||||
|
func NewMagicLinkHandler(
|
||||||
|
tokens user.MagicLinkRepository,
|
||||||
|
users user.UserService,
|
||||||
|
repo user.UserRepository,
|
||||||
|
sender email.Sender,
|
||||||
|
cfg config.MagicLinkConfig,
|
||||||
|
emailFrom string,
|
||||||
|
validator *validation.Validator,
|
||||||
|
) *MagicLinkHandler {
|
||||||
|
return &MagicLinkHandler{
|
||||||
|
tokens: tokens,
|
||||||
|
users: users,
|
||||||
|
repo: repo,
|
||||||
|
sender: sender,
|
||||||
|
cfg: cfg,
|
||||||
|
emailFrom: emailFrom,
|
||||||
|
validator: validator,
|
||||||
|
clock: time.Now,
|
||||||
|
newPassword: func() (string, error) {
|
||||||
|
var raw [48]byte
|
||||||
|
if _, err := rand.Read(raw[:]); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(raw[:]), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes mounts the two endpoints on the provided router.
|
||||||
|
func (h *MagicLinkHandler) RegisterRoutes(router chi.Router) {
|
||||||
|
router.Post("/magic-link/request", h.handleRequest)
|
||||||
|
router.Get("/magic-link/consume", h.handleConsume)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MagicLinkRequest is the body of POST /magic-link/request.
|
||||||
|
type MagicLinkRequest struct {
|
||||||
|
Email string `json:"email" validate:"required,email,max=255"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MagicLinkResponse is the response shape for both endpoints.
|
||||||
|
type MagicLinkResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest godoc
|
||||||
|
//
|
||||||
|
// @Summary Request a magic link
|
||||||
|
// @Description Generates a passwordless-auth one-time token and emails it. Always 200 to prevent email enumeration.
|
||||||
|
// @Tags API/v1/User
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body MagicLinkRequest true "Email address"
|
||||||
|
// @Success 200 {object} MagicLinkResponse "Email queued (or silently dropped)"
|
||||||
|
// @Failure 400 {object} map[string]string "Invalid request body"
|
||||||
|
// @Router /v1/auth/magic-link/request [post]
|
||||||
|
func (h *MagicLinkHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
var req MagicLinkRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, `{"error":"invalid_request","message":"Invalid JSON request body"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.validator != nil {
|
||||||
|
if err := h.validator.Validate(req); err != nil {
|
||||||
|
h.writeValidationError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := strings.ToLower(strings.TrimSpace(req.Email))
|
||||||
|
|
||||||
|
plain, hashHex, err := user.GenerateMagicLinkToken()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Msg("magic link request: rand failed")
|
||||||
|
http.Error(w, `{"error":"server_error","message":"Failed to generate token"}`, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := h.clock()
|
||||||
|
tok := &user.MagicLinkToken{
|
||||||
|
Email: addr,
|
||||||
|
TokenHash: hashHex,
|
||||||
|
ExpiresAt: now.Add(h.cfg.TTL),
|
||||||
|
}
|
||||||
|
if err := h.tokens.CreateMagicLinkToken(ctx, tok); err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("email", addr).Msg("magic link request: persist failed")
|
||||||
|
writeJSON(w, http.StatusOK, MagicLinkResponse{Message: "If that email is valid, a link has been sent."})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
link := buildMagicLinkURL(h.cfg.BaseURL, plain)
|
||||||
|
subject := "Your sign-in link"
|
||||||
|
bodyText := fmt.Sprintf("Sign in by clicking the link below.\n\n%s\n\nThe link is valid for %s and can only be used once.\nIf you did not request this, ignore this email.\n", link, h.cfg.TTL)
|
||||||
|
bodyHTML := fmt.Sprintf(`<p>Sign in by clicking the link below.</p><p><a href="%s">%s</a></p><p>The link is valid for %s and can only be used once.<br>If you did not request this, ignore this email.</p>`, link, link, h.cfg.TTL)
|
||||||
|
|
||||||
|
msg := email.Message{
|
||||||
|
From: h.emailFrom,
|
||||||
|
To: addr,
|
||||||
|
Subject: subject,
|
||||||
|
BodyText: bodyText,
|
||||||
|
BodyHTML: bodyHTML,
|
||||||
|
}
|
||||||
|
if err := h.sender.Send(ctx, msg); err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("to", addr).Msg("magic link request: email send failed")
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, MagicLinkResponse{Message: "If that email is valid, a link has been sent."})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConsume validates the token, marks it consumed, ensures a
|
||||||
|
// matching User row exists (sign-up on first link), and issues a JWT.
|
||||||
|
//
|
||||||
|
// All failure modes (missing, expired, already-consumed) collapse to a
|
||||||
|
// single 401 to prevent attackers distinguishing them.
|
||||||
|
//
|
||||||
|
// @Summary Consume a magic link
|
||||||
|
// @Description Validates the magic-link token, ensures the user exists (signup-on-first-use), issues a JWT.
|
||||||
|
// @Tags API/v1/User
|
||||||
|
// @Produce json
|
||||||
|
// @Param token query string true "The magic-link token"
|
||||||
|
// @Success 200 {object} MagicLinkResponse "Signed in"
|
||||||
|
// @Failure 400 {object} map[string]string "Missing token"
|
||||||
|
// @Failure 401 {object} map[string]string "Invalid or expired token"
|
||||||
|
// @Router /v1/auth/magic-link/consume [get]
|
||||||
|
func (h *MagicLinkHandler) handleConsume(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
plain := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||||
|
if plain == "" {
|
||||||
|
writeJSONError(w, http.StatusBadRequest, "invalid_request", "missing token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := h.tokens.GetMagicLinkTokenByHash(ctx, user.HashMagicLinkToken(plain))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Msg("magic link consume: lookup failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", "lookup failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tok == nil || tok.ConsumedAt != nil || h.clock().After(tok.ExpiresAt) {
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_token", "magic link is invalid or expired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.tokens.MarkMagicLinkTokenConsumed(ctx, tok.ID, h.clock()); err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Uint("id", tok.ID).Msg("magic link consume: mark failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", "consume failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := h.ensureUser(ctx, tok.Email)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("email", tok.Email).Msg("magic link consume: user upsert failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", "user upsert failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jwt, err := h.users.GenerateJWT(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Msg("magic link consume: JWT generation failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", "jwt generation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, MagicLinkResponse{Message: "signed in", Token: jwt})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUser returns the user keyed on email (stored as Username),
|
||||||
|
// creating them if absent. Newly-created users get a random unguessable
|
||||||
|
// bcrypt-hashed password so the password endpoints stay locked out.
|
||||||
|
func (h *MagicLinkHandler) ensureUser(ctx context.Context, email string) (*user.User, error) {
|
||||||
|
if h.repo != nil {
|
||||||
|
existing, err := h.repo.GetUserByUsername(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawPass, err := h.newPassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("magic link signup rand: %w", err)
|
||||||
|
}
|
||||||
|
hash, err := h.users.HashPassword(ctx, rawPass)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("magic link signup hash: %w", err)
|
||||||
|
}
|
||||||
|
u := &user.User{
|
||||||
|
Username: email,
|
||||||
|
PasswordHash: hash,
|
||||||
|
IsAdmin: false,
|
||||||
|
}
|
||||||
|
if err := h.users.CreateUser(ctx, u); err != nil {
|
||||||
|
return nil, fmt.Errorf("magic link signup create: %w", err)
|
||||||
|
}
|
||||||
|
if h.repo != nil {
|
||||||
|
return h.repo.GetUserByUsername(ctx, email)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MagicLinkHandler) writeValidationError(w http.ResponseWriter, err error) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
var ve *validation.ValidationError
|
||||||
|
if errors.As(err, &ve) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": "validation_failed",
|
||||||
|
"message": "Invalid request data",
|
||||||
|
"details": ve.Messages,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": "validation_failed",
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSONError(w http.ResponseWriter, status int, code, msg string) {
|
||||||
|
writeJSON(w, status, map[string]string{"error": code, "message": msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMagicLinkURL(baseURL, token string) string {
|
||||||
|
base := strings.TrimRight(baseURL, "/")
|
||||||
|
return fmt.Sprintf("%s/api/v1/auth/magic-link/consume?token=%s", base, token)
|
||||||
|
}
|
||||||
371
pkg/user/api/magic_link_handler_test.go
Normal file
371
pkg/user/api/magic_link_handler_test.go
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/config"
|
||||||
|
"dance-lessons-coach/pkg/email"
|
||||||
|
"dance-lessons-coach/pkg/user"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeMLRepo is an in-memory MagicLinkRepository for the handler tests.
|
||||||
|
type fakeMLRepo struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens map[string]*user.MagicLinkToken // key: TokenHash
|
||||||
|
nextID uint
|
||||||
|
failOn string // "create" / "get" / "mark" / "" (none)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeMLRepo() *fakeMLRepo {
|
||||||
|
return &fakeMLRepo{tokens: map[string]*user.MagicLinkToken{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeMLRepo) CreateMagicLinkToken(_ context.Context, t *user.MagicLinkToken) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if r.failOn == "create" {
|
||||||
|
return errors.New("simulated create failure")
|
||||||
|
}
|
||||||
|
r.nextID++
|
||||||
|
t.ID = r.nextID
|
||||||
|
r.tokens[t.TokenHash] = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeMLRepo) GetMagicLinkTokenByHash(_ context.Context, h string) (*user.MagicLinkToken, error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if r.failOn == "get" {
|
||||||
|
return nil, errors.New("simulated get failure")
|
||||||
|
}
|
||||||
|
t, ok := r.tokens[h]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
cp := *t
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeMLRepo) MarkMagicLinkTokenConsumed(_ context.Context, id uint, when time.Time) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if r.failOn == "mark" {
|
||||||
|
return errors.New("simulated mark failure")
|
||||||
|
}
|
||||||
|
for _, t := range r.tokens {
|
||||||
|
if t.ID == id {
|
||||||
|
t.ConsumedAt = &when
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.New("not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeMLRepo) DeleteExpiredMagicLinkTokens(_ context.Context, _ time.Time) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeUserSvc is a minimal user.UserService stub.
|
||||||
|
type fakeUserSvc struct {
|
||||||
|
createdUsers []*user.User
|
||||||
|
jwtForID map[uint]string
|
||||||
|
hashCalls int
|
||||||
|
failOn string // "create" / "hash" / "jwt"
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeUserSvc() *fakeUserSvc { return &fakeUserSvc{jwtForID: map[uint]string{}} }
|
||||||
|
|
||||||
|
func (s *fakeUserSvc) Authenticate(_ context.Context, _, _ string) (*user.User, error) {
|
||||||
|
return nil, errors.New("not used in magic-link tests")
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) GenerateJWT(_ context.Context, u *user.User) (string, error) {
|
||||||
|
if s.failOn == "jwt" {
|
||||||
|
return "", errors.New("simulated jwt failure")
|
||||||
|
}
|
||||||
|
return "jwt-for-user-" + u.Username, nil
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) ValidateJWT(_ context.Context, _ string) (*user.User, error) {
|
||||||
|
return nil, errors.New("not used")
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) AdminAuthenticate(_ context.Context, _ string) (*user.User, error) {
|
||||||
|
return nil, errors.New("not used")
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) AddJWTSecret(_ string, _ bool, _ time.Duration) {}
|
||||||
|
func (s *fakeUserSvc) RotateJWTSecret(_ string) {}
|
||||||
|
func (s *fakeUserSvc) GetJWTSecretByIndex(_ int) (string, bool) { return "", false }
|
||||||
|
func (s *fakeUserSvc) ResetJWTSecrets() {}
|
||||||
|
func (s *fakeUserSvc) StartJWTSecretCleanupLoop(_ context.Context, _ time.Duration) {}
|
||||||
|
func (s *fakeUserSvc) RemoveExpiredJWTSecrets() int { return 0 }
|
||||||
|
func (s *fakeUserSvc) ListJWTSecretsInfo() []user.JWTSecretInfo { return nil }
|
||||||
|
|
||||||
|
func (s *fakeUserSvc) UserExists(_ context.Context, username string) (bool, error) {
|
||||||
|
for _, u := range s.createdUsers {
|
||||||
|
if u.Username == username {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) CreateUser(_ context.Context, u *user.User) error {
|
||||||
|
if s.failOn == "create" {
|
||||||
|
return errors.New("simulated create failure")
|
||||||
|
}
|
||||||
|
u.ID = uint(len(s.createdUsers) + 1)
|
||||||
|
cp := *u
|
||||||
|
s.createdUsers = append(s.createdUsers, &cp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) HashPassword(_ context.Context, p string) (string, error) {
|
||||||
|
s.hashCalls++
|
||||||
|
if s.failOn == "hash" {
|
||||||
|
return "", errors.New("simulated hash failure")
|
||||||
|
}
|
||||||
|
return "hash:" + p, nil
|
||||||
|
}
|
||||||
|
func (s *fakeUserSvc) RequestPasswordReset(_ context.Context, _ string) error { return nil }
|
||||||
|
func (s *fakeUserSvc) CompletePasswordReset(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeUserRepo implements user.UserRepository using fakeUserSvc's slice.
|
||||||
|
type fakeUserRepo struct{ svc *fakeUserSvc }
|
||||||
|
|
||||||
|
func (r *fakeUserRepo) CreateUser(_ context.Context, u *user.User) error {
|
||||||
|
return r.svc.CreateUser(context.Background(), u)
|
||||||
|
}
|
||||||
|
func (r *fakeUserRepo) GetUserByUsername(_ context.Context, name string) (*user.User, error) {
|
||||||
|
for _, u := range r.svc.createdUsers {
|
||||||
|
if u.Username == name {
|
||||||
|
cp := *u
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *fakeUserRepo) GetUserByID(_ context.Context, _ uint) (*user.User, error) { return nil, nil }
|
||||||
|
func (r *fakeUserRepo) UpdateUser(_ context.Context, _ *user.User) error { return nil }
|
||||||
|
func (r *fakeUserRepo) DeleteUser(_ context.Context, _ uint) error { return nil }
|
||||||
|
func (r *fakeUserRepo) AllowPasswordReset(_ context.Context, _ string) error { return nil }
|
||||||
|
func (r *fakeUserRepo) CompletePasswordReset(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *fakeUserRepo) UserExists(_ context.Context, name string) (bool, error) {
|
||||||
|
return r.svc.UserExists(context.Background(), name)
|
||||||
|
}
|
||||||
|
func (r *fakeUserRepo) CheckDatabaseHealth(_ context.Context) error { return nil }
|
||||||
|
|
||||||
|
// recordingSender captures email.Send calls without sending anything.
|
||||||
|
type recordingSender struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
messages []email.Message
|
||||||
|
failNext bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *recordingSender) Send(_ context.Context, m email.Message) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.failNext {
|
||||||
|
return errors.New("simulated send failure")
|
||||||
|
}
|
||||||
|
s.messages = append(s.messages, m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandler(t *testing.T) (*MagicLinkHandler, *fakeMLRepo, *fakeUserSvc, *recordingSender) {
|
||||||
|
t.Helper()
|
||||||
|
mlRepo := newFakeMLRepo()
|
||||||
|
svc := newFakeUserSvc()
|
||||||
|
repo := &fakeUserRepo{svc: svc}
|
||||||
|
sender := &recordingSender{}
|
||||||
|
h := NewMagicLinkHandler(
|
||||||
|
mlRepo, svc, repo, sender,
|
||||||
|
config.MagicLinkConfig{TTL: 15 * time.Minute, BaseURL: "http://test.local"},
|
||||||
|
"noreply@test.local",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return h, mlRepo, svc, sender
|
||||||
|
}
|
||||||
|
|
||||||
|
func mountAndRequest(h *MagicLinkHandler, method, path, body string) *httptest.ResponseRecorder {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
h.RegisterRoutes(r)
|
||||||
|
req := httptest.NewRequest(method, path, strings.NewReader(body))
|
||||||
|
if body != "" {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequest_HappyPath confirms POST /magic-link/request stores a token,
|
||||||
|
// sends an email containing the link, and returns 200 with a generic body.
|
||||||
|
func TestRequest_HappyPath(t *testing.T) {
|
||||||
|
h, mlRepo, _, sender := newHandler(t)
|
||||||
|
|
||||||
|
rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"alice@example.com"}`)
|
||||||
|
require.Equal(t, http.StatusOK, rr.Code)
|
||||||
|
assert.Contains(t, rr.Body.String(), "If that email is valid")
|
||||||
|
|
||||||
|
// One token persisted, email lower-cased.
|
||||||
|
require.Len(t, mlRepo.tokens, 1)
|
||||||
|
for _, tok := range mlRepo.tokens {
|
||||||
|
assert.Equal(t, "alice@example.com", tok.Email)
|
||||||
|
assert.Greater(t, tok.ExpiresAt.Unix(), time.Now().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
// One email sent to the same address, link points at our test base URL.
|
||||||
|
require.Len(t, sender.messages, 1)
|
||||||
|
assert.Equal(t, "alice@example.com", sender.messages[0].To)
|
||||||
|
assert.Contains(t, sender.messages[0].BodyText, "http://test.local/api/v1/auth/magic-link/consume?token=")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequest_NormalizesEmail confirms the email is lower-cased + trimmed.
|
||||||
|
func TestRequest_NormalizesEmail(t *testing.T) {
|
||||||
|
h, mlRepo, _, sender := newHandler(t)
|
||||||
|
rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":" Alice@Example.COM "}`)
|
||||||
|
require.Equal(t, http.StatusOK, rr.Code)
|
||||||
|
require.Len(t, mlRepo.tokens, 1)
|
||||||
|
for _, tok := range mlRepo.tokens {
|
||||||
|
assert.Equal(t, "alice@example.com", tok.Email)
|
||||||
|
}
|
||||||
|
assert.Equal(t, "alice@example.com", sender.messages[0].To)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequest_BadJSON returns 400.
|
||||||
|
func TestRequest_BadJSON(t *testing.T) {
|
||||||
|
h, _, _, _ := newHandler(t)
|
||||||
|
rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `not json`)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequest_PersistFailureStillReturns200 — a DB error must NOT leak
|
||||||
|
// to the user (would let attackers detect storage outages).
|
||||||
|
func TestRequest_PersistFailureStillReturns200(t *testing.T) {
|
||||||
|
h, mlRepo, _, sender := newHandler(t)
|
||||||
|
mlRepo.failOn = "create"
|
||||||
|
rr := mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"bob@example.com"}`)
|
||||||
|
assert.Equal(t, http.StatusOK, rr.Code)
|
||||||
|
// No email was sent because no token was persisted.
|
||||||
|
assert.Empty(t, sender.messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_HappyPath_NewUser exercises sign-up-on-first-link.
|
||||||
|
func TestConsume_HappyPath_NewUser(t *testing.T) {
|
||||||
|
h, mlRepo, svc, _ := newHandler(t)
|
||||||
|
|
||||||
|
// Seed one token by going through the request flow.
|
||||||
|
mountAndRequest(h, http.MethodPost, "/magic-link/request", `{"email":"alice@example.com"}`)
|
||||||
|
require.Len(t, mlRepo.tokens, 1)
|
||||||
|
|
||||||
|
// We need the plaintext to consume — derive it from the only token in the
|
||||||
|
// repo by reverse trick : the request handler doesn't expose it. So we
|
||||||
|
// drive consume with a fresh known-plaintext we put into the repo
|
||||||
|
// directly.
|
||||||
|
plain, hashHex, err := user.GenerateMagicLinkToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
mlRepo.tokens = map[string]*user.MagicLinkToken{
|
||||||
|
hashHex: {ID: 99, Email: "alice@example.com", TokenHash: hashHex, ExpiresAt: time.Now().Add(5 * time.Minute)},
|
||||||
|
}
|
||||||
|
mlRepo.nextID = 99
|
||||||
|
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "")
|
||||||
|
require.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
|
||||||
|
|
||||||
|
var resp MagicLinkResponse
|
||||||
|
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "signed in", resp.Message)
|
||||||
|
assert.Equal(t, "jwt-for-user-alice@example.com", resp.Token)
|
||||||
|
|
||||||
|
// User was created.
|
||||||
|
require.Len(t, svc.createdUsers, 1)
|
||||||
|
assert.Equal(t, "alice@example.com", svc.createdUsers[0].Username)
|
||||||
|
assert.NotEmpty(t, svc.createdUsers[0].PasswordHash, "passwordless user must still have a non-empty hash (random unguessable value)")
|
||||||
|
assert.Equal(t, 1, svc.hashCalls)
|
||||||
|
|
||||||
|
// Token marked consumed.
|
||||||
|
for _, tok := range mlRepo.tokens {
|
||||||
|
require.NotNil(t, tok.ConsumedAt, "consumed_at must be set after consume")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_HappyPath_ExistingUser confirms no new user is created
|
||||||
|
// when the email is already known.
|
||||||
|
func TestConsume_HappyPath_ExistingUser(t *testing.T) {
|
||||||
|
h, mlRepo, svc, _ := newHandler(t)
|
||||||
|
|
||||||
|
// Pre-seed the user.
|
||||||
|
require.NoError(t, svc.CreateUser(context.Background(), &user.User{Username: "carol@example.com", PasswordHash: "x"}))
|
||||||
|
require.Len(t, svc.createdUsers, 1)
|
||||||
|
preCount := len(svc.createdUsers)
|
||||||
|
|
||||||
|
plain, hashHex, err := user.GenerateMagicLinkToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
mlRepo.tokens[hashHex] = &user.MagicLinkToken{ID: 1, Email: "carol@example.com", TokenHash: hashHex, ExpiresAt: time.Now().Add(5 * time.Minute)}
|
||||||
|
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "")
|
||||||
|
require.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
|
||||||
|
|
||||||
|
// No new user.
|
||||||
|
assert.Len(t, svc.createdUsers, preCount)
|
||||||
|
assert.Equal(t, 0, svc.hashCalls, "no hash call when user exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_MissingToken returns 400.
|
||||||
|
func TestConsume_MissingToken(t *testing.T) {
|
||||||
|
h, _, _, _ := newHandler(t)
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume", "")
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_UnknownToken returns 401 (single generic shape).
|
||||||
|
func TestConsume_UnknownToken(t *testing.T) {
|
||||||
|
h, _, _, _ := newHandler(t)
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token=neverissued", "")
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_ExpiredToken returns 401.
|
||||||
|
func TestConsume_ExpiredToken(t *testing.T) {
|
||||||
|
h, mlRepo, _, _ := newHandler(t)
|
||||||
|
plain, hashHex, err := user.GenerateMagicLinkToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
mlRepo.tokens[hashHex] = &user.MagicLinkToken{
|
||||||
|
ID: 1, Email: "x@example.com", TokenHash: hashHex,
|
||||||
|
ExpiresAt: time.Now().Add(-1 * time.Minute), // already expired
|
||||||
|
}
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "")
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConsume_AlreadyConsumed returns 401 — single-use guarantee.
|
||||||
|
func TestConsume_AlreadyConsumed(t *testing.T) {
|
||||||
|
h, mlRepo, _, _ := newHandler(t)
|
||||||
|
plain, hashHex, err := user.GenerateMagicLinkToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
now := time.Now()
|
||||||
|
mlRepo.tokens[hashHex] = &user.MagicLinkToken{
|
||||||
|
ID: 1, Email: "x@example.com", TokenHash: hashHex,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute), ConsumedAt: &now,
|
||||||
|
}
|
||||||
|
rr := mountAndRequest(h, http.MethodGet, "/magic-link/consume?token="+plain, "")
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildMagicLinkURL_TrailingSlash exercises the small helper.
|
||||||
|
func TestBuildMagicLinkURL_TrailingSlash(t *testing.T) {
|
||||||
|
got := buildMagicLinkURL("http://x.local/", "abc")
|
||||||
|
assert.Equal(t, "http://x.local/api/v1/auth/magic-link/consume?token=abc", got)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user