Phase 1 of ADR-0022. In-memory per-IP rate limiter on golang.org/x/time/rate. Returns 429 with Retry-After when exceeded. 7 unit tests pass. BDD scenario @skip until testserver rework. Closes #13. ~95% Mistral Vibe autonomous via ICM workspace. Cost ~6.5€ (T5 + resume + trainer commit/PR). Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
311 lines
8.3 KiB
Go
311 lines
8.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 5,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
// Create a simple handler that returns 200 OK
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
|
|
// Make 5 requests (equal to burst size) - all should succeed
|
|
for i := 0; i < 5; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 3,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make 4 requests (exceeding burst of 3) - 4th should be rate limited
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.2:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// 4th request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.2:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Request 4: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// Verify response body
|
|
var response map[string]interface{}
|
|
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
|
t.Fatalf("Failed to decode response body: %v", err)
|
|
}
|
|
|
|
if response["error"] != "rate_limited" {
|
|
t.Errorf("Expected error 'rate_limited', got %v", response["error"])
|
|
}
|
|
|
|
if _, ok := response["retry_after_seconds"]; !ok {
|
|
t.Error("Expected retry_after_seconds in response")
|
|
}
|
|
|
|
// Verify Retry-After header
|
|
if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" {
|
|
t.Error("Expected Retry-After header to be set")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_DifferentIPsIndependent(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 2,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// IP1 makes 2 requests (fills its burst)
|
|
for i := 0; i < 2; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// IP1's 3rd request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// IP2 should still be able to make requests (independent rate limit)
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "10.0.0.2:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusOK {
|
|
t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_Disabled(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: false,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 1,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make many requests - all should succeed when disabled
|
|
for i := 0; i < 100; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_TTLExpiration(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 2,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
// Manually set a short TTL for testing
|
|
rl.ttl = 50 * time.Millisecond
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// IP makes 2 requests (fills burst)
|
|
for i := 0; i < 2; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.50:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// 3rd request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.50:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Request 3: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// Wait for TTL to expire
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
// New request should succeed (new limiter created after TTL expiration)
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "10.0.0.50:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusOK {
|
|
t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_ClientIPExtraction(t *testing.T) {
|
|
rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10})
|
|
|
|
tests := []struct {
|
|
name string
|
|
header map[string]string
|
|
remoteAddr string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "X-Forwarded-For single IP",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.195",
|
|
},
|
|
{
|
|
name: "X-Forwarded-For multiple IPs",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.195",
|
|
},
|
|
{
|
|
name: "X-Real-IP",
|
|
header: map[string]string{"X-Real-IP": "203.0.113.50"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.50",
|
|
},
|
|
{
|
|
name: "RemoteAddr with port",
|
|
header: map[string]string{},
|
|
remoteAddr: "203.0.113.100:54321",
|
|
expected: "203.0.113.100",
|
|
},
|
|
{
|
|
name: "RemoteAddr without port",
|
|
header: map[string]string{},
|
|
remoteAddr: "203.0.113.101",
|
|
expected: "203.0.113.101",
|
|
},
|
|
{
|
|
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.200",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
for k, v := range tt.header {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.RemoteAddr = tt.remoteAddr
|
|
|
|
ip := rl.clientIP(req)
|
|
if ip != tt.expected {
|
|
t.Errorf("clientIP() = %q, expected %q", ip, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_ContentTypeHeader(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 1,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make 1 request to fill burst
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.200:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
// 2nd request should be rate limited
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "192.168.1.200:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("Expected status 429, got %d", rr2.Code)
|
|
}
|
|
|
|
// Check Content-Type header is JSON
|
|
contentType := rr2.Header().Get("Content-Type")
|
|
if contentType != "application/json" {
|
|
t.Errorf("Expected Content-Type: application/json, got %q", contentType)
|
|
}
|
|
}
|