package middleware import ( "encoding/json" "net/http" "net/http/httptest" "testing" "time" ) func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) { cfg := RateLimitConfig{ Enabled: true, RequestsPerMinute: 60, BurstSize: 5, } rl := NewRateLimiter(cfg) // Create a simple handler that returns 200 OK handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) // Make 5 requests (equal to burst size) - all should succeed for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) } } } func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) { cfg := RateLimitConfig{ Enabled: true, RequestsPerMinute: 60, BurstSize: 3, } rl := NewRateLimiter(cfg) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Make 4 requests (exceeding burst of 3) - 4th should be rate limited for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.2:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) } } // 4th request should be rate limited req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.2:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("Request 4: expected status 429, got %d", rr.Code) } // Verify response body var response map[string]interface{} if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { t.Fatalf("Failed to decode response body: %v", err) } if response["error"] != "rate_limited" { t.Errorf("Expected error 'rate_limited', got %v", response["error"]) } if _, ok := response["retry_after_seconds"]; !ok { t.Error("Expected retry_after_seconds in response") } // Verify Retry-After header if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" { t.Error("Expected Retry-After header to be set") } } func TestRateLimiter_DifferentIPsIndependent(t *testing.T) { cfg := RateLimitConfig{ Enabled: true, RequestsPerMinute: 60, BurstSize: 2, } rl := NewRateLimiter(cfg) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // IP1 makes 2 requests (fills its burst) for i := 0; i < 2; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "10.0.0.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code) } } // IP1's 3rd request should be rate limited req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "10.0.0.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code) } // IP2 should still be able to make requests (independent rate limit) req2 := httptest.NewRequest("GET", "/test", nil) req2.RemoteAddr = "10.0.0.2:12345" rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req2) if rr2.Code != http.StatusOK { t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code) } } func TestRateLimiter_Disabled(t *testing.T) { cfg := RateLimitConfig{ Enabled: false, RequestsPerMinute: 60, BurstSize: 1, } rl := NewRateLimiter(cfg) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Make many requests - all should succeed when disabled for i := 0; i < 100; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.100:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code) } } } func TestRateLimiter_TTLExpiration(t *testing.T) { cfg := RateLimitConfig{ Enabled: true, RequestsPerMinute: 60, BurstSize: 2, } rl := NewRateLimiter(cfg) // Manually set a short TTL for testing rl.ttl = 50 * time.Millisecond handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // IP makes 2 requests (fills burst) for i := 0; i < 2; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "10.0.0.50:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) } } // 3rd request should be rate limited req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "10.0.0.50:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("Request 3: expected status 429, got %d", rr.Code) } // Wait for TTL to expire time.Sleep(60 * time.Millisecond) // New request should succeed (new limiter created after TTL expiration) req2 := httptest.NewRequest("GET", "/test", nil) req2.RemoteAddr = "10.0.0.50:12345" rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req2) if rr2.Code != http.StatusOK { t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code) } } func TestRateLimiter_ClientIPExtraction(t *testing.T) { rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10}) tests := []struct { name string header map[string]string remoteAddr string expected string }{ { name: "X-Forwarded-For single IP", header: map[string]string{"X-Forwarded-For": "203.0.113.195"}, remoteAddr: "127.0.0.1:12345", expected: "203.0.113.195", }, { name: "X-Forwarded-For multiple IPs", header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"}, remoteAddr: "127.0.0.1:12345", expected: "203.0.113.195", }, { name: "X-Real-IP", header: map[string]string{"X-Real-IP": "203.0.113.50"}, remoteAddr: "127.0.0.1:12345", expected: "203.0.113.50", }, { name: "RemoteAddr with port", header: map[string]string{}, remoteAddr: "203.0.113.100:54321", expected: "203.0.113.100", }, { name: "RemoteAddr without port", header: map[string]string{}, remoteAddr: "203.0.113.101", expected: "203.0.113.101", }, { name: "X-Forwarded-For takes precedence over X-Real-IP", header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"}, remoteAddr: "127.0.0.1:12345", expected: "203.0.113.200", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) for k, v := range tt.header { req.Header.Set(k, v) } req.RemoteAddr = tt.remoteAddr ip := rl.clientIP(req) if ip != tt.expected { t.Errorf("clientIP() = %q, expected %q", ip, tt.expected) } }) } } func TestRateLimiter_ContentTypeHeader(t *testing.T) { cfg := RateLimitConfig{ Enabled: true, RequestsPerMinute: 60, BurstSize: 1, } rl := NewRateLimiter(cfg) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Make 1 request to fill burst req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.200:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) // 2nd request should be rate limited req2 := httptest.NewRequest("GET", "/test", nil) req2.RemoteAddr = "192.168.1.200:12345" rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req2) if rr2.Code != http.StatusTooManyRequests { t.Fatalf("Expected status 429, got %d", rr2.Code) } // Check Content-Type header is JSON contentType := rr2.Header().Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type: application/json, got %q", contentType) } }