// Package auth provides OpenID Connect client primitives for the // dance-lessons-coach passwordless-auth migration (ADR-0028 Phase B). // // This file defines the client surface only. HTTP handlers wire-up // happens in pkg/user/api/oidc_handler.go (separate phase B.3). package auth import ( "context" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "math/big" "net/http" "net/url" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" ) // OIDCClient is a per-provider OIDC client. // Holds the discovery document + JWKS cache + OAuth code-exchange config. type OIDCClient struct { issuerURL string clientID string clientSecret string httpClient *http.Client // discovery document, lazy-fetched on first use discoveryMu sync.RWMutex discovery *Discovery // JWKS cache (id_token signature verification keys), refreshed periodically jwksMu sync.RWMutex jwks map[string]*rsa.PublicKey jwksFetched time.Time } // Discovery is the subset of the .well-known/openid-configuration document we use. type Discovery struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` JWKSUri string `json:"jwks_uri"` UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` } // TokenResponse is the response from the token endpoint after code exchange. type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` IDToken string `json:"id_token"` Scope string `json:"scope,omitempty"` } // IDTokenClaims represents the parsed claims from an ID token. type IDTokenClaims struct { jwt.RegisteredClaims Email string `json:"email,omitempty"` EmailVerified bool `json:"email_verified,omitempty"` Name string `json:"name,omitempty"` } // jwks represents the JWKS (JSON Web Key Set) response. type jwks struct { Keys []jwk `json:"keys"` } // jwk represents a single JSON Web Key. type jwk struct { Kid string `json:"kid"` Kty string `json:"kty"` N string `json:"n"` E string `json:"e"` Use string `json:"use,omitempty"` Alg string `json:"alg,omitempty"` } // NewOIDCClient constructs a client. Discovery + JWKS are NOT fetched eagerly; // they are lazy-loaded on first use to avoid blocking server startup if the // provider is temporarily down. func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient { return &OIDCClient{ issuerURL: issuerURL, clientID: clientID, clientSecret: clientSecret, httpClient: &http.Client{Timeout: 10 * time.Second}, jwks: make(map[string]*rsa.PublicKey), } } // ClientID returns the OIDC client ID. func (c *OIDCClient) ClientID() string { return c.clientID } // IssuerURL returns the OIDC issuer URL. func (c *OIDCClient) IssuerURL() string { return c.issuerURL } // SetHTTPClient sets a custom HTTP client for testing. func (c *OIDCClient) SetHTTPClient(client *http.Client) { c.httpClient = client } // decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values. func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) { nBytes, err := base64.RawURLEncoding.DecodeString(j.N) if err != nil { return nil, fmt.Errorf("decode n: %w", err) } eBytes, err := base64.RawURLEncoding.DecodeString(j.E) if err != nil { return nil, fmt.Errorf("decode e: %w", err) } n := new(big.Int).SetBytes(nBytes) e := new(big.Int).SetBytes(eBytes) return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil } // Discover fetches and caches the .well-known document. Idempotent. // First call: HTTP fetch + cache. Subsequent calls: cached value. func (c *OIDCClient) Discover(ctx context.Context) (*Discovery, error) { c.discoveryMu.RLock() if c.discovery != nil { c.discoveryMu.RUnlock() return c.discovery, nil } c.discoveryMu.RUnlock() c.discoveryMu.Lock() defer c.discoveryMu.Unlock() // Double-check after acquiring write lock if c.discovery != nil { return c.discovery, nil } wellKnownURL := fmt.Sprintf("%s/.well-known/openid-configuration", c.issuerURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL, nil) if err != nil { return nil, fmt.Errorf("create discovery request: %w", err) } resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("fetch discovery: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("discovery HTTP %d", resp.StatusCode) } var disc Discovery if err := json.NewDecoder(resp.Body).Decode(&disc); err != nil { return nil, fmt.Errorf("decode discovery: %w", err) } c.discovery = &disc return &disc, nil } // RefreshJWKS fetches JWKS URI, parse keys, populate jwks map. func (c *OIDCClient) RefreshJWKS(ctx context.Context) error { // Ensure discovery is loaded if c.discovery == nil { if _, err := c.Discover(ctx); err != nil { return fmt.Errorf("discover: %w", err) } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.discovery.JWKSUri, nil) if err != nil { return fmt.Errorf("create JWKS request: %w", err) } resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("fetch JWKS: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("JWKS HTTP %d", resp.StatusCode) } var keySet jwks if err := json.NewDecoder(resp.Body).Decode(&keySet); err != nil { return fmt.Errorf("decode JWKS: %w", err) } c.jwksMu.Lock() defer c.jwksMu.Unlock() c.jwks = make(map[string]*rsa.PublicKey) for _, key := range keySet.Keys { if key.Kty == "RSA" { pubKey, err := decodeRSAPublicKey(key) if err != nil { return fmt.Errorf("decode RSA key %s: %w", key.Kid, err) } c.jwks[key.Kid] = pubKey } } c.jwksFetched = time.Now() return nil } // ExchangeCode exchanges an authorization code for an access token and ID token. func (c *OIDCClient) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI string) (*TokenResponse, error) { // Ensure discovery is loaded if c.discovery == nil { if _, err := c.Discover(ctx); err != nil { return nil, fmt.Errorf("discover: %w", err) } } form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("code_verifier", codeVerifier) form.Set("redirect_uri", redirectURI) form.Set("client_id", c.clientID) form.Set("client_secret", c.clientSecret) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.discovery.TokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("create token request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("exchange code: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token HTTP %d", resp.StatusCode) } var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return nil, fmt.Errorf("decode token response: %w", err) } return &tokenResp, nil } // ValidateIDToken verifies the signature and claims of an ID token. func (c *OIDCClient) ValidateIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) { // First, parse without verification to get the kid parser := jwt.NewParser() unverifiedToken, _, err := parser.ParseUnverified(idToken, &IDTokenClaims{}) if err != nil { return nil, fmt.Errorf("parse unverified token: %w", err) } claims, ok := unverifiedToken.Claims.(*IDTokenClaims) if !ok { return nil, fmt.Errorf("invalid claims type") } // Get kid from header kid, ok := unverifiedToken.Header["kid"].(string) if !ok || kid == "" { return nil, fmt.Errorf("missing kid in token header") } // Get the key, refreshing JWKS if needed c.jwksMu.RLock() _, keyExists := c.jwks[kid] c.jwksMu.RUnlock() if !keyExists { if err := c.RefreshJWKS(ctx); err != nil { return nil, fmt.Errorf("refresh JWKS: %w", err) } c.jwksMu.RLock() _, keyExists = c.jwks[kid] c.jwksMu.RUnlock() if !keyExists { return nil, fmt.Errorf("key %s not found in JWKS", kid) } } // Parse with verification keyFunc := func(token *jwt.Token) (interface{}, error) { if kid, ok := token.Header["kid"].(string); ok { c.jwksMu.RLock() defer c.jwksMu.RUnlock() if key, exists := c.jwks[kid]; exists { return key, nil } } return nil, fmt.Errorf("key not found") } parsedToken, err := jwt.ParseWithClaims(idToken, &IDTokenClaims{}, keyFunc) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } claims, ok = parsedToken.Claims.(*IDTokenClaims) if !ok { return nil, fmt.Errorf("invalid claims type after parse") } // Validate claims if claims.Issuer != c.issuerURL { return nil, fmt.Errorf("issuer mismatch: expected %s, got %s", c.issuerURL, claims.Issuer) } // Check audience contains clientID audValid := false if claims.Audience != nil { for _, aud := range claims.Audience { if aud == c.clientID { audValid = true break } } } if !audValid { return nil, fmt.Errorf("audience does not contain client ID %s", c.clientID) } // Check expiration if claims.ExpiresAt != nil && time.Now().UTC().After(claims.ExpiresAt.Time) { return nil, fmt.Errorf("token expired") } return claims, nil }