package server import ( "context" "errors" "net/http" "net/http/httptest" "testing" "dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/user" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // fakeTokenValidator is a minimal tokenValidator stub. type fakeTokenValidator struct { validUser *user.User err error seen string // captures the last token passed in } func (f *fakeTokenValidator) ValidateJWT(ctx context.Context, token string) (*user.User, error) { f.seen = token if f.err != nil { return nil, f.err } return f.validUser, nil } // nextHandler returns 200 with a flag in body indicating whether a user // was injected into context. func nextHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { u, ok := auth.GetAuthenticatedUserFromContext(r.Context()) if ok && u != nil { w.Header().Set("X-User", u.Username) } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) } func TestOptionalHandler_NoHeader_PassesThrough(t *testing.T) { fv := &fakeTokenValidator{} mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Header().Get("X-User"), "no user expected when no Authorization header") assert.Empty(t, fv.seen, "validator should not have been called") } func TestOptionalHandler_MalformedHeader_PassesThrough(t *testing.T) { fv := &fakeTokenValidator{} mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Basic xxx") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Header().Get("X-User")) assert.Empty(t, fv.seen, "validator should not have been called for non-Bearer scheme") } func TestOptionalHandler_BearerCaseInsensitive(t *testing.T) { fv := &fakeTokenValidator{validUser: &user.User{Username: "alice"}} mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "bearer abc123") // lowercase rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "alice", rec.Header().Get("X-User"), "case-insensitive Bearer per RFC 6750") assert.Equal(t, "abc123", fv.seen) } func TestOptionalHandler_InvalidJWT_PassesThrough(t *testing.T) { fv := &fakeTokenValidator{err: errors.New("bad signature")} mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer xxx") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "optional auth never returns 401") assert.Empty(t, rec.Header().Get("X-User")) } func TestOptionalHandler_ValidJWT_InjectsUser(t *testing.T) { fv := &fakeTokenValidator{validUser: &user.User{ID: 7, Username: "bob"}} mw := NewAuthMiddleware(fv).OptionalHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer goodtoken") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "bob", rec.Header().Get("X-User")) assert.Equal(t, "goodtoken", fv.seen) } func TestRequiredHandler_NoHeader_Returns401(t *testing.T) { fv := &fakeTokenValidator{} mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer", "RFC 6750 challenge header") assert.Contains(t, rec.Body.String(), "unauthorized") } func TestRequiredHandler_InvalidJWT_Returns401WithErrorTag(t *testing.T) { fv := &fakeTokenValidator{err: errors.New("expired")} mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer xxx") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) assert.Contains(t, rec.Header().Get("WWW-Authenticate"), `error="invalid_token"`) } func TestRequiredHandler_ValidJWT_PassesThrough(t *testing.T) { fv := &fakeTokenValidator{validUser: &user.User{Username: "carol"}} mw := NewAuthMiddleware(fv).RequiredHandler(nextHandler()) req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer goodtoken") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "carol", rec.Header().Get("X-User")) } func TestExtractBearerToken_EdgeCases(t *testing.T) { cases := []struct { in string out string ok bool }{ {"", "", false}, {"Bearer ", "", true}, // empty token, but matches the prefix — caller decides {"Bearer xxx", "xxx", true}, {"bearer xxx", "xxx", true}, // case-insensitive {"BEARER xxx", "xxx", true}, {"Basic xxx", "", false}, {"Bearer", "", false}, // no separating space {"Bear", "", false}, } for _, c := range cases { t.Run(c.in, func(t *testing.T) { tok, ok := extractBearerToken(c.in) assert.Equal(t, c.ok, ok) assert.Equal(t, c.out, tok) }) } } func TestFirstWord(t *testing.T) { assert.Equal(t, "Bearer", firstWord("Bearer xxx")) assert.Equal(t, "Basic", firstWord("Basic\tabc")) assert.Equal(t, "Token", firstWord("Token")) assert.Equal(t, "", firstWord("")) }