Files
dance-lessons-coach/pkg/middleware/ratelimit.go
Gabriel Radureau 54dd0cc80f
All checks were successful
CI/CD Pipeline / Build Docker Cache (push) Successful in 1m3s
CI/CD Pipeline / CI Pipeline (push) Successful in 4m4s
CI/CD Pipeline / Trigger Docker Push (push) Successful in 6s
feat(server): add per-IP rate limit middleware on /api/v1/greet (#22)
Phase 1 of ADR-0022. In-memory per-IP rate limiter on golang.org/x/time/rate. Returns 429 with Retry-After when exceeded. 7 unit tests pass. BDD scenario @skip until testserver rework. Closes #13.

~95% Mistral Vibe autonomous via ICM workspace. Cost ~6.5€ (T5 + resume + trainer commit/PR).

Co-authored-by: Gabriel Radureau <arcodange@gmail.com>
Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
2026-05-03 13:16:29 +02:00

154 lines
3.8 KiB
Go

package middleware
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimitConfig holds the configuration for rate limiting
type RateLimitConfig struct {
Enabled bool
RequestsPerMinute int
BurstSize int
}
// RateLimiter implements per-IP rate limiting using a token bucket algorithm
type RateLimiter struct {
mu sync.Mutex
visitors map[string]*visitor
rate rate.Limit
burst int
ttl time.Duration
enabled bool
}
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
// NewRateLimiter creates a new rate limiter with the given configuration
func NewRateLimiter(cfg RateLimitConfig) *RateLimiter {
// Convert requests per minute to events per second
rateLimit := rate.Limit(float64(cfg.RequestsPerMinute) / 60.0)
burst := cfg.BurstSize
if burst <= 0 {
burst = 1
}
return &RateLimiter{
mu: sync.Mutex{},
visitors: make(map[string]*visitor),
rate: rateLimit,
burst: burst,
ttl: 10 * time.Minute,
enabled: cfg.Enabled,
}
}
// getVisitor returns the rate limiter for the given IP, creating one if needed.
// It performs TTL-based eviction of stale entries.
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
if !rl.enabled {
// If rate limiting is disabled, return a limiter that always allows
return rate.NewLimiter(rate.Inf, 1)
}
now := time.Now()
rl.mu.Lock()
defer rl.mu.Unlock()
// Clean up old entries periodically (every 100 accesses to avoid lock contention)
if len(rl.visitors) > 0 && len(rl.visitors)%100 == 0 {
rl.cleanupOldVisitors(now)
}
v, exists := rl.visitors[ip]
if !exists || now.Sub(v.lastSeen) > rl.ttl {
// Create new limiter for this IP
limiter := rate.NewLimiter(rl.rate, rl.burst)
rl.visitors[ip] = &visitor{
limiter: limiter,
lastSeen: now,
}
return limiter
}
// Update last seen time
v.lastSeen = now
return v.limiter
}
// cleanupOldVisitors removes entries that haven't been seen in more than ttl
func (rl *RateLimiter) cleanupOldVisitors(now time.Time) {
for ip, v := range rl.visitors {
if now.Sub(v.lastSeen) > rl.ttl {
delete(rl.visitors, ip)
}
}
}
// clientIP extracts the client IP address from the request
func (rl *RateLimiter) clientIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ...
// The leftmost is the original client
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr (strip port if present)
addr := r.RemoteAddr
if colonIdx := strings.LastIndex(addr, ":"); colonIdx != -1 {
return addr[:colonIdx]
}
return addr
}
// Middleware returns the rate limiting middleware function
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := rl.clientIP(r)
limiter := rl.getVisitor(ip)
if !limiter.Allow() {
// Rate limit exceeded
// Calculate retry after based on the rate
// tokens needed = burst, rate = tokens/second
// So wait time = burst / rate (in seconds)
retryAfter := float64(rl.burst) / float64(rl.rate)
if retryAfter <= 0 {
retryAfter = 1
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", retryAfter))
w.WriteHeader(http.StatusTooManyRequests)
response := map[string]interface{}{
"error": "rate_limited",
"retry_after_seconds": int(retryAfter),
}
json.NewEncoder(w).Encode(response)
return
}
next.ServeHTTP(w, r)
})
}