diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 08e1788..c11ffc8 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -94,6 +94,21 @@ func NewOIDCClient(issuerURL, clientID, clientSecret string) *OIDCClient { } } +// ClientID returns the OIDC client ID. +func (c *OIDCClient) ClientID() string { + return c.clientID +} + +// IssuerURL returns the OIDC issuer URL. +func (c *OIDCClient) IssuerURL() string { + return c.issuerURL +} + +// SetHTTPClient sets a custom HTTP client for testing. +func (c *OIDCClient) SetHTTPClient(client *http.Client) { + c.httpClient = client +} + // decodeRSAPublicKey reconstructs an *rsa.PublicKey from JWK n and e values. func decodeRSAPublicKey(j jwk) (*rsa.PublicKey, error) { nBytes, err := base64.RawURLEncoding.DecodeString(j.N) diff --git a/pkg/server/server.go b/pkg/server/server.go index b82a516..8aa26b0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -18,6 +18,7 @@ import ( "github.com/rs/zerolog/log" httpSwagger "github.com/swaggo/http-swagger" + "dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/cache" "dance-lessons-coach/pkg/config" "dance-lessons-coach/pkg/email" @@ -279,6 +280,18 @@ func (s *Server) registerApiV1Routes(r chi.Router) { ) mlHandler.RegisterRoutes(r) } + + // OIDC handlers (ADR-0028 Phase B.4) + oidcProviders := s.config.GetOIDCProviders() + if len(oidcProviders) > 0 { + oidcClients := make(map[string]*auth.OIDCClient, len(oidcProviders)) + for name, p := range oidcProviders { + oidcClients[name] = auth.NewOIDCClient(p.IssuerURL, p.ClientID, p.ClientSecret) + } + redirectBase := s.config.GetMagicLinkConfig().BaseURL + oidcHandler := userapi.NewOIDCHandler(oidcClients, s.userService, s.userRepo, redirectBase) + oidcHandler.RegisterRoutes(r) + } }) // Register admin routes diff --git a/pkg/user/api/oidc_handler.go b/pkg/user/api/oidc_handler.go new file mode 100644 index 0000000..3d88f57 --- /dev/null +++ b/pkg/user/api/oidc_handler.go @@ -0,0 +1,329 @@ +package api + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "dance-lessons-coach/pkg/auth" + "dance-lessons-coach/pkg/user" + + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" +) + +// OIDCHandler exposes the OIDC authorization-code endpoints. +type OIDCHandler struct { + clients map[string]*auth.OIDCClient // keyed by provider name + users user.UserService + repo user.UserRepository + redirectBase string + + pkceMu sync.Mutex + pkceStore map[string]pkceEntry +} + +type pkceEntry struct { + codeVerifier string + providerName string + expiresAt time.Time +} + +// NewOIDCHandler creates a new OIDCHandler. +func NewOIDCHandler(clients map[string]*auth.OIDCClient, users user.UserService, repo user.UserRepository, redirectBase string) *OIDCHandler { + return &OIDCHandler{ + clients: clients, + users: users, + repo: repo, + redirectBase: redirectBase, + pkceStore: make(map[string]pkceEntry), + } +} + +// RegisterRoutes mounts the OIDC endpoints on the provided router. +func (h *OIDCHandler) RegisterRoutes(router chi.Router) { + router.Get("/oidc/{provider}/start", h.handleStart) + router.Get("/oidc/{provider}/callback", h.handleCallback) +} + +// handleStart initiates the OIDC authorization-code flow. +// +// @Summary Start OIDC authorization +// @Description Generates PKCE state and verifier, redirects to the OIDC provider authorization endpoint. +// @Tags API/v1/User +// @Produce json +// @Param provider path string true "OIDC provider name" +// @Success 302 {string}string "Redirect to OIDC provider" +// @Failure 404 {object}map[string]string "Unknown provider" +// @Failure 502 {object}map[string]string "Discovery failed" +// @Router /v1/auth/oidc/{provider}/start [get] +func (h *OIDCHandler) handleStart(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + provider := chi.URLParam(r, "provider") + + client, exists := h.clients[provider] + if !exists { + log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC start: unknown provider") + writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider") + return + } + + // Ensure discovery is loaded + disc, err := client.Discover(ctx) + if err != nil { + log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC start: discovery failed") + writeJSONError(w, http.StatusBadGateway, "discovery_failed", fmt.Sprintf("OIDC discovery failed: %v", err)) + return + } + + // Generate state: 32 bytes random, base64-url-no-padding + state := generateRandomBase64URL(32) + + // Generate code verifier: 32 bytes random, base64-url-no-padding + codeVerifier := generateRandomBase64URL(32) + + // Compute code challenge: SHA256 hash of code verifier, base64-url-no-padding + hash := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + // Store PKCE entry + h.pkceMu.Lock() + // Lazy-clean expired entries + now := time.Now() + for k, entry := range h.pkceStore { + if entry.expiresAt.Before(now) { + delete(h.pkceStore, k) + } + } + h.pkceStore[state] = pkceEntry{ + codeVerifier: codeVerifier, + providerName: provider, + expiresAt: now.Add(10 * time.Minute), + } + h.pkceMu.Unlock() + + // Build redirect URL + redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider) + + v := url.Values{} + v.Set("response_type", "code") + v.Set("client_id", client.ClientID()) + v.Set("redirect_uri", redirectURI) + v.Set("state", state) + v.Set("code_challenge", codeChallenge) + v.Set("code_challenge_method", "S256") + v.Set("scope", "openid email profile") + + target := disc.AuthorizationEndpoint + "?" + v.Encode() + + log.Debug().Ctx(ctx).Str("provider", provider).Str("target", target).Msg("OIDC start: redirecting to provider") + + http.Redirect(w, r, target, http.StatusFound) +} + +// handleCallback handles the OIDC callback after authorization. +// +// @Summary OIDC callback handler +// @Description Validates state, exchanges code for tokens, validates id_token, signs up on first use, issues JWT. +// @Tags API/v1/User +// @Produce json +// @Param provider path string true "OIDC provider name" +// @Param state query string true "State parameter" +// @Param code query string false "Authorization code" +// @Param error query string false "OIDC error" +// @Success 200 {object} OIDCCallbackResponse "Successfully signed in via OIDC" +// @Failure 401 {object} map[string]string "Invalid state, missing code, or OIDC error" +// @Failure 502 {object} map[string]string "Token exchange or validation failed" +// @Failure 500 {object} map[string]string "Internal server error" +// @Router /v1/auth/oidc/{provider}/callback [get] +func (h *OIDCHandler) handleCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + provider := chi.URLParam(r, "provider") + + client, exists := h.clients[provider] + if !exists { + log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: unknown provider") + writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider") + return + } + + // Read query parameters + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + oidcError := r.URL.Query().Get("error") + + // If OIDC provider returned an error + if oidcError != "" { + log.Warn().Ctx(ctx).Str("provider", provider).Str("error", oidcError).Msg("OIDC callback: provider error") + writeJSON(w, http.StatusUnauthorized, map[string]string{ + "error": "oidc_error", + "provider_error": oidcError, + }) + return + } + + // Validate state + if state == "" { + log.Warn().Ctx(ctx).Msg("OIDC callback: missing state") + writeJSONError(w, http.StatusUnauthorized, "invalid_state", "missing state parameter") + return + } + + h.pkceMu.Lock() + entry, exists := h.pkceStore[state] + if !exists { + h.pkceMu.Unlock() + log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state not found") + writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state not found or already used") + return + } + + // Check expiration and provider match + now := time.Now() + if entry.expiresAt.Before(now) { + delete(h.pkceStore, state) + h.pkceMu.Unlock() + log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state expired") + writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state expired") + return + } + + if entry.providerName != provider { + delete(h.pkceStore, state) + h.pkceMu.Unlock() + log.Warn().Ctx(ctx).Str("state", state).Str("expected_provider", entry.providerName).Str("actual_provider", provider).Msg("OIDC callback: provider mismatch") + writeJSONError(w, http.StatusUnauthorized, "invalid_state", "provider mismatch") + return + } + + // Delete the entry (single-use) + codeVerifier := entry.codeVerifier + delete(h.pkceStore, state) + h.pkceMu.Unlock() + + // Validate code parameter + if code == "" { + log.Warn().Ctx(ctx).Msg("OIDC callback: missing code") + writeJSONError(w, http.StatusUnauthorized, "invalid_request", "missing authorization code") + return + } + + // Build redirect URI + redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider) + + // Exchange code for tokens + tokenResp, err := client.ExchangeCode(ctx, code, codeVerifier, redirectURI) + if err != nil { + log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: code exchange failed") + writeJSONError(w, http.StatusBadGateway, "token_exchange_failed", fmt.Sprintf("code exchange failed: %v", err)) + return + } + + // Validate ID token + claims, err := client.ValidateIDToken(ctx, tokenResp.IDToken) + if err != nil { + log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: ID token validation failed") + writeJSONError(w, http.StatusUnauthorized, "invalid_id_token", fmt.Sprintf("ID token validation failed: %v", err)) + return + } + + // Check email in claims + if claims.Email == "" { + log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: no email in ID token") + writeJSONError(w, http.StatusUnauthorized, "no_email_in_id_token", "ID token does not contain an email claim") + return + } + + // Ensure user exists (sign-up on first use) + u, err := h.ensureUser(ctx, claims.Email) + if err != nil { + log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: user upsert failed") + writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("user upsert failed: %v", err)) + return + } + + // Generate JWT + jwtToken, err := h.users.GenerateJWT(ctx, u) + if err != nil { + log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: JWT generation failed") + writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("JWT generation failed: %v", err)) + return + } + + log.Info().Ctx(ctx).Str("provider", provider).Str("email", claims.Email).Msg("OIDC callback: user signed in successfully") + + writeJSON(w, http.StatusOK, map[string]string{ + "message": "signed in via oidc", + "token": jwtToken, + "user": claims.Email, + }) +} + +// ensureUser returns the user keyed on email (stored as Username), +// creating them if absent. Newly-created users get a random unguessable +// bcrypt-hashed password so the password endpoints stay locked out. +func (h *OIDCHandler) ensureUser(ctx context.Context, email string) (*user.User, error) { + if h.repo != nil { + existing, err := h.repo.GetUserByUsername(ctx, email) + if err != nil { + return nil, fmt.Errorf("get user by username: %w", err) + } + if existing != nil { + return existing, nil + } + } + + // Generate random password + rawPass := generateRandomHex(32) + hash, err := h.users.HashPassword(ctx, rawPass) + if err != nil { + return nil, fmt.Errorf("hash password: %w", err) + } + + u := &user.User{ + Username: email, + PasswordHash: hash, + IsAdmin: false, + } + + if err := h.users.CreateUser(ctx, u); err != nil { + return nil, fmt.Errorf("create user: %w", err) + } + + if h.repo != nil { + return h.repo.GetUserByUsername(ctx, email) + } + return u, nil +} + +// generateRandomBase64URL generates a random string suitable for use in OIDC PKCE flows. +func generateRandomBase64URL(length int) string { + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("failed to read random bytes: %v", err)) + } + return base64.RawURLEncoding.EncodeToString(b) +} + +// generateRandomHex generates a random hex string. +func generateRandomHex(length int) string { + b := make([]byte, length/2) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("failed to read random bytes: %v", err)) + } + return hex.EncodeToString(b) +} + +// OIDCCallbackResponse represents the JSON response from the OIDC callback. +type OIDCCallbackResponse struct { + Message string `json:"message"` + Token string `json:"token"` + User string `json:"user"` +}