♻️ refactor(server): split AuthMiddleware into Optional/Required (RFC 6750 + ISP narrow interface)
Some checks failed
CI/CD Pipeline / Build Docker Cache (push) Successful in 15s
CI/CD Pipeline / Trigger Docker Push (push) Has been cancelled
CI/CD Pipeline / CI Pipeline (push) Has been cancelled

Generated by Mistral Vibe.
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit was merged in pull request #91.
This commit is contained in:
2026-05-06 06:56:02 +02:00
parent e5a1979b1f
commit 17de45563d

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/auth"
"dance-lessons-coach/pkg/user" "dance-lessons-coach/pkg/user"
@@ -10,54 +11,123 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// AuthMiddleware handles JWT authentication and adds user to context // tokenValidator is the narrow interface AuthMiddleware needs from
// user.AuthService — only JWT validation. ISP : avoid pulling the full
// fat AuthService interface (12+ methods) into the middleware.
type tokenValidator interface {
ValidateJWT(ctx context.Context, token string) (*user.User, error)
}
const bearerPrefix = "Bearer "
// firstWord returns the first whitespace-separated word of s, or s itself
// if there's no whitespace. Used for log-safe scheme extraction.
func firstWord(s string) string {
if i := strings.IndexAny(s, " \t"); i >= 0 {
return s[:i]
}
return s
}
// extractBearerToken pulls the bearer token out of an Authorization header.
// Returns ("", false) if absent or malformed. RFC 6750 specifies the
// scheme is case-insensitive ; we honor that.
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" {
return "", false
}
// Case-insensitive prefix match for "Bearer "
if len(authHeader) < len(bearerPrefix) {
return "", false
}
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
return "", false
}
return authHeader[len(bearerPrefix):], true
}
// AuthMiddleware (existing type kept for backwards compatibility ; the
// constructor now returns a struct that exposes BOTH the optional and
// required handlers). The legacy .Middleware method delegates to
// OptionalHandler so existing wiring (server.go r.Use(authMiddleware.Middleware))
// keeps working.
type AuthMiddleware struct { type AuthMiddleware struct {
authService user.AuthService validator tokenValidator
} }
// NewAuthMiddleware creates a new authentication middleware func NewAuthMiddleware(validator tokenValidator) *AuthMiddleware {
func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware { return &AuthMiddleware{validator: validator}
return &AuthMiddleware{
authService: authService,
}
} }
// Middleware returns the authentication middleware function // OptionalHandler wraps next so :
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { // - no Authorization header : pass through, no user in context
// - malformed header : pass through, log Trace, no user in context
// - invalid JWT : pass through, log Trace, no user in context
// - valid JWT : pass through, user injected via auth.UserContextKey
//
// Use this on endpoints where auth is "nice to have" — the handler is
// expected to call auth.GetAuthenticatedUserFromContext and decide.
func (m *AuthMiddleware) OptionalHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
token, ok := extractBearerToken(r.Header.Get("Authorization"))
// Extract Authorization header if !ok {
authHeader := r.Header.Get("Authorization") // Header absent or malformed — log size only (Q-064 : no raw value).
if authHeader == "" { if h := r.Header.Get("Authorization"); h != "" {
// No authorization header, pass through with no user log.Trace().Ctx(ctx).
Int("auth_header_len", len(h)).
Str("scheme_word", firstWord(h)).
Msg("Optional auth : malformed Authorization header")
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
validatedUser, err := m.validator.ValidateJWT(ctx, token)
// Extract token from "Bearer <token>" format
const bearerPrefix = "Bearer "
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format")
next.ServeHTTP(w, r)
return
}
token := authHeader[len(bearerPrefix):]
// Validate JWT token
validatedUser, err := m.authService.ValidateJWT(ctx, token)
if err != nil { if err != nil {
log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed") log.Trace().Ctx(ctx).Err(err).Msg("Optional auth : JWT validation failed")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
// Add user to context
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser) ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
r = r.WithContext(ctxWithUser) next.ServeHTTP(w, r.WithContext(ctxWithUser))
// Continue to next handler
next.ServeHTTP(w, r)
}) })
} }
// RequiredHandler wraps next so :
// - no header / malformed / invalid JWT : 401 Unauthorized + WWW-Authenticate: Bearer
// - valid JWT : pass through, user injected via auth.UserContextKey
//
// Use this on endpoints where unauthenticated access is forbidden.
// Conforms to RFC 6750.
func (m *AuthMiddleware) RequiredHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, ok := extractBearerToken(r.Header.Get("Authorization"))
if !ok {
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"missing or malformed Authorization header"}`))
return
}
validatedUser, err := m.validator.ValidateJWT(ctx, token)
if err != nil {
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach", error="invalid_token"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token"}`))
return
}
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
next.ServeHTTP(w, r.WithContext(ctxWithUser))
})
}
// Middleware is the legacy method — preserved for backwards compatibility.
// Delegates to OptionalHandler. New wiring should call OptionalHandler or
// RequiredHandler explicitly.
//
// Deprecated: use OptionalHandler() directly.
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return m.OptionalHandler(next)
}