🧪 test(server): unit tests for AuthMiddleware Optional/Required handlers #92
181
pkg/server/middleware_test.go
Normal file
181
pkg/server/middleware_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"dance-lessons-coach/pkg/auth"
|
||||
"dance-lessons-coach/pkg/user"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeTokenValidator is a minimal tokenValidator stub.
|
||||
type fakeTokenValidator struct {
|
||||
validUser *user.User
|
||||
err error
|
||||
seen string // captures the last token passed in
|
||||
}
|
||||
|
||||
func (f *fakeTokenValidator) ValidateJWT(ctx context.Context, token string) (*user.User, error) {
|
||||
f.seen = token
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return f.validUser, nil
|
||||
}
|
||||
|
||||
// nextHandler returns 200 with a flag in body indicating whether a user
|
||||
// was injected into context.
|
||||
func nextHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
u, ok := auth.GetAuthenticatedUserFromContext(r.Context())
|
||||
if ok && u != nil {
|
||||
w.Header().Set("X-User", u.Username)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOptionalHandler_NoHeader_PassesThrough(t *testing.T) {
|
||||
fv := &fakeTokenValidator{}
|
||||
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Empty(t, rec.Header().Get("X-User"), "no user expected when no Authorization header")
|
||||
assert.Empty(t, fv.seen, "validator should not have been called")
|
||||
}
|
||||
|
||||
func TestOptionalHandler_MalformedHeader_PassesThrough(t *testing.T) {
|
||||
fv := &fakeTokenValidator{}
|
||||
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "Basic xxx")
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Empty(t, rec.Header().Get("X-User"))
|
||||
assert.Empty(t, fv.seen, "validator should not have been called for non-Bearer scheme")
|
||||
}
|
||||
|
||||
func TestOptionalHandler_BearerCaseInsensitive(t *testing.T) {
|
||||
fv := &fakeTokenValidator{validUser: &user.User{Username: "alice"}}
|
||||
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "bearer abc123") // lowercase
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "alice", rec.Header().Get("X-User"), "case-insensitive Bearer per RFC 6750")
|
||||
assert.Equal(t, "abc123", fv.seen)
|
||||
}
|
||||
|
||||
func TestOptionalHandler_InvalidJWT_PassesThrough(t *testing.T) {
|
||||
fv := &fakeTokenValidator{err: errors.New("bad signature")}
|
||||
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "Bearer xxx")
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "optional auth never returns 401")
|
||||
assert.Empty(t, rec.Header().Get("X-User"))
|
||||
}
|
||||
|
||||
func TestOptionalHandler_ValidJWT_InjectsUser(t *testing.T) {
|
||||
fv := &fakeTokenValidator{validUser: &user.User{ID: 7, Username: "bob"}}
|
||||
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "Bearer goodtoken")
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "bob", rec.Header().Get("X-User"))
|
||||
assert.Equal(t, "goodtoken", fv.seen)
|
||||
}
|
||||
|
||||
func TestRequiredHandler_NoHeader_Returns401(t *testing.T) {
|
||||
fv := &fakeTokenValidator{}
|
||||
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer", "RFC 6750 challenge header")
|
||||
assert.Contains(t, rec.Body.String(), "unauthorized")
|
||||
}
|
||||
|
||||
func TestRequiredHandler_InvalidJWT_Returns401WithErrorTag(t *testing.T) {
|
||||
fv := &fakeTokenValidator{err: errors.New("expired")}
|
||||
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "Bearer xxx")
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), `error="invalid_token"`)
|
||||
}
|
||||
|
||||
func TestRequiredHandler_ValidJWT_PassesThrough(t *testing.T) {
|
||||
fv := &fakeTokenValidator{validUser: &user.User{Username: "carol"}}
|
||||
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Header.Set("Authorization", "Bearer goodtoken")
|
||||
rec := httptest.NewRecorder()
|
||||
mw.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "carol", rec.Header().Get("X-User"))
|
||||
}
|
||||
|
||||
func TestExtractBearerToken_EdgeCases(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
out string
|
||||
ok bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"Bearer ", "", true}, // empty token, but matches the prefix — caller decides
|
||||
{"Bearer xxx", "xxx", true},
|
||||
{"bearer xxx", "xxx", true}, // case-insensitive
|
||||
{"BEARER xxx", "xxx", true},
|
||||
{"Basic xxx", "", false},
|
||||
{"Bearer", "", false}, // no separating space
|
||||
{"Bear", "", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.in, func(t *testing.T) {
|
||||
tok, ok := extractBearerToken(c.in)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
assert.Equal(t, c.out, tok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstWord(t *testing.T) {
|
||||
assert.Equal(t, "Bearer", firstWord("Bearer xxx"))
|
||||
assert.Equal(t, "Basic", firstWord("Basic\tabc"))
|
||||
assert.Equal(t, "Token", firstWord("Token"))
|
||||
assert.Equal(t, "", firstWord(""))
|
||||
}
|
||||
Reference in New Issue
Block a user