diff --git a/pkg/user/api/oidc_handler_test.go b/pkg/user/api/oidc_handler_test.go new file mode 100644 index 0000000..ee97958 --- /dev/null +++ b/pkg/user/api/oidc_handler_test.go @@ -0,0 +1,134 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "dance-lessons-coach/pkg/auth" + "dance-lessons-coach/pkg/user" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" +) + +// fakeUserSvc is reused from magic_link_handler_test.go +// It's in the same package (api) so we can use it directly. + +// fakeUserRepo is reused from magic_link_handler_test.go +// It's in the same package (api) so we can use it directly. + +// setupMockOIDCProvider creates a mock OIDC provider server for testing. +// Uses the Q-062 mitigation pattern with var server *httptest.Server. +func setupMockOIDCProvider(t *testing.T) *httptest.Server { + t.Helper() + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL) + return + } + w.WriteHeader(http.StatusNotFound) + })) + return server +} + +// mountOIDCHandler mounts the OIDCHandler on a new router and returns it. +func mountOIDCHandler(t *testing.T, handler *OIDCHandler) *chi.Mux { + t.Helper() + r := chi.NewRouter() + handler.RegisterRoutes(r) + return r +} + +// newTestOIDCHandler creates an OIDCHandler with the given clients. +func newTestOIDCHandler(clients map[string]*auth.OIDCClient) *OIDCHandler { + return NewOIDCHandler( + clients, + newFakeUserSvc(), + &fakeUserRepo{svc: newFakeUserSvc()}, + "http://localhost:8080", + ) +} + +// TestOIDCHandler_Start_RejectsUnknownProvider tests that starting with an unknown provider returns 404. +func TestOIDCHandler_Start_RejectsUnknownProvider(t *testing.T) { + handler := newTestOIDCHandler(map[string]*auth.OIDCClient{}) + router := mountOIDCHandler(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/oidc/unknown/start", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +// TestOIDCHandler_Callback_RejectsMissingState tests that callback without state returns 401. +func TestOIDCHandler_Callback_RejectsMissingState(t *testing.T) { + client := auth.NewOIDCClient("http://mock-provider", "test-id", "test-secret") + handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) + router := mountOIDCHandler(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/oidc/test/callback", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +// TestOIDCHandler_Callback_RejectsUnknownState tests that callback with unknown state returns 401. +func TestOIDCHandler_Callback_RejectsUnknownState(t *testing.T) { + client := auth.NewOIDCClient("http://mock-provider", "test-id", "test-secret") + handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) + router := mountOIDCHandler(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/oidc/test/callback?state=unknown&code=any", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +// TestOIDCHandler_Start_RedirectsWithPKCE tests that starting with a valid provider redirects with PKCE. +func TestOIDCHandler_Start_RedirectsWithPKCE(t *testing.T) { + // Setup mock OIDC provider + mockServer := setupMockOIDCProvider(t) + defer mockServer.Close() + + // Create OIDC client pointing to mock server + client := auth.NewOIDCClient(mockServer.URL, "test-id", "test-secret") + // Set a custom HTTP client that can reach the mock server + client.SetHTTPClient(mockServer.Client()) + + handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) + router := mountOIDCHandler(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/oidc/test/start", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + // Assert 302 redirect + assert.Equal(t, http.StatusFound, rr.Code) + + // Get Location header + location := rr.Header().Get("Location") + assert.NotEmpty(t, location) + + // Location should start with the mock auth endpoint + expectedAuthEndpoint := mockServer.URL + "/auth" + assert.Contains(t, location, expectedAuthEndpoint) + + // Location should contain code_challenge and state + assert.Contains(t, location, "code_challenge=") + assert.Contains(t, location, "state=") + assert.Contains(t, location, "response_type=code") + assert.Contains(t, location, "client_id=test-id") + assert.Contains(t, location, "code_challenge_method=S256") +} + +// Ensure the interfaces are satisfied at compile time +var _ user.UserService = (*fakeUserSvc)(nil) +var _ user.UserRepository = (*fakeUserRepo)(nil)