diff --git a/pkg/server/middleware.go b/pkg/server/middleware.go index 3c2c732..656f9a2 100644 --- a/pkg/server/middleware.go +++ b/pkg/server/middleware.go @@ -3,6 +3,7 @@ package server import ( "context" "net/http" + "strings" "dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/user" @@ -10,54 +11,123 @@ import ( "github.com/rs/zerolog/log" ) -// AuthMiddleware handles JWT authentication and adds user to context -type AuthMiddleware struct { - authService user.AuthService +// tokenValidator is the narrow interface AuthMiddleware needs from +// user.AuthService — only JWT validation. ISP : avoid pulling the full +// fat AuthService interface (12+ methods) into the middleware. +type tokenValidator interface { + ValidateJWT(ctx context.Context, token string) (*user.User, error) } -// NewAuthMiddleware creates a new authentication middleware -func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware { - return &AuthMiddleware{ - authService: authService, +const bearerPrefix = "Bearer " + +// firstWord returns the first whitespace-separated word of s, or s itself +// if there's no whitespace. Used for log-safe scheme extraction. +func firstWord(s string) string { + if i := strings.IndexAny(s, " \t"); i >= 0 { + return s[:i] } + return s } -// Middleware returns the authentication middleware function -func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { +// extractBearerToken pulls the bearer token out of an Authorization header. +// Returns ("", false) if absent or malformed. RFC 6750 specifies the +// scheme is case-insensitive ; we honor that. +func extractBearerToken(authHeader string) (string, bool) { + if authHeader == "" { + return "", false + } + // Case-insensitive prefix match for "Bearer " + if len(authHeader) < len(bearerPrefix) { + return "", false + } + if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) { + return "", false + } + return authHeader[len(bearerPrefix):], true +} + +// AuthMiddleware (existing type kept for backwards compatibility ; the +// constructor now returns a struct that exposes BOTH the optional and +// required handlers). The legacy .Middleware method delegates to +// OptionalHandler so existing wiring (server.go r.Use(authMiddleware.Middleware)) +// keeps working. +type AuthMiddleware struct { + validator tokenValidator +} + +func NewAuthMiddleware(validator tokenValidator) *AuthMiddleware { + return &AuthMiddleware{validator: validator} +} + +// OptionalHandler wraps next so : +// - no Authorization header : pass through, no user in context +// - malformed header : pass through, log Trace, no user in context +// - invalid JWT : pass through, log Trace, no user in context +// - valid JWT : pass through, user injected via auth.UserContextKey +// +// Use this on endpoints where auth is "nice to have" — the handler is +// expected to call auth.GetAuthenticatedUserFromContext and decide. +func (m *AuthMiddleware) OptionalHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - - // Extract Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - // No authorization header, pass through with no user + token, ok := extractBearerToken(r.Header.Get("Authorization")) + if !ok { + // Header absent or malformed — log size only (Q-064 : no raw value). + if h := r.Header.Get("Authorization"); h != "" { + log.Trace().Ctx(ctx). + Int("auth_header_len", len(h)). + Str("scheme_word", firstWord(h)). + Msg("Optional auth : malformed Authorization header") + } next.ServeHTTP(w, r) return } - - // Extract token from "Bearer " format - const bearerPrefix = "Bearer " - if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { - log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format") - next.ServeHTTP(w, r) - return - } - - token := authHeader[len(bearerPrefix):] - - // Validate JWT token - validatedUser, err := m.authService.ValidateJWT(ctx, token) + validatedUser, err := m.validator.ValidateJWT(ctx, token) if err != nil { - log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed") + log.Trace().Ctx(ctx).Err(err).Msg("Optional auth : JWT validation failed") next.ServeHTTP(w, r) return } - - // Add user to context ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser) - r = r.WithContext(ctxWithUser) - - // Continue to next handler - next.ServeHTTP(w, r) + next.ServeHTTP(w, r.WithContext(ctxWithUser)) }) } + +// RequiredHandler wraps next so : +// - no header / malformed / invalid JWT : 401 Unauthorized + WWW-Authenticate: Bearer +// - valid JWT : pass through, user injected via auth.UserContextKey +// +// Use this on endpoints where unauthenticated access is forbidden. +// Conforms to RFC 6750. +func (m *AuthMiddleware) RequiredHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token, ok := extractBearerToken(r.Header.Get("Authorization")) + if !ok { + w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized","message":"missing or malformed Authorization header"}`)) + return + } + validatedUser, err := m.validator.ValidateJWT(ctx, token) + if err != nil { + w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach", error="invalid_token"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token"}`)) + return + } + ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser) + next.ServeHTTP(w, r.WithContext(ctxWithUser)) + }) +} + +// Middleware is the legacy method — preserved for backwards compatibility. +// Delegates to OptionalHandler. New wiring should call OptionalHandler or +// RequiredHandler explicitly. +// +// Deprecated: use OptionalHandler() directly. +func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { + return m.OptionalHandler(next) +}