From dd25d7d1ae996f782f191eec474f82eb03acea59 Mon Sep 17 00:00:00 2001 From: Gabriel Radureau Date: Tue, 5 May 2026 19:53:47 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(auth):=20implement=20OIDC=20cl?= =?UTF-8?q?ient=20methods=20(ADR-0028=20Phase=20B.3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the 4 OIDCClient methods that were TODO skeletons in PR #69: - Discover(ctx) — fetch + cache .well-known/openid-configuration - RefreshJWKS(ctx) — fetch JWKS, parse RSA public keys (n/e base64-url) - ExchangeCode(ctx, code, codeVerifier, redirectURI) — POST token endpoint with PKCE - ValidateIDToken(ctx, idToken) — verify signature via JWKS, validate claims Plus 7 unit tests using httptest.NewServer to mock the OIDC provider: TestDiscover_HappyPath, TestDiscover_Idempotent, TestRefreshJWKS_HappyPath, TestExchangeCode_HappyPath, TestValidateIDToken_HappyPath, TestValidateIDToken_RejectsExpired, TestValidateIDToken_RejectsWrongIssuer. Mostly authored by Mistral Vibe (batch6, $3.51 / 46 steps), trainer-takeover on 2 bugs: - closure auto-reference bug (server := httptest.NewServer with closure body referencing server.URL — needs var server *httptest.Server then server = ...) - ExchangeCode body wasn't being sent (req.PostForm = form is wrong; must pass strings.NewReader(form.Encode()) as the request body) --- pkg/auth/oidc.go | 264 ++++++++++++++++++++++++-- pkg/auth/oidc_test.go | 420 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 664 insertions(+), 20 deletions(-) diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 81b3c4a..08e1788 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -8,9 +8,17 @@ package auth import ( "context" "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" "net/http" + "net/url" + "strings" "sync" "time" + + "github.com/golang-jwt/jwt/v5" ) // OIDCClient is a per-provider OIDC client. @@ -52,14 +60,25 @@ type TokenResponse struct { // IDTokenClaims represents the parsed claims from an ID token. type IDTokenClaims struct { - Issuer string `json:"iss"` - Subject string `json:"sub"` - Audience string `json:"aud"` - ExpirationTime int64 `json:"exp"` - IssuedAt int64 `json:"iat"` - Nonce string `json:"nonce,omitempty"` - Email string `json:"email,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` + jwt.RegisteredClaims + Email string `json:"email,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + Name string `json:"name,omitempty"` +} + +// jwks represents the JWKS (JSON Web Key Set) response. +type jwks struct { + Keys []jwk `json:"keys"` +} + +// jwk represents a single JSON Web Key. +type jwk struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + N string `json:"n"` + E string `json:"e"` + Use string `json:"use,omitempty"` + Alg string `json:"alg,omitempty"` } // NewOIDCClient constructs a client. Discovery + JWKS are NOT fetched eagerly; @@ -75,30 +94,237 @@ func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient { } } +// decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values. +func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) { + nBytes, err := base64.RawURLEncoding.DecodeString(j.N) + if err != nil { + return nil, fmt.Errorf("decode n: %w", err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(j.E) + if err != nil { + return nil, fmt.Errorf("decode e: %w", err) + } + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil +} + // Discover fetches and caches the .well-known document. Idempotent. // First call: HTTP fetch + cache. Subsequent calls: cached value. func (c *OIDCClient) Discover(ctx context.Context) (*Discovery, error) { - // TODO Phase B.3: implement (HTTP GET issuerURL + "/.well-known/openid-configuration") - return nil, nil // placeholder for skeleton phase + c.discoveryMu.RLock() + if c.discovery != nil { + c.discoveryMu.RUnlock() + return c.discovery, nil + } + c.discoveryMu.RUnlock() + + c.discoveryMu.Lock() + defer c.discoveryMu.Unlock() + + // Double-check after acquiring write lock + if c.discovery != nil { + return c.discovery, nil + } + + wellKnownURL := fmt.Sprintf("%s/.well-known/openid-configuration", c.issuerURL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL, nil) + if err != nil { + return nil, fmt.Errorf("create discovery request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch discovery: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("discovery HTTP %d", resp.StatusCode) + } + + var disc Discovery + if err := json.NewDecoder(resp.Body).Decode(&disc); err != nil { + return nil, fmt.Errorf("decode discovery: %w", err) + } + + c.discovery = &disc + return &disc, nil } // RefreshJWKS fetches JWKS URI, parse keys, populate jwks map. -// TODO Phase B.3: implement func (c *OIDCClient) RefreshJWKS(ctx context.Context) error { - // TODO Phase B.3: implement (HTTP GET to JWKS URI from discovery, parse keys) - return nil // placeholder for skeleton phase + // Ensure discovery is loaded + if c.discovery == nil { + if _, err := c.Discover(ctx); err != nil { + return fmt.Errorf("discover: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.discovery.JWKSUri, nil) + if err != nil { + return fmt.Errorf("create JWKS request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("fetch JWKS: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("JWKS HTTP %d", resp.StatusCode) + } + + var keySet jwks + if err := json.NewDecoder(resp.Body).Decode(&keySet); err != nil { + return fmt.Errorf("decode JWKS: %w", err) + } + + c.jwksMu.Lock() + defer c.jwksMu.Unlock() + + c.jwks = make(map[string]*rsa.PublicKey) + for _, key := range keySet.Keys { + if key.Kty == "RSA" { + pubKey, err := decodeRSAPublicKey(key) + if err != nil { + return fmt.Errorf("decode RSA key %s: %w", key.Kid, err) + } + c.jwks[key.Kid] = pubKey + } + } + + c.jwksFetched = time.Now() + return nil } // ExchangeCode exchanges an authorization code for an access token and ID token. -// TODO Phase B.3: implement func (c *OIDCClient) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI string) (*TokenResponse, error) { - // TODO Phase B.3: implement (POST to token_endpoint with code, code_verifier, redirect_uri) - return nil, nil // placeholder for skeleton phase + // Ensure discovery is loaded + if c.discovery == nil { + if _, err := c.Discover(ctx); err != nil { + return nil, fmt.Errorf("discover: %w", err) + } + } + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("code_verifier", codeVerifier) + form.Set("redirect_uri", redirectURI) + form.Set("client_id", c.clientID) + form.Set("client_secret", c.clientSecret) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.discovery.TokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("exchange code: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token HTTP %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, fmt.Errorf("decode token response: %w", err) + } + + return &tokenResp, nil } // ValidateIDToken verifies the signature and claims of an ID token. -// TODO Phase B.3: implement func (c *OIDCClient) ValidateIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) { - // TODO Phase B.3: implement (verify signature with JWKS, validate claims) - return nil, nil // placeholder for skeleton phase + // First, parse without verification to get the kid + parser := jwt.NewParser() + unverifiedToken, _, err := parser.ParseUnverified(idToken, &IDTokenClaims{}) + if err != nil { + return nil, fmt.Errorf("parse unverified token: %w", err) + } + + claims, ok := unverifiedToken.Claims.(*IDTokenClaims) + if !ok { + return nil, fmt.Errorf("invalid claims type") + } + + // Get kid from header + kid, ok := unverifiedToken.Header["kid"].(string) + if !ok || kid == "" { + return nil, fmt.Errorf("missing kid in token header") + } + + // Get the key, refreshing JWKS if needed + c.jwksMu.RLock() + _, keyExists := c.jwks[kid] + c.jwksMu.RUnlock() + + if !keyExists { + if err := c.RefreshJWKS(ctx); err != nil { + return nil, fmt.Errorf("refresh JWKS: %w", err) + } + + c.jwksMu.RLock() + _, keyExists = c.jwks[kid] + c.jwksMu.RUnlock() + + if !keyExists { + return nil, fmt.Errorf("key %s not found in JWKS", kid) + } + } + + // Parse with verification + keyFunc := func(token *jwt.Token) (interface{}, error) { + if kid, ok := token.Header["kid"].(string); ok { + c.jwksMu.RLock() + defer c.jwksMu.RUnlock() + if key, exists := c.jwks[kid]; exists { + return key, nil + } + } + return nil, fmt.Errorf("key not found") + } + + parsedToken, err := jwt.ParseWithClaims(idToken, &IDTokenClaims{}, keyFunc) + if err != nil { + return nil, fmt.Errorf("parse token: %w", err) + } + + claims, ok = parsedToken.Claims.(*IDTokenClaims) + if !ok { + return nil, fmt.Errorf("invalid claims type after parse") + } + + // Validate claims + if claims.Issuer != c.issuerURL { + return nil, fmt.Errorf("issuer mismatch: expected %s, got %s", c.issuerURL, claims.Issuer) + } + + // Check audience contains clientID + audValid := false + if claims.Audience != nil { + for _, aud := range claims.Audience { + if aud == c.clientID { + audValid = true + break + } + } + } + if !audValid { + return nil, fmt.Errorf("audience does not contain client ID %s", c.clientID) + } + + // Check expiration + if claims.ExpiresAt != nil && time.Now().UTC().After(claims.ExpiresAt.Time) { + return nil, fmt.Errorf("token expired") + } + + return claims, nil } diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index 952e6f5..65595b2 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -1,6 +1,21 @@ package auth -import "testing" +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) func TestNewOIDCClient(t *testing.T) { c := NewOIDCClient("https://example.com", "client_id", "client_secret") @@ -11,3 +26,406 @@ func TestNewOIDCClient(t *testing.T) { t.Errorf("issuerURL not set: got %q", c.issuerURL) } } + +func TestDiscover_HappyPath(t *testing.T) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + disc, err := client.Discover(context.Background()) + if err != nil { + t.Fatalf("Discover failed: %v", err) + } + + if disc.Issuer != server.URL { + t.Errorf("issuer mismatch: got %s, want %s", disc.Issuer, server.URL) + } + if disc.TokenEndpoint != server.URL+"/token" { + t.Errorf("token endpoint mismatch: got %s", disc.TokenEndpoint) + } + if disc.JWKSUri != server.URL+"/jwks" { + t.Errorf("jwks_uri mismatch: got %s", disc.JWKSUri) + } +} + +func TestDiscover_Idempotent(t *testing.T) { + var requestCount int32 + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // First call + _, err := client.Discover(context.Background()) + if err != nil { + t.Fatalf("First Discover failed: %v", err) + } + + // Second call + _, err = client.Discover(context.Background()) + if err != nil { + t.Fatalf("Second Discover failed: %v", err) + } + + if atomic.LoadInt32(&requestCount) != 1 { + t.Errorf("Expected 1 HTTP request, got %d", requestCount) + } +} + +func generateTestRSAKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate RSA key: %v", err) + } + return privKey +} + +func encodeRSAPublicKey(privKey *rsa.PrivateKey) (n, e string) { + n = base64.RawURLEncoding.EncodeToString(privKey.PublicKey.N.Bytes()) + e = base64.RawURLEncoding.EncodeToString(big.NewInt(int64(privKey.PublicKey.E)).Bytes()) + return n, e +} + +func TestRefreshJWKS_HappyPath(t *testing.T) { + privKey := generateTestRSAKey(t) + n, e := encodeRSAPublicKey(privKey) + + var discoveryCalled, jwksCalled bool + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + discoveryCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + return + } + if r.URL.Path == "/jwks" { + jwksCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // First discover to populate discovery + _, err := client.Discover(context.Background()) + if err != nil { + t.Fatalf("Discover failed: %v", err) + } + + // Now refresh JWKS + err = client.RefreshJWKS(context.Background()) + if err != nil { + t.Fatalf("RefreshJWKS failed: %v", err) + } + + if !discoveryCalled { + t.Error("discovery endpoint was not called") + } + if !jwksCalled { + t.Error("jwks endpoint was not called") + } + + // Check that jwks was populated + client.jwksMu.RLock() + defer client.jwksMu.RUnlock() + + if len(client.jwks) != 1 { + t.Errorf("expected 1 key in jwks, got %d", len(client.jwks)) + } + + if _, exists := client.jwks["test-key-id"]; !exists { + t.Error("test-key-id not found in jwks") + } +} + +func TestExchangeCode_HappyPath(t *testing.T) { + tokenResponseJSON := `{"access_token":"access-token-123","id_token":"id-token-456","token_type":"Bearer","expires_in":3600}` + + var receivedForm url.Values + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + return + } + if r.URL.Path == "/token" { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + t.Errorf("expected Content-Type application/x-www-form-urlencoded, got %s", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusBadRequest) + return + } + + err := r.ParseForm() + if err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + receivedForm = r.Form + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(tokenResponseJSON)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // Discover first to populate discovery + _, err := client.Discover(context.Background()) + if err != nil { + t.Fatalf("Discover failed: %v", err) + } + + resp, err := client.ExchangeCode(context.Background(), "auth-code-789", "code-verifier-123", "https://app.example.com/callback") + if err != nil { + t.Fatalf("ExchangeCode failed: %v", err) + } + + if resp.AccessToken != "access-token-123" { + t.Errorf("access token mismatch: got %s", resp.AccessToken) + } + if resp.IDToken != "id-token-456" { + t.Errorf("id token mismatch: got %s", resp.IDToken) + } + if resp.TokenType != "Bearer" { + t.Errorf("token type mismatch: got %s", resp.TokenType) + } + + // Check form values + if receivedForm.Get("grant_type") != "authorization_code" { + t.Errorf("grant_type mismatch: got %s", receivedForm.Get("grant_type")) + } + if receivedForm.Get("code") != "auth-code-789" { + t.Errorf("code mismatch: got %s", receivedForm.Get("code")) + } + if receivedForm.Get("code_verifier") != "code-verifier-123" { + t.Errorf("code_verifier mismatch: got %s", receivedForm.Get("code_verifier")) + } + if receivedForm.Get("redirect_uri") != "https://app.example.com/callback" { + t.Errorf("redirect_uri mismatch: got %s", receivedForm.Get("redirect_uri")) + } + if receivedForm.Get("client_id") != "client_id" { + t.Errorf("client_id mismatch: got %s", receivedForm.Get("client_id")) + } + if receivedForm.Get("client_secret") != "client_secret" { + t.Errorf("client_secret mismatch: got %s", receivedForm.Get("client_secret")) + } +} + +func TestValidateIDToken_HappyPath(t *testing.T) { + privKey := generateTestRSAKey(t) + n, e := encodeRSAPublicKey(privKey) + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + return + } + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // Create and sign a JWT + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: server.URL, + Audience: jwt.ClaimStrings{"client_id"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Subject: "user-123", + }, + Email: "user@example.com", + EmailVerified: true, + Name: "Test User", + }) + token.Header["kid"] = "test-key-id" + + signedToken, err := token.SignedString(privKey) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + // Validate the token + claims, err := client.ValidateIDToken(context.Background(), signedToken) + if err != nil { + t.Fatalf("ValidateIDToken failed: %v", err) + } + + if claims.Issuer != server.URL { + t.Errorf("issuer mismatch: got %s, want %s", claims.Issuer, server.URL) + } + if claims.Subject != "user-123" { + t.Errorf("subject mismatch: got %s", claims.Subject) + } + if claims.Email != "user@example.com" { + t.Errorf("email mismatch: got %s", claims.Email) + } + if !claims.EmailVerified { + t.Error("email_verified should be true") + } + if claims.Name != "Test User" { + t.Errorf("name mismatch: got %s", claims.Name) + } +} + +func TestValidateIDToken_RejectsExpired(t *testing.T) { + privKey := generateTestRSAKey(t) + n, e := encodeRSAPublicKey(privKey) + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + return + } + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // Create an expired JWT + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: server.URL, + Audience: jwt.ClaimStrings{"client_id"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago + IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + Subject: "user-123", + }, + }) + token.Header["kid"] = "test-key-id" + + signedToken, err := token.SignedString(privKey) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + // Should fail due to expired token + _, err = client.ValidateIDToken(context.Background(), signedToken) + if err == nil { + t.Error("expected error for expired token, got nil") + } +} + +func TestValidateIDToken_RejectsWrongIssuer(t *testing.T) { + privKey := generateTestRSAKey(t) + n, e := encodeRSAPublicKey(privKey) + + wrongIssuer := "https://wrong-provider.example.com" + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, + server.URL, server.URL, server.URL, server.URL))) + return + } + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewOIDCClient(server.URL, "client_id", "client_secret") + client.httpClient = server.Client() + + // Create a JWT with wrong issuer + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: wrongIssuer, + Audience: jwt.ClaimStrings{"client_id"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Subject: "user-123", + }, + }) + token.Header["kid"] = "test-key-id" + + signedToken, err := token.SignedString(privKey) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + // Should fail due to issuer mismatch + _, err = client.ValidateIDToken(context.Background(), signedToken) + if err == nil { + t.Error("expected error for wrong issuer, got nil") + } +}