feat(auth): OIDC HTTP handlers /start + /callback (ADR-0028 Phase B.4) #75

Merged
arcodange merged 1 commits from vibe/batch7-task-oidc-handlers into main 2026-05-05 22:29:35 +02:00
3 changed files with 357 additions and 0 deletions

View File

@@ -94,6 +94,21 @@ func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient {
}
}
// ClientID returns the OIDC client ID.
func (c *OIDCClient) ClientID() string {
return c.clientID
}
// IssuerURL returns the OIDC issuer URL.
func (c *OIDCClient) IssuerURL() string {
return c.issuerURL
}
// SetHTTPClient sets a custom HTTP client for testing.
func (c *OIDCClient) SetHTTPClient(client *http.Client) {
c.httpClient = client
}
// decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values.
func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)

View File

@@ -18,6 +18,7 @@ import (
"github.com/rs/zerolog/log"
httpSwagger "github.com/swaggo/http-swagger"
"dance-lessons-coach/pkg/auth"
"dance-lessons-coach/pkg/cache"
"dance-lessons-coach/pkg/config"
"dance-lessons-coach/pkg/email"
@@ -279,6 +280,18 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
)
mlHandler.RegisterRoutes(r)
}
// OIDC handlers (ADR-0028 Phase B.4)
oidcProviders := s.config.GetOIDCProviders()
if len(oidcProviders) > 0 {
oidcClients := make(map[string]*auth.OIDCClient, len(oidcProviders))
for name, p := range oidcProviders {
oidcClients[name] = auth.NewOIDCClient(p.IssuerURL, p.ClientID, p.ClientSecret)
}
redirectBase := s.config.GetMagicLinkConfig().BaseURL
oidcHandler := userapi.NewOIDCHandler(oidcClients, s.userService, s.userRepo, redirectBase)
oidcHandler.RegisterRoutes(r)
}
})
// Register admin routes

View File

