✨ feat(auth): OIDC HTTP handlers /start + /callback (ADR-0028 Phase B.4) (#75)
Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #75.
This commit is contained in:
329
pkg/user/api/oidc_handler.go
Normal file
329
pkg/user/api/oidc_handler.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dance-lessons-coach/pkg/auth"
|
||||
"dance-lessons-coach/pkg/user"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// OIDCHandler exposes the OIDC authorization-code endpoints.
|
||||
type OIDCHandler struct {
|
||||
clients map[string]*auth.OIDCClient // keyed by provider name
|
||||
users user.UserService
|
||||
repo user.UserRepository
|
||||
redirectBase string
|
||||
|
||||
pkceMu sync.Mutex
|
||||
pkceStore map[string]pkceEntry
|
||||
}
|
||||
|
||||
type pkceEntry struct {
|
||||
codeVerifier string
|
||||
providerName string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewOIDCHandler creates a new OIDCHandler.
|
||||
func NewOIDCHandler(clients map[string]*auth.OIDCClient, users user.UserService, repo user.UserRepository, redirectBase string) *OIDCHandler {
|
||||
return &OIDCHandler{
|
||||
clients: clients,
|
||||
users: users,
|
||||
repo: repo,
|
||||
redirectBase: redirectBase,
|
||||
pkceStore: make(map[string]pkceEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes mounts the OIDC endpoints on the provided router.
|
||||
func (h *OIDCHandler) RegisterRoutes(router chi.Router) {
|
||||
router.Get("/oidc/{provider}/start", h.handleStart)
|
||||
router.Get("/oidc/{provider}/callback", h.handleCallback)
|
||||
}
|
||||
|
||||
// handleStart initiates the OIDC authorization-code flow.
|
||||
//
|
||||
// @Summary Start OIDC authorization
|
||||
// @Description Generates PKCE state and verifier, redirects to the OIDC provider authorization endpoint.
|
||||
// @Tags API/v1/User
|
||||
// @Produce json
|
||||
// @Param provider path string true "OIDC provider name"
|
||||
// @Success 302 {string}string "Redirect to OIDC provider"
|
||||
// @Failure 404 {object}map[string]string "Unknown provider"
|
||||
// @Failure 502 {object}map[string]string "Discovery failed"
|
||||
// @Router /v1/auth/oidc/{provider}/start [get]
|
||||
func (h *OIDCHandler) handleStart(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
provider := chi.URLParam(r, "provider")
|
||||
|
||||
client, exists := h.clients[provider]
|
||||
if !exists {
|
||||
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC start: unknown provider")
|
||||
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure discovery is loaded
|
||||
disc, err := client.Discover(ctx)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC start: discovery failed")
|
||||
writeJSONError(w, http.StatusBadGateway, "discovery_failed", fmt.Sprintf("OIDC discovery failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Generate state: 32 bytes random, base64-url-no-padding
|
||||
state := generateRandomBase64URL(32)
|
||||
|
||||
// Generate code verifier: 32 bytes random, base64-url-no-padding
|
||||
codeVerifier := generateRandomBase64URL(32)
|
||||
|
||||
// Compute code challenge: SHA256 hash of code verifier, base64-url-no-padding
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Store PKCE entry
|
||||
h.pkceMu.Lock()
|
||||
// Lazy-clean expired entries
|
||||
now := time.Now()
|
||||
for k, entry := range h.pkceStore {
|
||||
if entry.expiresAt.Before(now) {
|
||||
delete(h.pkceStore, k)
|
||||
}
|
||||
}
|
||||
h.pkceStore[state] = pkceEntry{
|
||||
codeVerifier: codeVerifier,
|
||||
providerName: provider,
|
||||
expiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
h.pkceMu.Unlock()
|
||||
|
||||
// Build redirect URL
|
||||
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
|
||||
|
||||
v := url.Values{}
|
||||
v.Set("response_type", "code")
|
||||
v.Set("client_id", client.ClientID())
|
||||
v.Set("redirect_uri", redirectURI)
|
||||
v.Set("state", state)
|
||||
v.Set("code_challenge", codeChallenge)
|
||||
v.Set("code_challenge_method", "S256")
|
||||
v.Set("scope", "openid email profile")
|
||||
|
||||
target := disc.AuthorizationEndpoint + "?" + v.Encode()
|
||||
|
||||
log.Debug().Ctx(ctx).Str("provider", provider).Str("target", target).Msg("OIDC start: redirecting to provider")
|
||||
|
||||
http.Redirect(w, r, target, http.StatusFound)
|
||||
}
|
||||
|
||||
// handleCallback handles the OIDC callback after authorization.
|
||||
//
|
||||
// @Summary OIDC callback handler
|
||||
// @Description Validates state, exchanges code for tokens, validates id_token, signs up on first use, issues JWT.
|
||||
// @Tags API/v1/User
|
||||
// @Produce json
|
||||
// @Param provider path string true "OIDC provider name"
|
||||
// @Param state query string true "State parameter"
|
||||
// @Param code query string false "Authorization code"
|
||||
// @Param error query string false "OIDC error"
|
||||
// @Success 200 {object} OIDCCallbackResponse "Successfully signed in via OIDC"
|
||||
// @Failure 401 {object} map[string]string "Invalid state, missing code, or OIDC error"
|
||||
// @Failure 502 {object} map[string]string "Token exchange or validation failed"
|
||||
// @Failure 500 {object} map[string]string "Internal server error"
|
||||
// @Router /v1/auth/oidc/{provider}/callback [get]
|
||||
func (h *OIDCHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
provider := chi.URLParam(r, "provider")
|
||||
|
||||
client, exists := h.clients[provider]
|
||||
if !exists {
|
||||
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: unknown provider")
|
||||
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
|
||||
return
|
||||
}
|
||||
|
||||
// Read query parameters
|
||||
state := r.URL.Query().Get("state")
|
||||
code := r.URL.Query().Get("code")
|
||||
oidcError := r.URL.Query().Get("error")
|
||||
|
||||
// If OIDC provider returned an error
|
||||
if oidcError != "" {
|
||||
log.Warn().Ctx(ctx).Str("provider", provider).Str("error", oidcError).Msg("OIDC callback: provider error")
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
||||
"error": "oidc_error",
|
||||
"provider_error": oidcError,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
if state == "" {
|
||||
log.Warn().Ctx(ctx).Msg("OIDC callback: missing state")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.pkceMu.Lock()
|
||||
entry, exists := h.pkceStore[state]
|
||||
if !exists {
|
||||
h.pkceMu.Unlock()
|
||||
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state not found")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state not found or already used")
|
||||
return
|
||||
}
|
||||
|
||||
// Check expiration and provider match
|
||||
now := time.Now()
|
||||
if entry.expiresAt.Before(now) {
|
||||
delete(h.pkceStore, state)
|
||||
h.pkceMu.Unlock()
|
||||
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state expired")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state expired")
|
||||
return
|
||||
}
|
||||
|
||||
if entry.providerName != provider {
|
||||
delete(h.pkceStore, state)
|
||||
h.pkceMu.Unlock()
|
||||
log.Warn().Ctx(ctx).Str("state", state).Str("expected_provider", entry.providerName).Str("actual_provider", provider).Msg("OIDC callback: provider mismatch")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "provider mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the entry (single-use)
|
||||
codeVerifier := entry.codeVerifier
|
||||
delete(h.pkceStore, state)
|
||||
h.pkceMu.Unlock()
|
||||
|
||||
// Validate code parameter
|
||||
if code == "" {
|
||||
log.Warn().Ctx(ctx).Msg("OIDC callback: missing code")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_request", "missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URI
|
||||
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenResp, err := client.ExchangeCode(ctx, code, codeVerifier, redirectURI)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: code exchange failed")
|
||||
writeJSONError(w, http.StatusBadGateway, "token_exchange_failed", fmt.Sprintf("code exchange failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate ID token
|
||||
claims, err := client.ValidateIDToken(ctx, tokenResp.IDToken)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: ID token validation failed")
|
||||
writeJSONError(w, http.StatusUnauthorized, "invalid_id_token", fmt.Sprintf("ID token validation failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Check email in claims
|
||||
if claims.Email == "" {
|
||||
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: no email in ID token")
|
||||
writeJSONError(w, http.StatusUnauthorized, "no_email_in_id_token", "ID token does not contain an email claim")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure user exists (sign-up on first use)
|
||||
u, err := h.ensureUser(ctx, claims.Email)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: user upsert failed")
|
||||
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("user upsert failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT
|
||||
jwtToken, err := h.users.GenerateJWT(ctx, u)
|
||||
if err != nil {
|
||||
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: JWT generation failed")
|
||||
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("JWT generation failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Ctx(ctx).Str("provider", provider).Str("email", claims.Email).Msg("OIDC callback: user signed in successfully")
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "signed in via oidc",
|
||||
"token": jwtToken,
|
||||
"user": claims.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// ensureUser returns the user keyed on email (stored as Username),
|
||||
// creating them if absent. Newly-created users get a random unguessable
|
||||
// bcrypt-hashed password so the password endpoints stay locked out.
|
||||
func (h *OIDCHandler) ensureUser(ctx context.Context, email string) (*user.User, error) {
|
||||
if h.repo != nil {
|
||||
existing, err := h.repo.GetUserByUsername(ctx, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user by username: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return existing, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Generate random password
|
||||
rawPass := generateRandomHex(32)
|
||||
hash, err := h.users.HashPassword(ctx, rawPass)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
u := &user.User{
|
||||
Username: email,
|
||||
PasswordHash: hash,
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
if err := h.users.CreateUser(ctx, u); err != nil {
|
||||
return nil, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
if h.repo != nil {
|
||||
return h.repo.GetUserByUsername(ctx, email)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// generateRandomBase64URL generates a random string suitable for use in OIDC PKCE flows.
|
||||
func generateRandomBase64URL(length int) string {
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("failed to read random bytes: %v", err))
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateRandomHex generates a random hex string.
|
||||
func generateRandomHex(length int) string {
|
||||
b := make([]byte, length/2)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("failed to read random bytes: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// OIDCCallbackResponse represents the JSON response from the OIDC callback.
|
||||
type OIDCCallbackResponse struct {
|
||||
Message string `json:"message"`
|
||||
Token string `json:"token"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
Reference in New Issue
Block a user