Files
dance-lessons-coach/pkg/auth/oidc.go
Gabriel Radureau b27d8168eb feat(auth): OIDC HTTP handlers /start + /callback (ADR-0028 Phase B.4)
Two endpoints implementing the OIDC Authorization Code with PKCE flow:
- GET /api/v1/auth/oidc/{provider}/start — generates state + PKCE
  verifier, redirects to provider's authorization_endpoint
- GET /api/v1/auth/oidc/{provider}/callback — validates state,
  exchanges code, validates id_token, signs up on first-use, issues JWT

Wires into pkg/server/server.go alongside the magic-link handler ;
gated on len(GetOIDCProviders()) > 0 so it stays inactive until at
least one provider is configured.

pkg/auth/oidc.go : adds 2 small getters (ClientID, IssuerURL) needed
by the handler for redirect URL construction.

Authoring : Mostly Mistral Vibe (batch7, $4.60 / 45 steps — Q-045 hit
the price cap before merge). Trainer takeover ~5 min :
- removed the broken test file (Mistral's fakeOIDCUserSvc /
  fakeOIDCUserRepo didn't implement the full interfaces ; tests
  for the handler will land in a follow-up PR using the existing
  fakeUserSvc / fakeUserRepo from magic_link_handler_test.go)
- verified build + vet + go test ./pkg/user/api/... green

Phase B.5 (BDD scenarios with mock provider) and the missing
oidc_handler_test.go remain TODO. Brief ready :
~/Work/Vibe/workspaces/PHASE-B-5-READY-TO-LAUNCH.md
2026-05-05 22:29:14 +02:00

346 lines
9.3 KiB
Go

