Implements Phase 1 of ADR-0022 (Rate Limiting and Cache Strategy): in-memory per-IP rate limiter using golang.org/x/time/rate. Returns HTTP 429 with JSON body and Retry-After header when exceeded. Changes: - New: pkg/middleware/ratelimit.go (153 lines, 7 unit tests in ratelimit_test.go) - Modified: pkg/config/config.go (RateLimit struct + 3 SetDefaults + 3 BindEnv + 3 getters) - Modified: pkg/server/server.go (wire on /api/v1/greet, conditional on Enabled) - Modified: pkg/bdd/testserver/server.go (env-var support for rate limit config) - New: pkg/bdd/steps/ratelimit_steps.go (step definitions) - Added: features/greet/greet.feature scenario (currently @skip @bdd-deferred — see note below) Known limitation: The BDD scenario is tagged @skip @bdd-deferred because the testserver loads its config once at startup; env vars set inside a step do not reach the already-running server. The middleware itself is fully covered by unit tests. To re-enable BDD, the testserver needs either an admin endpoint or a per-scenario fresh-server pattern. Closes #13 (Phase 1 only — Phase 2 Redis + cache service deferred). Generated ~95% in autonomy by Mistral Vibe via ICM workspace ~/Work/Vibe/workspaces/rate-limit-middleware/. Trainer (Claude) finalized the commit/PR step (Mistral hit max-turns). 🤖 Co-Authored-By: Mistral Vibe (devstral-2 / mistral-medium-3.5) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
311 lines
8.3 KiB
Go
311 lines
8.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 5,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
// Create a simple handler that returns 200 OK
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
|
|
// Make 5 requests (equal to burst size) - all should succeed
|
|
for i := 0; i < 5; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 3,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make 4 requests (exceeding burst of 3) - 4th should be rate limited
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.2:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// 4th request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.2:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Request 4: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// Verify response body
|
|
var response map[string]interface{}
|
|
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
|
t.Fatalf("Failed to decode response body: %v", err)
|
|
}
|
|
|
|
if response["error"] != "rate_limited" {
|
|
t.Errorf("Expected error 'rate_limited', got %v", response["error"])
|
|
}
|
|
|
|
if _, ok := response["retry_after_seconds"]; !ok {
|
|
t.Error("Expected retry_after_seconds in response")
|
|
}
|
|
|
|
// Verify Retry-After header
|
|
if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" {
|
|
t.Error("Expected Retry-After header to be set")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_DifferentIPsIndependent(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 2,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// IP1 makes 2 requests (fills its burst)
|
|
for i := 0; i < 2; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// IP1's 3rd request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// IP2 should still be able to make requests (independent rate limit)
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "10.0.0.2:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusOK {
|
|
t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_Disabled(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: false,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 1,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make many requests - all should succeed when disabled
|
|
for i := 0; i < 100; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_TTLExpiration(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 2,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
// Manually set a short TTL for testing
|
|
rl.ttl = 50 * time.Millisecond
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// IP makes 2 requests (fills burst)
|
|
for i := 0; i < 2; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.50:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
// 3rd request should be rate limited
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.50:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Request 3: expected status 429, got %d", rr.Code)
|
|
}
|
|
|
|
// Wait for TTL to expire
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
// New request should succeed (new limiter created after TTL expiration)
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "10.0.0.50:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusOK {
|
|
t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_ClientIPExtraction(t *testing.T) {
|
|
rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10})
|
|
|
|
tests := []struct {
|
|
name string
|
|
header map[string]string
|
|
remoteAddr string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "X-Forwarded-For single IP",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.195",
|
|
},
|
|
{
|
|
name: "X-Forwarded-For multiple IPs",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.195",
|
|
},
|
|
{
|
|
name: "X-Real-IP",
|
|
header: map[string]string{"X-Real-IP": "203.0.113.50"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.50",
|
|
},
|
|
{
|
|
name: "RemoteAddr with port",
|
|
header: map[string]string{},
|
|
remoteAddr: "203.0.113.100:54321",
|
|
expected: "203.0.113.100",
|
|
},
|
|
{
|
|
name: "RemoteAddr without port",
|
|
header: map[string]string{},
|
|
remoteAddr: "203.0.113.101",
|
|
expected: "203.0.113.101",
|
|
},
|
|
{
|
|
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
|
header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"},
|
|
remoteAddr: "127.0.0.1:12345",
|
|
expected: "203.0.113.200",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
for k, v := range tt.header {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.RemoteAddr = tt.remoteAddr
|
|
|
|
ip := rl.clientIP(req)
|
|
if ip != tt.expected {
|
|
t.Errorf("clientIP() = %q, expected %q", ip, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_ContentTypeHeader(t *testing.T) {
|
|
cfg := RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerMinute: 60,
|
|
BurstSize: 1,
|
|
}
|
|
rl := NewRateLimiter(cfg)
|
|
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Make 1 request to fill burst
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.200:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
// 2nd request should be rate limited
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "192.168.1.200:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
|
|
if rr2.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("Expected status 429, got %d", rr2.Code)
|
|
}
|
|
|
|
// Check Content-Type header is JSON
|
|
contentType := rr2.Header().Get("Content-Type")
|
|
if contentType != "application/json" {
|
|
t.Errorf("Expected Content-Type: application/json, got %q", contentType)
|
|
}
|
|
}
|