@@ -0,0 +1,329 @@
package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"dance-lessons-coach/pkg/auth"
"dance-lessons-coach/pkg/user"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log"
)
// OIDCHandler exposes the OIDC authorization-code endpoints.
type OIDCHandler struct {
clients map[string]*auth.OIDCClient // keyed by provider name
users user.UserService
repo user.UserRepository
redirectBase string
pkceMu sync.Mutex
pkceStore map[string]pkceEntry
}
type pkceEntry struct {
codeVerifier string
providerName string
expiresAt time.Time
}
// NewOIDCHandler creates a new OIDCHandler.
func NewOIDCHandler(clients map[string]*auth.OIDCClient, users user.UserService, repo user.UserRepository, redirectBase string) *OIDCHandler {
return &OIDCHandler{
clients: clients,
users: users,
repo: repo,
redirectBase: redirectBase,
pkceStore: make(map[string]pkceEntry),
}
}
// RegisterRoutes mounts the OIDC endpoints on the provided router.
func (h *OIDCHandler) RegisterRoutes(router chi.Router) {
router.Get("/oidc/{provider}/start", h.handleStart)
router.Get("/oidc/{provider}/callback", h.handleCallback)
}
// handleStart initiates the OIDC authorization-code flow.
//
// @Summary Start OIDC authorization
// @Description Generates PKCE state and verifier, redirects to the OIDC provider authorization endpoint.
// @Tags API/v1/User
// @Produce json
// @Param provider path string true "OIDC provider name"
// @Success 302 {string}string "Redirect to OIDC provider"
// @Failure 404 {object}map[string]string "Unknown provider"
// @Failure 502 {object}map[string]string "Discovery failed"
// @Router /v1/auth/oidc/{provider}/start [get]
func (h *OIDCHandler) handleStart(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
provider := chi.URLParam(r, "provider")
client, exists := h.clients[provider]
if !exists {
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC start: unknown provider")
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
return
}
// Ensure discovery is loaded
disc, err := client.Discover(ctx)
if err != nil {
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC start: discovery failed")
writeJSONError(w, http.StatusBadGateway, "discovery_failed", fmt.Sprintf("OIDC discovery failed: %v", err))
return
}
// Generate state: 32 bytes random, base64-url-no-padding
state := generateRandomBase64URL(32)
// Generate code verifier: 32 bytes random, base64-url-no-padding
codeVerifier := generateRandomBase64URL(32)
// Compute code challenge: SHA256 hash of code verifier, base64-url-no-padding
hash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
// Store PKCE entry
h.pkceMu.Lock()
// Lazy-clean expired entries
now := time.Now()
for k, entry := range h.pkceStore {
if entry.expiresAt.Before(now) {
delete(h.pkceStore, k)
}
}
h.pkceStore[state] = pkceEntry{
codeVerifier: codeVerifier,
providerName: provider,
expiresAt: now.Add(10 * time.Minute),
}
h.pkceMu.Unlock()
// Build redirect URL
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
v := url.Values{}
v.Set("response_type", "code")
v.Set("client_id", client.ClientID())
v.Set("redirect_uri", redirectURI)
v.Set("state", state)
v.Set("code_challenge", codeChallenge)
v.Set("code_challenge_method", "S256")
v.Set("scope", "openid email profile")
target := disc.AuthorizationEndpoint + "?" + v.Encode()
log.Debug().Ctx(ctx).Str("provider", provider).Str("target", target).Msg("OIDC start: redirecting to provider")
http.Redirect(w, r, target, http.StatusFound)
}
// handleCallback handles the OIDC callback after authorization.
//
// @Summary OIDC callback handler
// @Description Validates state, exchanges code for tokens, validates id_token, signs up on first use, issues JWT.
// @Tags API/v1/User
// @Produce json
// @Param provider path string true "OIDC provider name"
// @Param state query string true "State parameter"
// @Param code query string false "Authorization code"
// @Param error query string false "OIDC error"
// @Success 200 {object} OIDCCallbackResponse "Successfully signed in via OIDC"
// @Failure 401 {object} map[string]string "Invalid state, missing code, or OIDC error"
// @Failure 502 {object} map[string]string "Token exchange or validation failed"
// @Failure 500 {object} map[string]string "Internal server error"
// @Router /v1/auth/oidc/{provider}/callback [get]
func (h *OIDCHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
provider := chi.URLParam(r, "provider")
client, exists := h.clients[provider]
if !exists {
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: unknown provider")
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
return
}
// Read query parameters
state := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")
oidcError := r.URL.Query().Get("error")
// If OIDC provider returned an error
if oidcError != "" {
log.Warn().Ctx(ctx).Str("provider", provider).Str("error", oidcError).Msg("OIDC callback: provider error")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"error": "oidc_error",
"provider_error": oidcError,
})
return
}
// Validate state
if state == "" {
log.Warn().Ctx(ctx).Msg("OIDC callback: missing state")
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "missing state parameter")
return
}
h.pkceMu.Lock()
entry, exists := h.pkceStore[state]
if !exists {
h.pkceMu.Unlock()
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state not found")
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state not found or already used")
return
}
// Check expiration and provider match
now := time.Now()
if entry.expiresAt.Before(now) {
delete(h.pkceStore, state)
h.pkceMu.Unlock()
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state expired")
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state expired")
return
}
if entry.providerName != provider {
delete(h.pkceStore, state)
h.pkceMu.Unlock()
log.Warn().Ctx(ctx).Str("state", state).Str("expected_provider", entry.providerName).Str("actual_provider", provider).Msg("OIDC callback: provider mismatch")
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "provider mismatch")
return
}
// Delete the entry (single-use)
codeVerifier := entry.codeVerifier
delete(h.pkceStore, state)
h.pkceMu.Unlock()
// Validate code parameter
if code == "" {
log.Warn().Ctx(ctx).Msg("OIDC callback: missing code")
writeJSONError(w, http.StatusUnauthorized, "invalid_request", "missing authorization code")
return
}
// Build redirect URI
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
// Exchange code for tokens
tokenResp, err := client.ExchangeCode(ctx, code, codeVerifier, redirectURI)
if err != nil {
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: code exchange failed")
writeJSONError(w, http.StatusBadGateway, "token_exchange_failed", fmt.Sprintf("code exchange failed: %v", err))
return
}
// Validate ID token
claims, err := client.ValidateIDToken(ctx, tokenResp.IDToken)
if err != nil {
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: ID token validation failed")
writeJSONError(w, http.StatusUnauthorized, "invalid_id_token", fmt.Sprintf("ID token validation failed: %v", err))
return
}
// Check email in claims
if claims.Email == "" {
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: no email in ID token")
writeJSONError(w, http.StatusUnauthorized, "no_email_in_id_token", "ID token does not contain an email claim")
return
}
// Ensure user exists (sign-up on first use)
u, err := h.ensureUser(ctx, claims.Email)
if err != nil {
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: user upsert failed")
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("user upsert failed: %v", err))
return
}
// Generate JWT
jwtToken, err := h.users.GenerateJWT(ctx, u)
if err != nil {
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: JWT generation failed")
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("JWT generation failed: %v", err))
return
}
log.Info().Ctx(ctx).Str("provider", provider).Str("email", claims.Email).Msg("OIDC callback: user signed in successfully")
writeJSON(w, http.StatusOK, map[string]string{
"message": "signed in via oidc",
"token": jwtToken,
"user": claims.Email,
})
}
// ensureUser returns the user keyed on email (stored as Username),
// creating them if absent. Newly-created users get a random unguessable
// bcrypt-hashed password so the password endpoints stay locked out.
func (h *OIDCHandler) ensureUser(ctx context.Context, email string) (*user.User, error) {
if h.repo != nil {
existing, err := h.repo.GetUserByUsername(ctx, email)
if err != nil {
return nil, fmt.Errorf("get user by username: %w", err)
}
if existing != nil {
return existing, nil
}
}
// Generate random password
rawPass := generateRandomHex(32)
hash, err := h.users.HashPassword(ctx, rawPass)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
u := &user.User{
Username: email,
PasswordHash: hash,
IsAdmin: false,
}
if err := h.users.CreateUser(ctx, u); err != nil {
return nil, fmt.Errorf("create user: %w", err)
}
if h.repo != nil {
return h.repo.GetUserByUsername(ctx, email)
}
return u, nil
}
// generateRandomBase64URL generates a random string suitable for use in OIDC PKCE flows.
func generateRandomBase64URL(length int) string {
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("failed to read random bytes: %v", err))
}
return base64.RawURLEncoding.EncodeToString(b)
}
// generateRandomHex generates a random hex string.
func generateRandomHex(length int) string {
b := make([]byte, length/2)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("failed to read random bytes: %v", err))
}
return hex.EncodeToString(b)
}
// OIDCCallbackResponse represents the JSON response from the OIDC callback.
type OIDCCallbackResponse struct {
Message string `json:"message"`
Token string `json:"token"`
User string `json:"user"`
}