package api import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "fmt" "net/http" "net/url" "sync" "time" "dance-lessons-coach/pkg/auth" "dance-lessons-coach/pkg/user" "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" ) // OIDCHandler exposes the OIDC authorization-code endpoints. type OIDCHandler struct { clients map[string]*auth.OIDCClient // keyed by provider name users user.UserService repo user.UserRepository redirectBase string pkceMu sync.Mutex pkceStore map[string]pkceEntry } type pkceEntry struct { codeVerifier string providerName string expiresAt time.Time } // NewOIDCHandler creates a new OIDCHandler. func NewOIDCHandler(clients map[string]*auth.OIDCClient, users user.UserService, repo user.UserRepository, redirectBase string) *OIDCHandler { return &OIDCHandler{ clients: clients, users: users, repo: repo, redirectBase: redirectBase, pkceStore: make(map[string]pkceEntry), } } // RegisterRoutes mounts the OIDC endpoints on the provided router. func (h *OIDCHandler) RegisterRoutes(router chi.Router) { router.Get("/oidc/{provider}/start", h.handleStart) router.Get("/oidc/{provider}/callback", h.handleCallback) } // handleStart initiates the OIDC authorization-code flow. // // @Summary Start OIDC authorization // @Description Generates PKCE state and verifier, redirects to the OIDC provider authorization endpoint. // @Tags API/v1/User // @Produce json // @Param provider path string true "OIDC provider name" // @Success 302 {string}string "Redirect to OIDC provider" // @Failure 404 {object}map[string]string "Unknown provider" // @Failure 502 {object}map[string]string "Discovery failed" // @Router /v1/auth/oidc/{provider}/start [get] func (h *OIDCHandler) handleStart(w http.ResponseWriter, r *http.Request) { ctx := r.Context() provider := chi.URLParam(r, "provider") client, exists := h.clients[provider] if !exists { log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC start: unknown provider") writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider") return } // Ensure discovery is loaded disc, err := client.Discover(ctx) if err != nil { log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC start: discovery failed") writeJSONError(w, http.StatusBadGateway, "discovery_failed", fmt.Sprintf("OIDC discovery failed: %v", err)) return } // Generate state: 32 bytes random, base64-url-no-padding state := generateRandomBase64URL(32) // Generate code verifier: 32 bytes random, base64-url-no-padding codeVerifier := generateRandomBase64URL(32) // Compute code challenge: SHA256 hash of code verifier, base64-url-no-padding hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) // Store PKCE entry h.pkceMu.Lock() // Lazy-clean expired entries now := time.Now() for k, entry := range h.pkceStore { if entry.expiresAt.Before(now) { delete(h.pkceStore, k) } } h.pkceStore[state] = pkceEntry{ codeVerifier: codeVerifier, providerName: provider, expiresAt: now.Add(10 * time.Minute), } h.pkceMu.Unlock() // Build redirect URL redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider) v := url.Values{} v.Set("response_type", "code") v.Set("client_id", client.ClientID()) v.Set("redirect_uri", redirectURI) v.Set("state", state) v.Set("code_challenge", codeChallenge) v.Set("code_challenge_method", "S256") v.Set("scope", "openid email profile") target := disc.AuthorizationEndpoint + "?" + v.Encode() log.Debug().Ctx(ctx).Str("provider", provider).Str("target", target).Msg("OIDC start: redirecting to provider") http.Redirect(w, r, target, http.StatusFound) } // handleCallback handles the OIDC callback after authorization. // // @Summary OIDC callback handler // @Description Validates state, exchanges code for tokens, validates id_token, signs up on first use, issues JWT. // @Tags API/v1/User // @Produce json // @Param provider path string true "OIDC provider name" // @Param state query string true "State parameter" // @Param code query string false "Authorization code" // @Param error query string false "OIDC error" // @Success 200 {object} OIDCCallbackResponse "Successfully signed in via OIDC" // @Failure 401 {object} map[string]string "Invalid state, missing code, or OIDC error" // @Failure 502 {object} map[string]string "Token exchange or validation failed" // @Failure 500 {object} map[string]string "Internal server error" // @Router /v1/auth/oidc/{provider}/callback [get] func (h *OIDCHandler) handleCallback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() provider := chi.URLParam(r, "provider") client, exists := h.clients[provider] if !exists { log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: unknown provider") writeJSONError(w, http.StatusNotFound, "unknown_provider", "unknown OIDC provider") return } // Read query parameters state := r.URL.Query().Get("state") code := r.URL.Query().Get("code") oidcError := r.URL.Query().Get("error") // If OIDC provider returned an error if oidcError != "" { log.Warn().Ctx(ctx).Str("provider", provider).Str("error", oidcError).Msg("OIDC callback: provider error") writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "oidc_error", "provider_error": oidcError, }) return } // Validate state if state == "" { log.Warn().Ctx(ctx).Msg("OIDC callback: missing state") writeJSONError(w, http.StatusUnauthorized, "invalid_state", "missing state parameter") return } h.pkceMu.Lock() entry, exists := h.pkceStore[state] if !exists { h.pkceMu.Unlock() log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state not found") writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state not found or already used") return } // Check expiration and provider match now := time.Now() if entry.expiresAt.Before(now) { delete(h.pkceStore, state) h.pkceMu.Unlock() log.Warn().Ctx(ctx).Str("state", state).Msg("OIDC callback: state expired") writeJSONError(w, http.StatusUnauthorized, "invalid_state", "state expired") return } if entry.providerName != provider { delete(h.pkceStore, state) h.pkceMu.Unlock() log.Warn().Ctx(ctx).Str("state", state).Str("expected_provider", entry.providerName).Str("actual_provider", provider).Msg("OIDC callback: provider mismatch") writeJSONError(w, http.StatusUnauthorized, "invalid_state", "provider mismatch") return } // Delete the entry (single-use) codeVerifier := entry.codeVerifier delete(h.pkceStore, state) h.pkceMu.Unlock() // Validate code parameter if code == "" { log.Warn().Ctx(ctx).Msg("OIDC callback: missing code") writeJSONError(w, http.StatusUnauthorized, "invalid_request", "missing authorization code") return } // Build redirect URI redirectURI := fmt.Sprintf("%s/api/v1/auth/oidc/%s/callback", h.redirectBase, provider) // Exchange code for tokens tokenResp, err := client.ExchangeCode(ctx, code, codeVerifier, redirectURI) if err != nil { log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: code exchange failed") writeJSONError(w, http.StatusBadGateway, "token_exchange_failed", fmt.Sprintf("code exchange failed: %v", err)) return } // Validate ID token claims, err := client.ValidateIDToken(ctx, tokenResp.IDToken) if err != nil { log.Error().Ctx(ctx).Err(err).Str("provider", provider).Msg("OIDC callback: ID token validation failed") writeJSONError(w, http.StatusUnauthorized, "invalid_id_token", fmt.Sprintf("ID token validation failed: %v", err)) return } // Check email in claims if claims.Email == "" { log.Warn().Ctx(ctx).Str("provider", provider).Msg("OIDC callback: no email in ID token") writeJSONError(w, http.StatusUnauthorized, "no_email_in_id_token", "ID token does not contain an email claim") return } // Ensure user exists (sign-up on first use) u, err := h.ensureUser(ctx, claims.Email) if err != nil { log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: user upsert failed") writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("user upsert failed: %v", err)) return } // Generate JWT jwtToken, err := h.users.GenerateJWT(ctx, u) if err != nil { log.Error().Ctx(ctx).Err(err).Str("email", claims.Email).Msg("OIDC callback: JWT generation failed") writeJSONError(w, http.StatusInternalServerError, "server_error", fmt.Sprintf("JWT generation failed: %v", err)) return } log.Info().Ctx(ctx).Str("provider", provider).Str("email", claims.Email).Msg("OIDC callback: user signed in successfully") writeJSON(w, http.StatusOK, map[string]string{ "message": "signed in via oidc", "token": jwtToken, "user": claims.Email, }) } // ensureUser returns the user keyed on email (stored as Username), // creating them if absent. Newly-created users get a random unguessable // bcrypt-hashed password so the password endpoints stay locked out. func (h *OIDCHandler) ensureUser(ctx context.Context, email string) (*user.User, error) { if h.repo != nil { existing, err := h.repo.GetUserByUsername(ctx, email) if err != nil { return nil, fmt.Errorf("get user by username: %w", err) } if existing != nil { return existing, nil } } // Generate random password rawPass := generateRandomHex(32) hash, err := h.users.HashPassword(ctx, rawPass) if err != nil { return nil, fmt.Errorf("hash password: %w", err) } u := &user.User{ Username: email, PasswordHash: hash, IsAdmin: false, } if err := h.users.CreateUser(ctx, u); err != nil { return nil, fmt.Errorf("create user: %w", err) } if h.repo != nil { return h.repo.GetUserByUsername(ctx, email) } return u, nil } // generateRandomBase64URL generates a random string suitable for use in OIDC PKCE flows. func generateRandomBase64URL(length int) string { b := make([]byte, length) if _, err := rand.Read(b); err != nil { panic(fmt.Sprintf("failed to read random bytes: %v", err)) } return base64.RawURLEncoding.EncodeToString(b) } // generateRandomHex generates a random hex string. func generateRandomHex(length int) string { b := make([]byte, length/2) if _, err := rand.Read(b); err != nil { panic(fmt.Sprintf("failed to read random bytes: %v", err)) } return hex.EncodeToString(b) } // OIDCCallbackResponse represents the JSON response from the OIDC callback. type OIDCCallbackResponse struct { Message string `json:"message"` Token string `json:"token"` User string `json:"user"` }