♻️ refactor(server): split AuthMiddleware into Optional/Required (RFC 6750 + ISP narrow interface) #91

Manually merged
arcodange merged 1 commits from vibe/batch-pr-a1-split-auth-middlewares into main 2026-05-06 06:56:41 +02:00
Showing only changes of commit 17de45563d - Show all commits

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/auth"
"dance-lessons-coach/pkg/user" "dance-lessons-coach/pkg/user"
@@ -10,54 +11,123 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// AuthMiddleware handles JWT authentication and adds user to context // tokenValidator is the narrow interface AuthMiddleware needs from
type AuthMiddleware struct { // user.AuthService — only JWT validation. ISP : avoid pulling the full
authService user.AuthService // fat AuthService interface (12+ methods) into the middleware.
type tokenValidator interface {
ValidateJWT(ctx context.Context, token string) (*user.User, error)
} }
// NewAuthMiddleware creates a new authentication middleware const bearerPrefix = "Bearer "
func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware {
return &AuthMiddleware{ // firstWord returns the first whitespace-separated word of s, or s itself
authService: authService, // if there's no whitespace. Used for log-safe scheme extraction.
func firstWord(s string) string {
if i := strings.IndexAny(s, " \t"); i >= 0 {
return s[:i]
} }
return s
} }
// Middleware returns the authentication middleware function // extractBearerToken pulls the bearer token out of an Authorization header.
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { // Returns ("", false) if absent or malformed. RFC 6750 specifies the
// scheme is case-insensitive ; we honor that.
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" {
return "", false
}
// Case-insensitive prefix match for "Bearer "
if len(authHeader) < len(bearerPrefix) {
return "", false
}
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
return "", false
}
return authHeader[len(bearerPrefix):], true
}
// AuthMiddleware (existing type kept for backwards compatibility ; the
// constructor now returns a struct that exposes BOTH the optional and
// required handlers). The legacy .Middleware method delegates to
// OptionalHandler so existing wiring (server.go r.Use(authMiddleware.Middleware))
// keeps working.
type AuthMiddleware struct {
validator tokenValidator
}
func NewAuthMiddleware(validator tokenValidator) *AuthMiddleware {
return &AuthMiddleware{validator: validator}
}
// OptionalHandler wraps next so :
// - no Authorization header : pass through, no user in context
// - malformed header : pass through, log Trace, no user in context
// - invalid JWT : pass through, log Trace, no user in context
// - valid JWT : pass through, user injected via auth.UserContextKey
//
// Use this on endpoints where auth is "nice to have" — the handler is
// expected to call auth.GetAuthenticatedUserFromContext and decide.
func (m *AuthMiddleware) OptionalHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
token, ok := extractBearerToken(r.Header.Get("Authorization"))
// Extract Authorization header if !ok {
authHeader := r.Header.Get("Authorization") // Header absent or malformed — log size only (Q-064 : no raw value).
if authHeader == "" { if h := r.Header.Get("Authorization"); h != "" {
// No authorization header, pass through with no user log.Trace().Ctx(ctx).
Int("auth_header_len", len(h)).
Str("scheme_word", firstWord(h)).
Msg("Optional auth : malformed Authorization header")
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
validatedUser, err := m.validator.ValidateJWT(ctx, token)
// Extract token from "Bearer <token>" format
const bearerPrefix = "Bearer "
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format")
next.ServeHTTP(w, r)
return
}
token := authHeader[len(bearerPrefix):]
// Validate JWT token
validatedUser, err := m.authService.ValidateJWT(ctx, token)
if err != nil { if err != nil {
log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed") log.Trace().Ctx(ctx).Err(err).Msg("Optional auth : JWT validation failed")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
// Add user to context
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser) ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
r = r.WithContext(ctxWithUser) next.ServeHTTP(w, r.WithContext(ctxWithUser))
// Continue to next handler
next.ServeHTTP(w, r)
}) })
} }
// RequiredHandler wraps next so :
// - no header / malformed / invalid JWT : 401 Unauthorized + WWW-Authenticate: Bearer
// - valid JWT : pass through, user injected via auth.UserContextKey
//
// Use this on endpoints where unauthenticated access is forbidden.
// Conforms to RFC 6750.
func (m *AuthMiddleware) RequiredHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, ok := extractBearerToken(r.Header.Get("Authorization"))
if !ok {
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"missing or malformed Authorization header"}`))
return
}
validatedUser, err := m.validator.ValidateJWT(ctx, token)
if err != nil {
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach", error="invalid_token"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token"}`))
return
}
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
next.ServeHTTP(w, r.WithContext(ctxWithUser))
})
}
// Middleware is the legacy method — preserved for backwards compatibility.
// Delegates to OptionalHandler. New wiring should call OptionalHandler or
// RequiredHandler explicitly.
//
// Deprecated: use OptionalHandler() directly.
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return m.OptionalHandler(next)
}