package middleware import ( "encoding/json" "fmt" "net/http" "strings" "sync" "time" "golang.org/x/time/rate" ) // RateLimitConfig holds the configuration for rate limiting type RateLimitConfig struct { Enabled bool RequestsPerMinute int BurstSize int } // RateLimiter implements per-IP rate limiting using a token bucket algorithm type RateLimiter struct { mu sync.Mutex visitors map[string]*visitor rate rate.Limit burst int ttl time.Duration enabled bool } type visitor struct { limiter *rate.Limiter lastSeen time.Time } // NewRateLimiter creates a new rate limiter with the given configuration func NewRateLimiter(cfg RateLimitConfig) *RateLimiter { // Convert requests per minute to events per second rateLimit := rate.Limit(float64(cfg.RequestsPerMinute) / 60.0) burst := cfg.BurstSize if burst <= 0 { burst = 1 } return &RateLimiter{ mu: sync.Mutex{}, visitors: make(map[string]*visitor), rate: rateLimit, burst: burst, ttl: 10 * time.Minute, enabled: cfg.Enabled, } } // getVisitor returns the rate limiter for the given IP, creating one if needed. // It performs TTL-based eviction of stale entries. func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter { if !rl.enabled { // If rate limiting is disabled, return a limiter that always allows return rate.NewLimiter(rate.Inf, 1) } now := time.Now() rl.mu.Lock() defer rl.mu.Unlock() // Clean up old entries periodically (every 100 accesses to avoid lock contention) if len(rl.visitors) > 0 && len(rl.visitors)%100 == 0 { rl.cleanupOldVisitors(now) } v, exists := rl.visitors[ip] if !exists || now.Sub(v.lastSeen) > rl.ttl { // Create new limiter for this IP limiter := rate.NewLimiter(rl.rate, rl.burst) rl.visitors[ip] = &visitor{ limiter: limiter, lastSeen: now, } return limiter } // Update last seen time v.lastSeen = now return v.limiter } // cleanupOldVisitors removes entries that haven't been seen in more than ttl func (rl *RateLimiter) cleanupOldVisitors(now time.Time) { for ip, v := range rl.visitors { if now.Sub(v.lastSeen) > rl.ttl { delete(rl.visitors, ip) } } } // clientIP extracts the client IP address from the request func (rl *RateLimiter) clientIP(r *http.Request) string { // Try X-Forwarded-For header first if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ... // The leftmost is the original client ips := strings.Split(xff, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } // Try X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr (strip port if present) addr := r.RemoteAddr if colonIdx := strings.LastIndex(addr, ":"); colonIdx != -1 { return addr[:colonIdx] } return addr } // Middleware returns the rate limiting middleware function func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := rl.clientIP(r) limiter := rl.getVisitor(ip) if !limiter.Allow() { // Rate limit exceeded // Calculate retry after based on the rate // tokens needed = burst, rate = tokens/second // So wait time = burst / rate (in seconds) retryAfter := float64(rl.burst) / float64(rl.rate) if retryAfter <= 0 { retryAfter = 1 } w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%.0f", retryAfter)) w.WriteHeader(http.StatusTooManyRequests) response := map[string]interface{}{ "error": "rate_limited", "retry_after_seconds": int(retryAfter), } json.NewEncoder(w).Encode(response) return } next.ServeHTTP(w, r) }) }