package auth import ( "context" "crypto/rand" "crypto/rsa" "encoding/base64" "fmt" "math/big" "net/http" "net/http/httptest" "net/url" "sync/atomic" "testing" "time" "github.com/golang-jwt/jwt/v5" ) func TestNewOIDCClient(t *testing.T) { c := NewOIDCClient("https://example.com", "client_id", "client_secret") if c == nil { t.Fatal("NewOIDCClient returned nil") } if c.issuerURL != "https://example.com" { t.Errorf("issuerURL not set: got %q", c.issuerURL) } } func TestDiscover_HappyPath(t *testing.T) { var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { t.Errorf("unexpected path: %s", r.URL.Path) w.WriteHeader(http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() disc, err := client.Discover(context.Background()) if err != nil { t.Fatalf("Discover failed: %v", err) } if disc.Issuer != server.URL { t.Errorf("issuer mismatch: got %s, want %s", disc.Issuer, server.URL) } if disc.TokenEndpoint != server.URL+"/token" { t.Errorf("token endpoint mismatch: got %s", disc.TokenEndpoint) } if disc.JWKSUri != server.URL+"/jwks" { t.Errorf("jwks_uri mismatch: got %s", disc.JWKSUri) } } func TestDiscover_Idempotent(t *testing.T) { var requestCount int32 var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&requestCount, 1) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // First call _, err := client.Discover(context.Background()) if err != nil { t.Fatalf("First Discover failed: %v", err) } // Second call _, err = client.Discover(context.Background()) if err != nil { t.Fatalf("Second Discover failed: %v", err) } if atomic.LoadInt32(&requestCount) != 1 { t.Errorf("Expected 1 HTTP request, got %d", requestCount) } } func generateTestRSAKey(t *testing.T) *rsa.PrivateKey { t.Helper() privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("failed to generate RSA key: %v", err) } return privKey } func encodeRSAPublicKey(privKey *rsa.PrivateKey) (n, e string) { n = base64.RawURLEncoding.EncodeToString(privKey.PublicKey.N.Bytes()) e = base64.RawURLEncoding.EncodeToString(big.NewInt(int64(privKey.PublicKey.E)).Bytes()) return n, e } func TestRefreshJWKS_HappyPath(t *testing.T) { privKey := generateTestRSAKey(t) n, e := encodeRSAPublicKey(privKey) var discoveryCalled, jwksCalled bool var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { discoveryCalled = true w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) return } if r.URL.Path == "/jwks" { jwksCalled = true w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // First discover to populate discovery _, err := client.Discover(context.Background()) if err != nil { t.Fatalf("Discover failed: %v", err) } // Now refresh JWKS err = client.RefreshJWKS(context.Background()) if err != nil { t.Fatalf("RefreshJWKS failed: %v", err) } if !discoveryCalled { t.Error("discovery endpoint was not called") } if !jwksCalled { t.Error("jwks endpoint was not called") } // Check that jwks was populated client.jwksMu.RLock() defer client.jwksMu.RUnlock() if len(client.jwks) != 1 { t.Errorf("expected 1 key in jwks, got %d", len(client.jwks)) } if _, exists := client.jwks["test-key-id"]; !exists { t.Error("test-key-id not found in jwks") } } func TestExchangeCode_HappyPath(t *testing.T) { tokenResponseJSON := `{"access_token":"access-token-123","id_token":"id-token-456","token_type":"Bearer","expires_in":3600}` var receivedForm url.Values var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) return } if r.URL.Path == "/token" { if r.Method != http.MethodPost { t.Errorf("expected POST, got %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { t.Errorf("expected Content-Type application/x-www-form-urlencoded, got %s", r.Header.Get("Content-Type")) w.WriteHeader(http.StatusBadRequest) return } err := r.ParseForm() if err != nil { t.Errorf("failed to parse form: %v", err) w.WriteHeader(http.StatusBadRequest) return } receivedForm = r.Form w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(tokenResponseJSON)) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // Discover first to populate discovery _, err := client.Discover(context.Background()) if err != nil { t.Fatalf("Discover failed: %v", err) } resp, err := client.ExchangeCode(context.Background(), "auth-code-789", "code-verifier-123", "https://app.example.com/callback") if err != nil { t.Fatalf("ExchangeCode failed: %v", err) } if resp.AccessToken != "access-token-123" { t.Errorf("access token mismatch: got %s", resp.AccessToken) } if resp.IDToken != "id-token-456" { t.Errorf("id token mismatch: got %s", resp.IDToken) } if resp.TokenType != "Bearer" { t.Errorf("token type mismatch: got %s", resp.TokenType) } // Check form values if receivedForm.Get("grant_type") != "authorization_code" { t.Errorf("grant_type mismatch: got %s", receivedForm.Get("grant_type")) } if receivedForm.Get("code") != "auth-code-789" { t.Errorf("code mismatch: got %s", receivedForm.Get("code")) } if receivedForm.Get("code_verifier") != "code-verifier-123" { t.Errorf("code_verifier mismatch: got %s", receivedForm.Get("code_verifier")) } if receivedForm.Get("redirect_uri") != "https://app.example.com/callback" { t.Errorf("redirect_uri mismatch: got %s", receivedForm.Get("redirect_uri")) } if receivedForm.Get("client_id") != "client_id" { t.Errorf("client_id mismatch: got %s", receivedForm.Get("client_id")) } if receivedForm.Get("client_secret") != "client_secret" { t.Errorf("client_secret mismatch: got %s", receivedForm.Get("client_secret")) } } func TestValidateIDToken_HappyPath(t *testing.T) { privKey := generateTestRSAKey(t) n, e := encodeRSAPublicKey(privKey) var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) return } if r.URL.Path == "/jwks" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // Create and sign a JWT token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: server.URL, Audience: jwt.ClaimStrings{"client_id"}, ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), Subject: "user-123", }, Email: "user@example.com", EmailVerified: true, Name: "Test User", }) token.Header["kid"] = "test-key-id" signedToken, err := token.SignedString(privKey) if err != nil { t.Fatalf("failed to sign token: %v", err) } // Validate the token claims, err := client.ValidateIDToken(context.Background(), signedToken) if err != nil { t.Fatalf("ValidateIDToken failed: %v", err) } if claims.Issuer != server.URL { t.Errorf("issuer mismatch: got %s, want %s", claims.Issuer, server.URL) } if claims.Subject != "user-123" { t.Errorf("subject mismatch: got %s", claims.Subject) } if claims.Email != "user@example.com" { t.Errorf("email mismatch: got %s", claims.Email) } if !claims.EmailVerified { t.Error("email_verified should be true") } if claims.Name != "Test User" { t.Errorf("name mismatch: got %s", claims.Name) } } func TestValidateIDToken_RejectsExpired(t *testing.T) { privKey := generateTestRSAKey(t) n, e := encodeRSAPublicKey(privKey) var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) return } if r.URL.Path == "/jwks" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // Create an expired JWT token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: server.URL, Audience: jwt.ClaimStrings{"client_id"}, ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), Subject: "user-123", }, }) token.Header["kid"] = "test-key-id" signedToken, err := token.SignedString(privKey) if err != nil { t.Fatalf("failed to sign token: %v", err) } // Should fail due to expired token _, err = client.ValidateIDToken(context.Background(), signedToken) if err == nil { t.Error("expected error for expired token, got nil") } } func TestValidateIDToken_RejectsWrongIssuer(t *testing.T) { privKey := generateTestRSAKey(t) n, e := encodeRSAPublicKey(privKey) wrongIssuer := "https://wrong-provider.example.com" var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/.well-known/openid-configuration" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/auth","token_endpoint":"%s/token","jwks_uri":"%s/jwks"}`, server.URL, server.URL, server.URL, server.URL))) return } if r.URL.Path == "/jwks" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf(`{"keys":[{"kid":"test-key-id","kty":"RSA","use":"sig","alg":"RS256","n":"%s","e":"%s"}]}`, n, e))) return } w.WriteHeader(http.StatusNotFound) })) defer server.Close() client := NewOIDCClient(server.URL, "client_id", "client_secret") client.httpClient = server.Client() // Create a JWT with wrong issuer token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: wrongIssuer, Audience: jwt.ClaimStrings{"client_id"}, ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), Subject: "user-123", }, }) token.Header["kid"] = "test-key-id" signedToken, err := token.SignedString(privKey) if err != nil { t.Fatalf("failed to sign token: %v", err) } // Should fail due to issuer mismatch _, err = client.ValidateIDToken(context.Background(), signedToken) if err == nil { t.Error("expected error for wrong issuer, got nil") } }