// Package auth provides OpenID Connect client primitives for the
// dance-lessons-coach passwordless-auth migration (ADR-0028 Phase B).
//
// This file defines the client surface only. HTTP handlers wire-up
// happens in pkg/user/api/oidc_handler.go (separate phase B.3).
package auth
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
// OIDCClient is a per-provider OIDC client.
// Holds the discovery document + JWKS cache + OAuth code-exchange config.
type OIDCClient struct {
issuerURL string
clientID string
clientSecret string
httpClient *http.Client
// discovery document, lazy-fetched on first use
discoveryMu sync.RWMutex
discovery *Discovery
// JWKS cache (id_token signature verification keys), refreshed periodically
jwksMu sync.RWMutex
jwks map[string]*rsa.PublicKey
jwksFetched time.Time
}
// Discovery is the subset of the .well-known/openid-configuration document we use.
type Discovery struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
JWKSUri string `json:"jwks_uri"`
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
}
// TokenResponse is the response from the token endpoint after code exchange.
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
Scope string `json:"scope,omitempty"`
}
// IDTokenClaims represents the parsed claims from an ID token.
type IDTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
Name string `json:"name,omitempty"`
}
// jwks represents the JWKS (JSON Web Key Set) response.
type jwks struct {
Keys []jwk `json:"keys"`
}
// jwk represents a single JSON Web Key.
type jwk struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
E string `json:"e"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
}
// NewOIDCClient constructs a client. Discovery + JWKS are NOT fetched eagerly;
// they are lazy-loaded on first use to avoid blocking server startup if the
// provider is temporarily down.
func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient {
return &OIDCClient{
issuerURL: issuerURL,
clientID: clientID,
clientSecret: clientSecret,
httpClient: &http.Client{Timeout: 10 * time.Second},
jwks: make(map[string]*rsa.PublicKey),
}
}
// ClientID returns the OIDC client ID.
func (c *OIDCClient) ClientID() string {
return c.clientID
}
// IssuerURL returns the OIDC issuer URL.
func (c *OIDCClient) IssuerURL() string {
return c.issuerURL
}
// SetHTTPClient sets a custom HTTP client for testing.
func (c *OIDCClient) SetHTTPClient(client *http.Client) {
c.httpClient = client
}
// decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values.
func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
if err != nil {
return nil, fmt.Errorf("decode n: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
if err != nil {
return nil, fmt.Errorf("decode e: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil
}
// Discover fetches and caches the .well-known document. Idempotent.
// First call: HTTP fetch + cache. Subsequent calls: cached value.
func (c *OIDCClient) Discover(ctx context.Context) (*Discovery, error) {
c.discoveryMu.RLock()
if c.discovery != nil {
c.discoveryMu.RUnlock()
return c.discovery, nil
}
c.discoveryMu.RUnlock()
c.discoveryMu.Lock()
defer c.discoveryMu.Unlock()
// Double-check after acquiring write lock
if c.discovery != nil {
return c.discovery, nil
}
wellKnownURL := fmt.Sprintf("%s/.well-known/openid-configuration", c.issuerURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL, nil)
if err != nil {
return nil, fmt.Errorf("create discovery request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch discovery: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("discovery HTTP %d", resp.StatusCode)
}
var disc Discovery
if err := json.NewDecoder(resp.Body).Decode(&disc); err != nil {
return nil, fmt.Errorf("decode discovery: %w", err)
}
c.discovery = &disc
return &disc, nil
}
// RefreshJWKS fetches JWKS URI, parse keys, populate jwks map.
func (c *OIDCClient) RefreshJWKS(ctx context.Context) error {
// Ensure discovery is loaded
if c.discovery == nil {
if _, err := c.Discover(ctx); err != nil {
return fmt.Errorf("discover: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.discovery.JWKSUri, nil)
if err != nil {
return fmt.Errorf("create JWKS request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("JWKS HTTP %d", resp.StatusCode)
}
var keySet jwks
if err := json.NewDecoder(resp.Body).Decode(&keySet); err != nil {
return fmt.Errorf("decode JWKS: %w", err)
}
c.jwksMu.Lock()
defer c.jwksMu.Unlock()
c.jwks = make(map[string]*rsa.PublicKey)
for _, key := range keySet.Keys {
if key.Kty == "RSA" {
pubKey, err := decodeRSAPublicKey(key)
if err != nil {
return fmt.Errorf("decode RSA key %s: %w", key.Kid, err)
}
c.jwks[key.Kid] = pubKey
}
}
c.jwksFetched = time.Now()
return nil
}
// ExchangeCode exchanges an authorization code for an access token and ID token.
func (c *OIDCClient) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI string) (*TokenResponse, error) {
// Ensure discovery is loaded
if c.discovery == nil {
if _, err := c.Discover(ctx); err != nil {
return nil, fmt.Errorf("discover: %w", err)
}
}
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("code_verifier", codeVerifier)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", c.clientID)
form.Set("client_secret", c.clientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.discovery.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("exchange code: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token HTTP %d", resp.StatusCode)
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, fmt.Errorf("decode token response: %w", err)
}
return &tokenResp, nil
}
// ValidateIDToken verifies the signature and claims of an ID token.
func (c *OIDCClient) ValidateIDToken(ctx context.Context, idToken string) (*IDTokenClaims, error) {
// First, parse without verification to get the kid
parser := jwt.NewParser()
unverifiedToken, _, err := parser.ParseUnverified(idToken, &IDTokenClaims{})
if err != nil {
return nil, fmt.Errorf("parse unverified token: %w", err)
}
claims, ok := unverifiedToken.Claims.(*IDTokenClaims)
if !ok {
return nil, fmt.Errorf("invalid claims type")
}
// Get kid from header
kid, ok := unverifiedToken.Header["kid"].(string)
if !ok || kid == "" {
return nil, fmt.Errorf("missing kid in token header")
}
// Get the key, refreshing JWKS if needed
c.jwksMu.RLock()
_, keyExists := c.jwks[kid]
c.jwksMu.RUnlock()
if !keyExists {
if err := c.RefreshJWKS(ctx); err != nil {
return nil, fmt.Errorf("refresh JWKS: %w", err)
}
c.jwksMu.RLock()
_, keyExists = c.jwks[kid]
c.jwksMu.RUnlock()
if !keyExists {
return nil, fmt.Errorf("key %s not found in JWKS", kid)
}
}
// Parse with verification
keyFunc := func(token *jwt.Token) (interface{}, error) {
if kid, ok := token.Header["kid"].(string); ok {
c.jwksMu.RLock()
defer c.jwksMu.RUnlock()
if key, exists := c.jwks[kid]; exists {
return key, nil
}
}
return nil, fmt.Errorf("key not found")
}
parsedToken, err := jwt.ParseWithClaims(idToken, &IDTokenClaims{}, keyFunc)
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
claims, ok = parsedToken.Claims.(*IDTokenClaims)
if !ok {
return nil, fmt.Errorf("invalid claims type after parse")
}
// Validate claims
if claims.Issuer != c.issuerURL {
return nil, fmt.Errorf("issuer mismatch: expected %s, got %s", c.issuerURL, claims.Issuer)
}
// Check audience contains clientID
audValid := false
if claims.Audience != nil {
for _, aud := range claims.Audience {
if aud == c.clientID {
audValid = true
break
}
}
}
if !audValid {
return nil, fmt.Errorf("audience does not contain client ID %s", c.clientID)
}
// Check expiration
if claims.ExpiresAt != nil && time.Now().UTC().After(claims.ExpiresAt.Time) {
return nil, fmt.Errorf("token expired")
}
return claims, nil
}