♻️ refactor(server): split AuthMiddleware into Optional/Required (RFC 6750 + ISP narrow interface)
Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit was merged in pull request #91.
This commit is contained in:
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"dance-lessons-coach/pkg/auth"
|
||||
"dance-lessons-coach/pkg/user"
|
||||
@@ -10,54 +11,123 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthMiddleware handles JWT authentication and adds user to context
|
||||
// tokenValidator is the narrow interface AuthMiddleware needs from
|
||||
// user.AuthService — only JWT validation. ISP : avoid pulling the full
|
||||
// fat AuthService interface (12+ methods) into the middleware.
|
||||
type tokenValidator interface {
|
||||
ValidateJWT(ctx context.Context, token string) (*user.User, error)
|
||||
}
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
|
||||
// firstWord returns the first whitespace-separated word of s, or s itself
|
||||
// if there's no whitespace. Used for log-safe scheme extraction.
|
||||
func firstWord(s string) string {
|
||||
if i := strings.IndexAny(s, " \t"); i >= 0 {
|
||||
return s[:i]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// extractBearerToken pulls the bearer token out of an Authorization header.
|
||||
// Returns ("", false) if absent or malformed. RFC 6750 specifies the
|
||||
// scheme is case-insensitive ; we honor that.
|
||||
func extractBearerToken(authHeader string) (string, bool) {
|
||||
if authHeader == "" {
|
||||
return "", false
|
||||
}
|
||||
// Case-insensitive prefix match for "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
return authHeader[len(bearerPrefix):], true
|
||||
}
|
||||
|
||||
// AuthMiddleware (existing type kept for backwards compatibility ; the
|
||||
// constructor now returns a struct that exposes BOTH the optional and
|
||||
// required handlers). The legacy .Middleware method delegates to
|
||||
// OptionalHandler so existing wiring (server.go r.Use(authMiddleware.Middleware))
|
||||
// keeps working.
|
||||
type AuthMiddleware struct {
|
||||
authService user.AuthService
|
||||
validator tokenValidator
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new authentication middleware
|
||||
func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
authService: authService,
|
||||
}
|
||||
func NewAuthMiddleware(validator tokenValidator) *AuthMiddleware {
|
||||
return &AuthMiddleware{validator: validator}
|
||||
}
|
||||
|
||||
// Middleware returns the authentication middleware function
|
||||
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
// OptionalHandler wraps next so :
|
||||
// - no Authorization header : pass through, no user in context
|
||||
// - malformed header : pass through, log Trace, no user in context
|
||||
// - invalid JWT : pass through, log Trace, no user in context
|
||||
// - valid JWT : pass through, user injected via auth.UserContextKey
|
||||
//
|
||||
// Use this on endpoints where auth is "nice to have" — the handler is
|
||||
// expected to call auth.GetAuthenticatedUserFromContext and decide.
|
||||
func (m *AuthMiddleware) OptionalHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Extract Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// No authorization header, pass through with no user
|
||||
token, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
// Header absent or malformed — log size only (Q-064 : no raw value).
|
||||
if h := r.Header.Get("Authorization"); h != "" {
|
||||
log.Trace().Ctx(ctx).
|
||||
Int("auth_header_len", len(h)).
|
||||
Str("scheme_word", firstWord(h)).
|
||||
Msg("Optional auth : malformed Authorization header")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from "Bearer <token>" format
|
||||
const bearerPrefix = "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
token := authHeader[len(bearerPrefix):]
|
||||
|
||||
// Validate JWT token
|
||||
validatedUser, err := m.authService.ValidateJWT(ctx, token)
|
||||
validatedUser, err := m.validator.ValidateJWT(ctx, token)
|
||||
if err != nil {
|
||||
log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed")
|
||||
log.Trace().Ctx(ctx).Err(err).Msg("Optional auth : JWT validation failed")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user to context
|
||||
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
|
||||
r = r.WithContext(ctxWithUser)
|
||||
|
||||
// Continue to next handler
|
||||
next.ServeHTTP(w, r)
|
||||
next.ServeHTTP(w, r.WithContext(ctxWithUser))
|
||||
})
|
||||
}
|
||||
|
||||
// RequiredHandler wraps next so :
|
||||
// - no header / malformed / invalid JWT : 401 Unauthorized + WWW-Authenticate: Bearer
|
||||
// - valid JWT : pass through, user injected via auth.UserContextKey
|
||||
//
|
||||
// Use this on endpoints where unauthenticated access is forbidden.
|
||||
// Conforms to RFC 6750.
|
||||
func (m *AuthMiddleware) RequiredHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
token, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"missing or malformed Authorization header"}`))
|
||||
return
|
||||
}
|
||||
validatedUser, err := m.validator.ValidateJWT(ctx, token)
|
||||
if err != nil {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach", error="invalid_token"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token"}`))
|
||||
return
|
||||
}
|
||||
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
|
||||
next.ServeHTTP(w, r.WithContext(ctxWithUser))
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware is the legacy method — preserved for backwards compatibility.
|
||||
// Delegates to OptionalHandler. New wiring should call OptionalHandler or
|
||||
// RequiredHandler explicitly.
|
||||
//
|
||||
// Deprecated: use OptionalHandler() directly.
|
||||
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return m.OptionalHandler(next)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user