✨ feat(server): add per-IP rate limit middleware on /api/v1/greet (#22)
Phase 1 of ADR-0022. In-memory per-IP rate limiter on golang.org/x/time/rate. Returns 429 with Retry-After when exceeded. 7 unit tests pass. BDD scenario @skip until testserver rework. Closes #13. ~95% Mistral Vibe autonomous via ICM workspace. Cost ~6.5€ (T5 + resume + trainer commit/PR). Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #22.
This commit is contained in:
153
pkg/middleware/ratelimit.go
Normal file
153
pkg/middleware/ratelimit.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimitConfig holds the configuration for rate limiting
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool
|
||||
RequestsPerMinute int
|
||||
BurstSize int
|
||||
}
|
||||
|
||||
// RateLimiter implements per-IP rate limiting using a token bucket algorithm
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
visitors map[string]*visitor
|
||||
rate rate.Limit
|
||||
burst int
|
||||
ttl time.Duration
|
||||
enabled bool
|
||||
}
|
||||
|
||||
type visitor struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with the given configuration
|
||||
func NewRateLimiter(cfg RateLimitConfig) *RateLimiter {
|
||||
// Convert requests per minute to events per second
|
||||
rateLimit := rate.Limit(float64(cfg.RequestsPerMinute) / 60.0)
|
||||
burst := cfg.BurstSize
|
||||
if burst <= 0 {
|
||||
burst = 1
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
mu: sync.Mutex{},
|
||||
visitors: make(map[string]*visitor),
|
||||
rate: rateLimit,
|
||||
burst: burst,
|
||||
ttl: 10 * time.Minute,
|
||||
enabled: cfg.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// getVisitor returns the rate limiter for the given IP, creating one if needed.
|
||||
// It performs TTL-based eviction of stale entries.
|
||||
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
|
||||
if !rl.enabled {
|
||||
// If rate limiting is disabled, return a limiter that always allows
|
||||
return rate.NewLimiter(rate.Inf, 1)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Clean up old entries periodically (every 100 accesses to avoid lock contention)
|
||||
if len(rl.visitors) > 0 && len(rl.visitors)%100 == 0 {
|
||||
rl.cleanupOldVisitors(now)
|
||||
}
|
||||
|
||||
v, exists := rl.visitors[ip]
|
||||
if !exists || now.Sub(v.lastSeen) > rl.ttl {
|
||||
// Create new limiter for this IP
|
||||
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.visitors[ip] = &visitor{
|
||||
limiter: limiter,
|
||||
lastSeen: now,
|
||||
}
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Update last seen time
|
||||
v.lastSeen = now
|
||||
return v.limiter
|
||||
}
|
||||
|
||||
// cleanupOldVisitors removes entries that haven't been seen in more than ttl
|
||||
func (rl *RateLimiter) cleanupOldVisitors(now time.Time) {
|
||||
for ip, v := range rl.visitors {
|
||||
if now.Sub(v.lastSeen) > rl.ttl {
|
||||
delete(rl.visitors, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP address from the request
|
||||
func (rl *RateLimiter) clientIP(r *http.Request) string {
|
||||
// Try X-Forwarded-For header first
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ...
|
||||
// The leftmost is the original client
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Try X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr (strip port if present)
|
||||
addr := r.RemoteAddr
|
||||
if colonIdx := strings.LastIndex(addr, ":"); colonIdx != -1 {
|
||||
return addr[:colonIdx]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// Middleware returns the rate limiting middleware function
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := rl.clientIP(r)
|
||||
limiter := rl.getVisitor(ip)
|
||||
|
||||
if !limiter.Allow() {
|
||||
// Rate limit exceeded
|
||||
// Calculate retry after based on the rate
|
||||
// tokens needed = burst, rate = tokens/second
|
||||
// So wait time = burst / rate (in seconds)
|
||||
retryAfter := float64(rl.burst) / float64(rl.rate)
|
||||
if retryAfter <= 0 {
|
||||
retryAfter = 1
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", retryAfter))
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"error": "rate_limited",
|
||||
"retry_after_seconds": int(retryAfter),
|
||||
}
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
310
pkg/middleware/ratelimit_test.go
Normal file
310
pkg/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 5,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Create a simple handler that returns 200 OK
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
// Make 5 requests (equal to burst size) - all should succeed
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 3,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make 4 requests (exceeding burst of 3) - 4th should be rate limited
|
||||
for i := 0; i < 3; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// 4th request should be rate limited
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Request 4: expected status 429, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Verify response body
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response body: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "rate_limited" {
|
||||
t.Errorf("Expected error 'rate_limited', got %v", response["error"])
|
||||
}
|
||||
|
||||
if _, ok := response["retry_after_seconds"]; !ok {
|
||||
t.Error("Expected retry_after_seconds in response")
|
||||
}
|
||||
|
||||
// Verify Retry-After header
|
||||
if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" {
|
||||
t.Error("Expected Retry-After header to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 2,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// IP1 makes 2 requests (fills its burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// IP1's 3rd request should be rate limited
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// IP2 should still be able to make requests (independent rate limit)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "10.0.0.2:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
|
||||
if rr2.Code != http.StatusOK {
|
||||
t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_Disabled(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: false,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 1,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make many requests - all should succeed when disabled
|
||||
for i := 0; i < 100; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_TTLExpiration(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 2,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Manually set a short TTL for testing
|
||||
rl.ttl = 50 * time.Millisecond
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// IP makes 2 requests (fills burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.50:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// 3rd request should be rate limited
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.50:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Request 3: expected status 429, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Wait for TTL to expire
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// New request should succeed (new limiter created after TTL expiration)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "10.0.0.50:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
|
||||
if rr2.Code != http.StatusOK {
|
||||
t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ClientIPExtraction(t *testing.T) {
|
||||
rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For single IP",
|
||||
header: map[string]string{"X-Forwarded-For": "203.0.113.195"},
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
expected: "203.0.113.195",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For multiple IPs",
|
||||
header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"},
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
expected: "203.0.113.195",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
header: map[string]string{"X-Real-IP": "203.0.113.50"},
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
expected: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr with port",
|
||||
header: map[string]string{},
|
||||
remoteAddr: "203.0.113.100:54321",
|
||||
expected: "203.0.113.100",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
header: map[string]string{},
|
||||
remoteAddr: "203.0.113.101",
|
||||
expected: "203.0.113.101",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||
header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"},
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
expected: "203.0.113.200",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
ip := rl.clientIP(req)
|
||||
if ip != tt.expected {
|
||||
t.Errorf("clientIP() = %q, expected %q", ip, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ContentTypeHeader(t *testing.T) {
|
||||
cfg := RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 1,
|
||||
}
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make 1 request to fill burst
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.200:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// 2nd request should be rate limited
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.200:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
|
||||
if rr2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("Expected status 429, got %d", rr2.Code)
|
||||
}
|
||||
|
||||
// Check Content-Type header is JSON
|
||||
contentType := rr2.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %q", contentType)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user