From ab4918adfc5b74ae397b17b347e22d7b2eccc0a8 Mon Sep 17 00:00:00 2001 From: Gabriel Radureau Date: Wed, 6 May 2026 06:58:25 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20test(server):=20unit=20tests=20f?= =?UTF-8?q?or=20AuthMiddleware=20Optional/Required=20handlers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe --- pkg/server/middleware_test.go | 181 ++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 pkg/server/middleware_test.go diff --git a/pkg/server/middleware_test.go b/pkg/server/middleware_test.go new file mode 100644 index 0000000..30973ac --- /dev/null +++ b/pkg/server/middleware_test.go @@ -0,0 +1,181 @@ +package server + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "dance-lessons-coach/pkg/auth" + "dance-lessons-coach/pkg/user" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeTokenValidator is a minimal tokenValidator stub. +type fakeTokenValidator struct { + validUser *user.User + err error + seen string // captures the last token passed in +} + +func (f *fakeTokenValidator) ValidateJWT(ctx context.Context, token string) (*user.User, error) { + f.seen = token + if f.err != nil { + return nil, f.err + } + return f.validUser, nil +} + +// nextHandler returns 200 with a flag in body indicating whether a user +// was injected into context. +func nextHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, ok := auth.GetAuthenticatedUserFromContext(r.Context()) + if ok && u != nil { + w.Header().Set("X-User", u.Username) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) +} + +func TestOptionalHandler_NoHeader_PassesThrough(t *testing.T) { + fv := &fakeTokenValidator{} + mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Empty(t, rec.Header().Get("X-User"), "no user expected when no Authorization header") + assert.Empty(t, fv.seen, "validator should not have been called") +} + +func TestOptionalHandler_MalformedHeader_PassesThrough(t *testing.T) { + fv := &fakeTokenValidator{} + mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "Basic xxx") + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Empty(t, rec.Header().Get("X-User")) + assert.Empty(t, fv.seen, "validator should not have been called for non-Bearer scheme") +} + +func TestOptionalHandler_BearerCaseInsensitive(t *testing.T) { + fv := &fakeTokenValidator{validUser: &user.User{Username: "alice"}} + mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "bearer abc123") // lowercase + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "alice", rec.Header().Get("X-User"), "case-insensitive Bearer per RFC 6750") + assert.Equal(t, "abc123", fv.seen) +} + +func TestOptionalHandler_InvalidJWT_PassesThrough(t *testing.T) { + fv := &fakeTokenValidator{err: errors.New("bad signature")} + mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "Bearer xxx") + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "optional auth never returns 401") + assert.Empty(t, rec.Header().Get("X-User")) +} + +func TestOptionalHandler_ValidJWT_InjectsUser(t *testing.T) { + fv := &fakeTokenValidator{validUser: &user.User{ID: 7, Username: "bob"}} + mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "Bearer goodtoken") + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "bob", rec.Header().Get("X-User")) + assert.Equal(t, "goodtoken", fv.seen) +} + +func TestRequiredHandler_NoHeader_Returns401(t *testing.T) { + fv := &fakeTokenValidator{} + mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer", "RFC 6750 challenge header") + assert.Contains(t, rec.Body.String(), "unauthorized") +} + +func TestRequiredHandler_InvalidJWT_Returns401WithErrorTag(t *testing.T) { + fv := &fakeTokenValidator{err: errors.New("expired")} + mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "Bearer xxx") + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + assert.Contains(t, rec.Header().Get("WWW-Authenticate"), `error="invalid_token"`) +} + +func TestRequiredHandler_ValidJWT_PassesThrough(t *testing.T) { + fv := &fakeTokenValidator{validUser: &user.User{Username: "carol"}} + mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Authorization", "Bearer goodtoken") + rec := httptest.NewRecorder() + mw.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "carol", rec.Header().Get("X-User")) +} + +func TestExtractBearerToken_EdgeCases(t *testing.T) { + cases := []struct { + in string + out string + ok bool + }{ + {"", "", false}, + {"Bearer ", "", true}, // empty token, but matches the prefix — caller decides + {"Bearer xxx", "xxx", true}, + {"bearer xxx", "xxx", true}, // case-insensitive + {"BEARER xxx", "xxx", true}, + {"Basic xxx", "", false}, + {"Bearer", "", false}, // no separating space + {"Bear", "", false}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + tok, ok := extractBearerToken(c.in) + assert.Equal(t, c.ok, ok) + assert.Equal(t, c.out, tok) + }) + } +} + +func TestFirstWord(t *testing.T) { + assert.Equal(t, "Bearer", firstWord("Bearer xxx")) + assert.Equal(t, "Basic", firstWord("Basic\tabc")) + assert.Equal(t, "Token", firstWord("Token")) + assert.Equal(t, "", firstWord("")) +}