package api import ( "fmt" "net/http" "net/http/httptest" "testing" "dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/user" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" ) // fakeUserSvc is reused from magic_link_handler_test.go // It's in the same package (api) so we can use it directly. // fakeUserRepo is reused from magic_link_handler_test.go // It's in the same package (api) so we can use it directly. // setupMockOIDCProvider creates a mock OIDC provider server for testing. // Uses the Q-062 mitigation pattern with var server *httptest.Server. func setupMockOIDCProvider(t *testing.T) *httptest.Server { t.Helper() var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL) return } w.WriteHeader(http.StatusNotFound) })) return server } // mountOIDCHandler mounts the OIDCHandler on a new router and returns it. func mountOIDCHandler(t *testing.T, handler *OIDCHandler) *chi.Mux { t.Helper() r := chi.NewRouter() handler.RegisterRoutes(r) return r } // newTestOIDCHandler creates an OIDCHandler with the given clients. func newTestOIDCHandler(clients map[string]*auth.OIDCClient) *OIDCHandler { return NewOIDCHandler( clients, newFakeUserSvc(), &fakeUserRepo{svc: newFakeUserSvc()}, "http://localhost:8080", ) } // TestOIDCHandler_Start_RejectsUnknownProvider tests that starting with an unknown provider returns 404. func TestOIDCHandler_Start_RejectsUnknownProvider(t *testing.T) { handler := newTestOIDCHandler(map[string]*auth.OIDCClient{}) router := mountOIDCHandler(t, handler) req := httptest.NewRequest(http.MethodGet, "/oidc/unknown/start", nil) rr := httptest.NewRecorder() router.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) } // TestOIDCHandler_Callback_RejectsMissingState tests that callback without state returns 401. func TestOIDCHandler_Callback_RejectsMissingState(t *testing.T) { client := auth.NewOIDCClient("http://mock-provider", "test-id", "test-secret") handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) router := mountOIDCHandler(t, handler) req := httptest.NewRequest(http.MethodGet, "/oidc/test/callback", nil) rr := httptest.NewRecorder() router.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) } // TestOIDCHandler_Callback_RejectsUnknownState tests that callback with unknown state returns 401. func TestOIDCHandler_Callback_RejectsUnknownState(t *testing.T) { client := auth.NewOIDCClient("http://mock-provider", "test-id", "test-secret") handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) router := mountOIDCHandler(t, handler) req := httptest.NewRequest(http.MethodGet, "/oidc/test/callback?state=unknown&code=any", nil) rr := httptest.NewRecorder() router.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) } // TestOIDCHandler_Start_RedirectsWithPKCE tests that starting with a valid provider redirects with PKCE. func TestOIDCHandler_Start_RedirectsWithPKCE(t *testing.T) { // Setup mock OIDC provider mockServer := setupMockOIDCProvider(t) defer mockServer.Close() // Create OIDC client pointing to mock server client := auth.NewOIDCClient(mockServer.URL, "test-id", "test-secret") // Set a custom HTTP client that can reach the mock server client.SetHTTPClient(mockServer.Client()) handler := newTestOIDCHandler(map[string]*auth.OIDCClient{"test": client}) router := mountOIDCHandler(t, handler) req := httptest.NewRequest(http.MethodGet, "/oidc/test/start", nil) rr := httptest.NewRecorder() router.ServeHTTP(rr, req) // Assert 302 redirect assert.Equal(t, http.StatusFound, rr.Code) // Get Location header location := rr.Header().Get("Location") assert.NotEmpty(t, location) // Location should start with the mock auth endpoint expectedAuthEndpoint := mockServer.URL + "/auth" assert.Contains(t, location, expectedAuthEndpoint) // Location should contain code_challenge and state assert.Contains(t, location, "code_challenge=") assert.Contains(t, location, "state=") assert.Contains(t, location, "response_type=code") assert.Contains(t, location, "client_id=test-id") assert.Contains(t, location, "code_challenge_method=S256") } // Ensure the interfaces are satisfied at compile time var _ user.UserService = (*fakeUserSvc)(nil) var _ user.UserRepository = (*fakeUserRepo)(nil)