♻️ refactor(server): split AuthMiddleware into Optional/Required (RFC 6750 + ISP narrow interface) #91
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"dance-lessons-coach/pkg/auth"
|
"dance-lessons-coach/pkg/auth"
|
||||||
"dance-lessons-coach/pkg/user"
|
"dance-lessons-coach/pkg/user"
|
||||||
@@ -10,54 +11,123 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthMiddleware handles JWT authentication and adds user to context
|
// tokenValidator is the narrow interface AuthMiddleware needs from
|
||||||
|
// user.AuthService — only JWT validation. ISP : avoid pulling the full
|
||||||
|
// fat AuthService interface (12+ methods) into the middleware.
|
||||||
|
type tokenValidator interface {
|
||||||
|
ValidateJWT(ctx context.Context, token string) (*user.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const bearerPrefix = "Bearer "
|
||||||
|
|
||||||
|
// firstWord returns the first whitespace-separated word of s, or s itself
|
||||||
|
// if there's no whitespace. Used for log-safe scheme extraction.
|
||||||
|
func firstWord(s string) string {
|
||||||
|
if i := strings.IndexAny(s, " \t"); i >= 0 {
|
||||||
|
return s[:i]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBearerToken pulls the bearer token out of an Authorization header.
|
||||||
|
// Returns ("", false) if absent or malformed. RFC 6750 specifies the
|
||||||
|
// scheme is case-insensitive ; we honor that.
|
||||||
|
func extractBearerToken(authHeader string) (string, bool) {
|
||||||
|
if authHeader == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
// Case-insensitive prefix match for "Bearer "
|
||||||
|
if len(authHeader) < len(bearerPrefix) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return authHeader[len(bearerPrefix):], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthMiddleware (existing type kept for backwards compatibility ; the
|
||||||
|
// constructor now returns a struct that exposes BOTH the optional and
|
||||||
|
// required handlers). The legacy .Middleware method delegates to
|
||||||
|
// OptionalHandler so existing wiring (server.go r.Use(authMiddleware.Middleware))
|
||||||
|
// keeps working.
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
authService user.AuthService
|
validator tokenValidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates a new authentication middleware
|
func NewAuthMiddleware(validator tokenValidator) *AuthMiddleware {
|
||||||
func NewAuthMiddleware(authService user.AuthService) *AuthMiddleware {
|
return &AuthMiddleware{validator: validator}
|
||||||
return &AuthMiddleware{
|
|
||||||
authService: authService,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns the authentication middleware function
|
// OptionalHandler wraps next so :
|
||||||
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
// - no Authorization header : pass through, no user in context
|
||||||
|
// - malformed header : pass through, log Trace, no user in context
|
||||||
|
// - invalid JWT : pass through, log Trace, no user in context
|
||||||
|
// - valid JWT : pass through, user injected via auth.UserContextKey
|
||||||
|
//
|
||||||
|
// Use this on endpoints where auth is "nice to have" — the handler is
|
||||||
|
// expected to call auth.GetAuthenticatedUserFromContext and decide.
|
||||||
|
func (m *AuthMiddleware) OptionalHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
token, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||||
// Extract Authorization header
|
if !ok {
|
||||||
authHeader := r.Header.Get("Authorization")
|
// Header absent or malformed — log size only (Q-064 : no raw value).
|
||||||
if authHeader == "" {
|
if h := r.Header.Get("Authorization"); h != "" {
|
||||||
// No authorization header, pass through with no user
|
log.Trace().Ctx(ctx).
|
||||||
|
Int("auth_header_len", len(h)).
|
||||||
|
Str("scheme_word", firstWord(h)).
|
||||||
|
Msg("Optional auth : malformed Authorization header")
|
||||||
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
validatedUser, err := m.validator.ValidateJWT(ctx, token)
|
||||||
// Extract token from "Bearer <token>" format
|
|
||||||
const bearerPrefix = "Bearer "
|
|
||||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
|
||||||
log.Trace().Ctx(ctx).Str("auth_header", authHeader).Msg("Invalid authorization header format")
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token := authHeader[len(bearerPrefix):]
|
|
||||||
|
|
||||||
// Validate JWT token
|
|
||||||
validatedUser, err := m.authService.ValidateJWT(ctx, token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().Ctx(ctx).Err(err).Msg("JWT validation failed")
|
log.Trace().Ctx(ctx).Err(err).Msg("Optional auth : JWT validation failed")
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add user to context
|
|
||||||
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
|
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
|
||||||
r = r.WithContext(ctxWithUser)
|
next.ServeHTTP(w, r.WithContext(ctxWithUser))
|
||||||
|
|
||||||
// Continue to next handler
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequiredHandler wraps next so :
|
||||||
|
// - no header / malformed / invalid JWT : 401 Unauthorized + WWW-Authenticate: Bearer
|
||||||
|
// - valid JWT : pass through, user injected via auth.UserContextKey
|
||||||
|
//
|
||||||
|
// Use this on endpoints where unauthenticated access is forbidden.
|
||||||
|
// Conforms to RFC 6750.
|
||||||
|
func (m *AuthMiddleware) RequiredHandler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
token, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||||
|
if !ok {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach"`)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"missing or malformed Authorization header"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
validatedUser, err := m.validator.ValidateJWT(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Bearer realm="dance-lessons-coach", error="invalid_token"`)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxWithUser := context.WithValue(ctx, auth.UserContextKey, validatedUser)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctxWithUser))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware is the legacy method — preserved for backwards compatibility.
|
||||||
|
// Delegates to OptionalHandler. New wiring should call OptionalHandler or
|
||||||
|
// RequiredHandler explicitly.
|
||||||
|
//
|
||||||
|
// Deprecated: use OptionalHandler() directly.
|
||||||
|
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||||
|
return m.OptionalHandler(next)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user