✨ feat(auth): OIDC HTTP handlers /start + /callback (ADR-0028 Phase B.4) (#75)
Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #75.
This commit is contained in:
@@ -94,6 +94,21 @@ func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientID returns the OIDC client ID.
|
||||||
|
func (c *OIDCClient) ClientID() string {
|
||||||
|
return c.clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuerURL returns the OIDC issuer URL.
|
||||||
|
func (c *OIDCClient) IssuerURL() string {
|
||||||
|
return c.issuerURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHTTPClient sets a custom HTTP client for testing.
|
||||||
|
func (c *OIDCClient) SetHTTPClient(client *http.Client) {
|
||||||
|
c.httpClient = client
|
||||||
|
}
|
||||||
|
|
||||||
// decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values.
|
// decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values.
|
||||||
func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) {
|
func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) {
|
||||||
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
|
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
httpSwagger "github.com/swaggo/http-swagger"
|
httpSwagger "github.com/swaggo/http-swagger"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/auth"
|
||||||
"dance-lessons-coach/pkg/cache"
|
"dance-lessons-coach/pkg/cache"
|
||||||
"dance-lessons-coach/pkg/config"
|
"dance-lessons-coach/pkg/config"
|
||||||
"dance-lessons-coach/pkg/email"
|
"dance-lessons-coach/pkg/email"
|
||||||
@@ -279,6 +280,18 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
|
|||||||
)
|
)
|
||||||
mlHandler.RegisterRoutes(r)
|
mlHandler.RegisterRoutes(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OIDC handlers (ADR-0028 Phase B.4)
|
||||||
|
oidcProviders := s.config.GetOIDCProviders()
|
||||||
|
if len(oidcProviders) > 0 {
|
||||||
|
oidcClients := make(map[string]*auth.OIDCClient, len(oidcProviders))
|
||||||
|
for name, p := range oidcProviders {
|
||||||
|
oidcClients[name] = auth.NewOIDCClient(p.IssuerURL, p.ClientID, p.ClientSecret)
|
||||||
|
}
|
||||||
|
redirectBase := s.config.GetMagicLinkConfig().BaseURL
|
||||||
|
oidcHandler := userapi.NewOIDCHandler(oidcClients, s.userService, s.userRepo, redirectBase)
|
||||||
|
oidcHandler.RegisterRoutes(r)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Register admin routes
|
// Register admin routes
|
||||||
|
|||||||
329
pkg/user/api/oidc_handler.go
Normal file
329
pkg/user/api/oidc_handler.go
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/auth"
|
||||||
|
"dance-lessons-coach/pkg/user"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OIDCHandler exposes the OIDC authorization-code endpoints.
|
||||||
|
type OIDCHandler struct {
|
||||||
|
clients map[string]*auth.OIDCClient // keyed by provider name
|
||||||
|
users user.UserService
|
||||||
|
repo user.UserRepository
|
||||||
|
redirectBase string
|
||||||
|
|
||||||
|
pkceMu sync.Mutex
|
||||||
|
pkceStore map[string]pkceEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type pkceEntry struct {
|
||||||
|
codeVerifier string
|
||||||
|
providerName string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOIDCHandler creates a new OIDCHandler.
|
||||||
|
func NewOIDCHandler(clients map[string]*auth.OIDCClient, users user.UserService, repo user.UserRepository, redirectBase string) *OIDCHandler {
|
||||||
|
return &OIDCHandler{
|
||||||
|
clients: clients,
|
||||||
|
users: users,
|
||||||
|
repo: repo,
|
||||||
|
redirectBase: redirectBase,
|
||||||
|
pkceStore: make(map[string]pkceEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes mounts the OIDC endpoints on the provided router.
|
||||||
|
func (h *OIDCHandler) RegisterRoutes(router chi.Router) {
|
||||||
|
router.Get("/oidc/{provider}/start", h.handleStart)
|
||||||
|
router.Get("/oidc/{provider}/callback", h.handleCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStart initiates the OIDC authorization-code flow.
|
||||||
|
//
|
||||||
|
// @Summary Start OIDC authorization
|
||||||
|
// @Description Generates PKCE state and verifier, redirects to the OIDC provider authorization endpoint.
|
||||||
|
// @Tags API/v1/User
|
||||||
|
// @Produce json
|
||||||
|
// @Param provider path string true "OIDC provider name"
|
||||||
|
// @Success 302 {string}string "Redirect to OIDC provider"
|
||||||
|
// @Failure 404 {object}map[string]string "Unknown provider"
|
||||||
|
// @Failure 502 {object}map[string]string "Discovery failed"
|
||||||
|
// @Router /v1/auth/oidc/{provider}/start [get]
|
||||||
|
func (h *OIDCHandler) handleStart(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
provider := chi.URLParam(r, "provider")
|
||||||
|
|
||||||
|
client, exists := h.clients[provider]
|
||||||
|
if !exists {
|
||||||
|
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC start: unknown provider")
|
||||||
|
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure discovery is loaded
|
||||||
|
disc, err := client.Discover(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC start: discovery failed")
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "discovery_failed", fmt.Sprintf("OIDC discovery failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate state: 32 bytes random, base64-url-no-padding
|
||||||
|
state := generateRandomBase64URL(32)
|
||||||
|
|
||||||
|
// Generate code verifier: 32 bytes random, base64-url-no-padding
|
||||||
|
codeVerifier := generateRandomBase64URL(32)
|
||||||
|
|
||||||
|
// Compute code challenge: SHA256 hash of code verifier, base64-url-no-padding
|
||||||
|
hash := sha256.Sum256([]byte(codeVerifier))
|
||||||
|
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
// Store PKCE entry
|
||||||
|
h.pkceMu.Lock()
|
||||||
|
// Lazy-clean expired entries
|
||||||
|
now := time.Now()
|
||||||
|
for k, entry := range h.pkceStore {
|
||||||
|
if entry.expiresAt.Before(now) {
|
||||||
|
delete(h.pkceStore, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.pkceStore[state] = pkceEntry{
|
||||||
|
codeVerifier: codeVerifier,
|
||||||
|
providerName: provider,
|
||||||
|
expiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
h.pkceMu.Unlock()
|
||||||
|
|
||||||
|
// Build redirect URL
|
||||||
|
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
|
||||||
|
|
||||||
|
v := url.Values{}
|
||||||
|
v.Set("response_type", "code")
|
||||||
|
v.Set("client_id", client.ClientID())
|
||||||
|
v.Set("redirect_uri", redirectURI)
|
||||||
|
v.Set("state", state)
|
||||||
|
v.Set("code_challenge", codeChallenge)
|
||||||
|
v.Set("code_challenge_method", "S256")
|
||||||
|
v.Set("scope", "openid email profile")
|
||||||
|
|
||||||
|
target := disc.AuthorizationEndpoint + "?" + v.Encode()
|
||||||
|
|
||||||
|
log.Debug().Ctx(ctx).Str("provider", provider).Str("target", target).Msg("OIDC start: redirecting to provider")
|
||||||
|
|
||||||
|
http.Redirect(w, r, target, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCallback handles the OIDC callback after authorization.
|
||||||
|
//
|
||||||
|
// @Summary OIDC callback handler
|
||||||
|
// @Description Validates state, exchanges code for tokens, validates id_token, signs up on first use, issues JWT.
|
||||||
|
// @Tags API/v1/User
|
||||||
|
// @Produce json
|
||||||
|
// @Param provider path string true "OIDC provider name"
|
||||||
|
// @Param state query string true "State parameter"
|
||||||
|
// @Param code query string false "Authorization code"
|
||||||
|
// @Param error query string false "OIDC error"
|
||||||
|
// @Success 200 {object} OIDCCallbackResponse "Successfully signed in via OIDC"
|
||||||
|
// @Failure 401 {object} map[string]string "Invalid state, missing code, or OIDC error"
|
||||||
|
// @Failure 502 {object} map[string]string "Token exchange or validation failed"
|
||||||
|
// @Failure 500 {object} map[string]string "Internal server error"
|
||||||
|
// @Router /v1/auth/oidc/{provider}/callback [get]
|
||||||
|
func (h *OIDCHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
provider := chi.URLParam(r, "provider")
|
||||||
|
|
||||||
|
client, exists := h.clients[provider]
|
||||||
|
if !exists {
|
||||||
|
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: unknown provider")
|
||||||
|
writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read query parameters
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
oidcError := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
// If OIDC provider returned an error
|
||||||
|
if oidcError != "" {
|
||||||
|
log.Warn().Ctx(ctx).Str("provider", provider).Str("error", oidcError).Msg("OIDC callback: provider error")
|
||||||
|
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
||||||
|
"error": "oidc_error",
|
||||||
|
"provider_error": oidcError,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate state
|
||||||
|
if state == "" {
|
||||||
|
log.Warn().Ctx(ctx).Msg("OIDC callback: missing state")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "missing state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.pkceMu.Lock()
|
||||||
|
entry, exists := h.pkceStore[state]
|
||||||
|
if !exists {
|
||||||
|
h.pkceMu.Unlock()
|
||||||
|
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state not found")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state not found or already used")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration and provider match
|
||||||
|
now := time.Now()
|
||||||
|
if entry.expiresAt.Before(now) {
|
||||||
|
delete(h.pkceStore, state)
|
||||||
|
h.pkceMu.Unlock()
|
||||||
|
log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state expired")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state expired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.providerName != provider {
|
||||||
|
delete(h.pkceStore, state)
|
||||||
|
h.pkceMu.Unlock()
|
||||||
|
log.Warn().Ctx(ctx).Str("state", state).Str("expected_provider", entry.providerName).Str("actual_provider", provider).Msg("OIDC callback: provider mismatch")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_state", "provider mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the entry (single-use)
|
||||||
|
codeVerifier := entry.codeVerifier
|
||||||
|
delete(h.pkceStore, state)
|
||||||
|
h.pkceMu.Unlock()
|
||||||
|
|
||||||
|
// Validate code parameter
|
||||||
|
if code == "" {
|
||||||
|
log.Warn().Ctx(ctx).Msg("OIDC callback: missing code")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_request", "missing authorization code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build redirect URI
|
||||||
|
redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider)
|
||||||
|
|
||||||
|
// Exchange code for tokens
|
||||||
|
tokenResp, err := client.ExchangeCode(ctx, code, codeVerifier, redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: code exchange failed")
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "token_exchange_failed", fmt.Sprintf("code exchange failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ID token
|
||||||
|
claims, err := client.ValidateIDToken(ctx, tokenResp.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: ID token validation failed")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "invalid_id_token", fmt.Sprintf("ID token validation failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check email in claims
|
||||||
|
if claims.Email == "" {
|
||||||
|
log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: no email in ID token")
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "no_email_in_id_token", "ID token does not contain an email claim")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure user exists (sign-up on first use)
|
||||||
|
u, err := h.ensureUser(ctx, claims.Email)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: user upsert failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("user upsert failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate JWT
|
||||||
|
jwtToken, err := h.users.GenerateJWT(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: JWT generation failed")
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("JWT generation failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Ctx(ctx).Str("provider", provider).Str("email", claims.Email).Msg("OIDC callback: user signed in successfully")
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"message": "signed in via oidc",
|
||||||
|
"token": jwtToken,
|
||||||
|
"user": claims.Email,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUser returns the user keyed on email (stored as Username),
|
||||||
|
// creating them if absent. Newly-created users get a random unguessable
|
||||||
|
// bcrypt-hashed password so the password endpoints stay locked out.
|
||||||
|
func (h *OIDCHandler) ensureUser(ctx context.Context, email string) (*user.User, error) {
|
||||||
|
if h.repo != nil {
|
||||||
|
existing, err := h.repo.GetUserByUsername(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get user by username: %w", err)
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate random password
|
||||||
|
rawPass := generateRandomHex(32)
|
||||||
|
hash, err := h.users.HashPassword(ctx, rawPass)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u := &user.User{
|
||||||
|
Username: email,
|
||||||
|
PasswordHash: hash,
|
||||||
|
IsAdmin: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.users.CreateUser(ctx, u); err != nil {
|
||||||
|
return nil, fmt.Errorf("create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.repo != nil {
|
||||||
|
return h.repo.GetUserByUsername(ctx, email)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomBase64URL generates a random string suitable for use in OIDC PKCE flows.
|
||||||
|
func generateRandomBase64URL(length int) string {
|
||||||
|
b := make([]byte, length)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to read random bytes: %v", err))
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomHex generates a random hex string.
|
||||||
|
func generateRandomHex(length int) string {
|
||||||
|
b := make([]byte, length/2)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to read random bytes: %v", err))
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCCallbackResponse represents the JSON response from the OIDC callback.
|
||||||
|
type OIDCCallbackResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
User string `json:"user"`
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user