🧪 test(server): unit tests for AuthMiddleware Optional/Required handlers #92
181
pkg/server/middleware_test.go
Normal file
181
pkg/server/middleware_test.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/auth"
|
||||||
|
"dance-lessons-coach/pkg/user"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeTokenValidator is a minimal tokenValidator stub.
|
||||||
|
type fakeTokenValidator struct {
|
||||||
|
validUser *user.User
|
||||||
|
err error
|
||||||
|
seen string // captures the last token passed in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTokenValidator) ValidateJWT(ctx context.Context, token string) (*user.User, error) {
|
||||||
|
f.seen = token
|
||||||
|
if f.err != nil {
|
||||||
|
return nil, f.err
|
||||||
|
}
|
||||||
|
return f.validUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextHandler returns 200 with a flag in body indicating whether a user
|
||||||
|
// was injected into context.
|
||||||
|
func nextHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
u, ok := auth.GetAuthenticatedUserFromContext(r.Context())
|
||||||
|
if ok && u != nil {
|
||||||
|
w.Header().Set("X-User", u.Username)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalHandler_NoHeader_PassesThrough(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{}
|
||||||
|
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Empty(t, rec.Header().Get("X-User"), "no user expected when no Authorization header")
|
||||||
|
assert.Empty(t, fv.seen, "validator should not have been called")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalHandler_MalformedHeader_PassesThrough(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{}
|
||||||
|
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "Basic xxx")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Empty(t, rec.Header().Get("X-User"))
|
||||||
|
assert.Empty(t, fv.seen, "validator should not have been called for non-Bearer scheme")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalHandler_BearerCaseInsensitive(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{validUser: &user.User{Username: "alice"}}
|
||||||
|
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "bearer abc123") // lowercase
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "alice", rec.Header().Get("X-User"), "case-insensitive Bearer per RFC 6750")
|
||||||
|
assert.Equal(t, "abc123", fv.seen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalHandler_InvalidJWT_PassesThrough(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{err: errors.New("bad signature")}
|
||||||
|
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer xxx")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, "optional auth never returns 401")
|
||||||
|
assert.Empty(t, rec.Header().Get("X-User"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalHandler_ValidJWT_InjectsUser(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{validUser: &user.User{ID: 7, Username: "bob"}}
|
||||||
|
mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer goodtoken")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "bob", rec.Header().Get("X-User"))
|
||||||
|
assert.Equal(t, "goodtoken", fv.seen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequiredHandler_NoHeader_Returns401(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{}
|
||||||
|
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer", "RFC 6750 challenge header")
|
||||||
|
assert.Contains(t, rec.Body.String(), "unauthorized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequiredHandler_InvalidJWT_Returns401WithErrorTag(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{err: errors.New("expired")}
|
||||||
|
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer xxx")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), `error="invalid_token"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequiredHandler_ValidJWT_PassesThrough(t *testing.T) {
|
||||||
|
fv := &fakeTokenValidator{validUser: &user.User{Username: "carol"}}
|
||||||
|
mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer goodtoken")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mw.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "carol", rec.Header().Get("X-User"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBearerToken_EdgeCases(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
out string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"", "", false},
|
||||||
|
{"Bearer ", "", true}, // empty token, but matches the prefix — caller decides
|
||||||
|
{"Bearer xxx", "xxx", true},
|
||||||
|
{"bearer xxx", "xxx", true}, // case-insensitive
|
||||||
|
{"BEARER xxx", "xxx", true},
|
||||||
|
{"Basic xxx", "", false},
|
||||||
|
{"Bearer", "", false}, // no separating space
|
||||||
|
{"Bear", "", false},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.in, func(t *testing.T) {
|
||||||
|
tok, ok := extractBearerToken(c.in)
|
||||||
|
assert.Equal(t, c.ok, ok)
|
||||||
|
assert.Equal(t, c.out, tok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirstWord(t *testing.T) {
|
||||||
|
assert.Equal(t, "Bearer", firstWord("Bearer xxx"))
|
||||||
|
assert.Equal(t, "Basic", firstWord("Basic\tabc"))
|
||||||
|
assert.Equal(t, "Token", firstWord("Token"))
|
||||||
|
assert.Equal(t, "", firstWord(""))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user