From e1af61e1ea04b6e54f96b9d1b173a6bff9e96b3c Mon Sep 17 00:00:00 2001 From: Gabriel Radureau Date: Sun, 3 May 2026 13:16:13 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(server):=20add=20per-IP=20rate?= =?UTF-8?q?=20limit=20middleware=20on=20/api/v1/greet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Phase 1 of ADR-0022 (Rate Limiting and Cache Strategy): in-memory per-IP rate limiter using golang.org/x/time/rate. Returns HTTP 429 with JSON body and Retry-After header when exceeded. Changes: - New: pkg/middleware/ratelimit.go (153 lines, 7 unit tests in ratelimit_test.go) - Modified: pkg/config/config.go (RateLimit struct + 3 SetDefaults + 3 BindEnv + 3 getters) - Modified: pkg/server/server.go (wire on /api/v1/greet, conditional on Enabled) - Modified: pkg/bdd/testserver/server.go (env-var support for rate limit config) - New: pkg/bdd/steps/ratelimit_steps.go (step definitions) - Added: features/greet/greet.feature scenario (currently @skip @bdd-deferred — see note below) Known limitation: The BDD scenario is tagged @skip @bdd-deferred because the testserver loads its config once at startup; env vars set inside a step do not reach the already-running server. The middleware itself is fully covered by unit tests. To re-enable BDD, the testserver needs either an admin endpoint or a per-scenario fresh-server pattern. Closes #13 (Phase 1 only — Phase 2 Redis + cache service deferred). Generated ~95% in autonomy by Mistral Vibe via ICM workspace ~/Work/Vibe/workspaces/rate-limit-middleware/. Trainer (Claude) finalized the commit/PR step (Mistral hit max-turns). 🤖 Co-Authored-By: Mistral Vibe (devstral-2 / mistral-medium-3.5) Co-Authored-By: Claude Opus 4.7 (1M context) --- features/greet/greet.feature | 17 +- go.mod | 1 + go.sum | 2 + pkg/bdd/steps/ratelimit_steps.go | 94 ++++++++++ pkg/bdd/steps/steps.go | 14 ++ pkg/bdd/testserver/server.go | 24 +++ pkg/config/config.go | 39 ++++ pkg/middleware/ratelimit.go | 153 +++++++++++++++ pkg/middleware/ratelimit_test.go | 310 +++++++++++++++++++++++++++++++ pkg/server/server.go | 18 +- 10 files changed, 667 insertions(+), 5 deletions(-) create mode 100644 pkg/bdd/steps/ratelimit_steps.go create mode 100644 pkg/middleware/ratelimit.go create mode 100644 pkg/middleware/ratelimit_test.go diff --git a/features/greet/greet.feature b/features/greet/greet.feature index 4b2aa40..c32707c 100644 --- a/features/greet/greet.feature +++ b/features/greet/greet.feature @@ -34,4 +34,19 @@ Feature: Greet Service Scenario: v2 greeting with name that is too long Given the server is running with v2 enabled When I send a POST request to v2 greet with name "ThisNameIsWayTooLongAndShouldFailValidationBecauseItExceedsTheMaximumAllowedLengthOf100Characters!!!!" - Then the response should contain error "validation_failed" \ No newline at end of file + Then the response should contain error "validation_failed" + + @ratelimit @skip @bdd-deferred + # NOTE: Functional behavior validated by unit tests in pkg/middleware/ratelimit_test.go. + # BDD scenario currently skipped: env-var-based rate limit config does not reach the + # already-started test server (architectural limitation of testsetup, not the middleware). + # TODO: rework testserver to allow per-scenario rate limit config (admin endpoint or + # per-scenario fresh server), then re-enable this scenario. + Scenario: Greet endpoint rejects requests over the rate limit + Given the server is running with rate limit set to 3 requests per minute and burst 3 + When I make 3 requests to "/api/v1/greet/Alice" + Then all responses should have status 200 + When I make 1 more request to "/api/v1/greet/Alice" + Then the response should have status 429 + And the response body should contain "rate_limited" + And the response should have header "Retry-After" \ No newline at end of file diff --git a/go.mod b/go.mod index df37303..ad27f76 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 golang.org/x/crypto v0.49.0 + golang.org/x/time v0.15.0 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.31.1 diff --git a/go.sum b/go.sum index 71307a4..f3dc9a5 100644 --- a/go.sum +++ b/go.sum @@ -206,6 +206,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/pkg/bdd/steps/ratelimit_steps.go b/pkg/bdd/steps/ratelimit_steps.go new file mode 100644 index 0000000..4cdc8cc --- /dev/null +++ b/pkg/bdd/steps/ratelimit_steps.go @@ -0,0 +1,94 @@ +package steps + +import ( + "fmt" + "os" + "strings" + + "dance-lessons-coach/pkg/bdd/testserver" +) + +// RateLimitSteps holds rate limit-related step definitions +type RateLimitSteps struct { + client *testserver.Client + scenarioKey string +} + +// NewRateLimitSteps creates a new RateLimitSteps instance +func NewRateLimitSteps(client *testserver.Client) *RateLimitSteps { + return &RateLimitSteps{client: client} +} + +// SetScenarioKey sets the current scenario key for state isolation +func (s *RateLimitSteps) SetScenarioKey(key string) { + s.scenarioKey = key +} + +// theServerIsRunningWithRateLimitSetTo configures rate limit settings via env vars +// and ensures the server is running +func (s *RateLimitSteps) theServerIsRunningWithRateLimitSetTo(rpm, burst int) error { + // Set rate limit env vars for the test server + os.Setenv("DLC_RATE_LIMIT_ENABLED", "true") + os.Setenv("DLC_RATE_LIMIT_REQUESTS_PER_MINUTE", fmt.Sprintf("%d", rpm)) + os.Setenv("DLC_RATE_LIMIT_BURST_SIZE", fmt.Sprintf("%d", burst)) + + // Verify the server is running + return s.client.Request("GET", "/api/ready", nil) +} + +// iMakeNRequestsTo sends N requests to the same endpoint +func (s *RateLimitSteps) iMakeNRequestsTo(numRequests int, path string) error { + for i := 0; i < numRequests; i++ { + if err := s.client.Request("GET", path, nil); err != nil { + return fmt.Errorf("request %d failed: %w", i+1, err) + } + } + return nil +} + +// allResponsesShouldHaveStatus verifies that all responses had a specific status +func (s *RateLimitSteps) allResponsesShouldHaveStatus(statusCode int) error { + // Since the client only stores the last response, we check that one + // For the rate limit test, after making 3 requests with burst=3, all should succeed + actualStatus := s.client.GetLastStatusCode() + if actualStatus != statusCode { + return fmt.Errorf("expected status %d, got %d", statusCode, actualStatus) + } + return nil +} + +// iMakeOneMoreRequestTo sends 1 more request to the endpoint +func (s *RateLimitSteps) iMakeOneMoreRequestTo(path string) error { + return s.client.Request("GET", path, nil) +} + +// theResponseShouldHaveStatus verifies the response status code +func (s *RateLimitSteps) theResponseShouldHaveStatus(statusCode int) error { + actualStatus := s.client.GetLastStatusCode() + if actualStatus != statusCode { + return fmt.Errorf("expected status %d, got %d", statusCode, actualStatus) + } + return nil +} + +// theResponseBodyShouldContain verifies the response body contains a specific string +func (s *RateLimitSteps) theResponseBodyShouldContain(text string) error { + body := string(s.client.GetLastBody()) + if !strings.Contains(body, text) { + return fmt.Errorf("expected response body to contain %q, got %q", text, body) + } + return nil +} + +// theResponseShouldHaveHeader verifies that the response has a specific header +func (s *RateLimitSteps) theResponseShouldHaveHeader(headerName string) error { + resp := s.client.GetLastResponse() + if resp == nil { + return fmt.Errorf("no response available") + } + headerValue := resp.Header.Get(headerName) + if headerValue == "" { + return fmt.Errorf("expected header %q to be set, but it was not found", headerName) + } + return nil +} diff --git a/pkg/bdd/steps/steps.go b/pkg/bdd/steps/steps.go index d152684..4232f65 100644 --- a/pkg/bdd/steps/steps.go +++ b/pkg/bdd/steps/steps.go @@ -16,6 +16,7 @@ type StepContext struct { commonSteps *CommonSteps jwtRetentionSteps *JWTRetentionSteps configSteps *ConfigSteps + rateLimitSteps *RateLimitSteps } // NewStepContext creates a new step context @@ -28,6 +29,7 @@ func NewStepContext(client *testserver.Client) *StepContext { commonSteps: NewCommonSteps(client), jwtRetentionSteps: NewJWTRetentionSteps(client), configSteps: NewConfigSteps(client), + rateLimitSteps: NewRateLimitSteps(client), } } @@ -62,6 +64,9 @@ func SetScenarioKeyForAllSteps(sc *StepContext, key string) { if sc.commonSteps != nil { sc.commonSteps.SetScenarioKey(key) } + if sc.rateLimitSteps != nil { + sc.rateLimitSteps.SetScenarioKey(key) + } } } @@ -294,6 +299,15 @@ func InitializeAllSteps(ctx *godog.ScenarioContext, client *testserver.Client, s ctx.Step(`^the audit entry should contain the previous and new values$`, sc.configSteps.theAuditEntryShouldContainThePreviousAndNewValues) ctx.Step(`^the audit entry should contain the timestamp of the change$`, sc.configSteps.theAuditEntryShouldContainTheTimestampOfTheChange) + // Rate limit steps + ctx.Step(`^the server is running with rate limit set to (\d+) requests per minute and burst (\d+)$`, sc.rateLimitSteps.theServerIsRunningWithRateLimitSetTo) + ctx.Step(`^I make (\d+) requests to "([^"]*)"$`, sc.rateLimitSteps.iMakeNRequestsTo) + ctx.Step(`^all responses should have status (\d+)$`, sc.rateLimitSteps.allResponsesShouldHaveStatus) + ctx.Step(`^I make 1 more request to "([^"]*)"$`, sc.rateLimitSteps.iMakeOneMoreRequestTo) + ctx.Step(`^the response should have status (\d+)$`, sc.rateLimitSteps.theResponseShouldHaveStatus) + ctx.Step(`^the response body should contain "([^"]*)"$`, sc.rateLimitSteps.theResponseBodyShouldContain) + ctx.Step(`^the response should have header "([^"]*)"$`, sc.rateLimitSteps.theResponseShouldHaveHeader) + // Common steps ctx.Step(`^the response should be "{\\"([^"]*)":\\"([^"]*)"}"$`, sc.commonSteps.theResponseShouldBe) ctx.Step(`^the response should contain error "([^"]*)"$`, sc.commonSteps.theResponseShouldContainError) diff --git a/pkg/bdd/testserver/server.go b/pkg/bdd/testserver/server.go index 1d62520..059c9e5 100644 --- a/pkg/bdd/testserver/server.go +++ b/pkg/bdd/testserver/server.go @@ -676,6 +676,25 @@ func (s *Server) shouldEnableV2() bool { // createTestConfig creates a test configuration // Pass v2Enabled explicitly to avoid reading env vars deep in the stack func createTestConfig(port int, v2Enabled bool) *config.Config { + // Check for rate limit env vars, use defaults if not set + rateLimitEnabled := true + rateLimitRPM := 60 + rateLimitBurst := 10 + + if env := os.Getenv("DLC_RATE_LIMIT_ENABLED"); env != "" { + rateLimitEnabled = strings.EqualFold(env, "true") || env == "1" + } + if env := os.Getenv("DLC_RATE_LIMIT_REQUESTS_PER_MINUTE"); env != "" { + if val, err := strconv.Atoi(env); err == nil { + rateLimitRPM = val + } + } + if env := os.Getenv("DLC_RATE_LIMIT_BURST_SIZE"); env != "" { + if val, err := strconv.Atoi(env); err == nil { + rateLimitBurst = val + } + } + return &config.Config{ Server: config.ServerConfig{ Host: "0.0.0.0", @@ -702,5 +721,10 @@ func createTestConfig(port int, v2Enabled bool) *config.Config { Logging: config.LoggingConfig{ Level: "debug", }, + RateLimit: config.RateLimitConfig{ + Enabled: rateLimitEnabled, + RequestsPerMinute: rateLimitRPM, + BurstSize: rateLimitBurst, + }, } } diff --git a/pkg/config/config.go b/pkg/config/config.go index f9cbdc1..3ca05b1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -27,6 +27,7 @@ type Config struct { API APIConfig `mapstructure:"api"` Auth AuthConfig `mapstructure:"auth"` Database DatabaseConfig `mapstructure:"database"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` } // ServerConfig holds server-related configuration @@ -97,6 +98,13 @@ type DatabaseConfig struct { ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"` } +// RateLimitConfig holds rate limiting configuration +type RateLimitConfig struct { + Enabled bool `mapstructure:"enabled"` + RequestsPerMinute int `mapstructure:"requests_per_minute"` + BurstSize int `mapstructure:"burst_size"` +} + // VersionInfo holds application version information type VersionInfo struct { Version string `mapstructure:"-"` // Set via ldflags @@ -189,6 +197,11 @@ func LoadConfig() (*Config, error) { // API defaults v.SetDefault("api.v2_enabled", false) + // Rate limit defaults + v.SetDefault("rate_limit.enabled", true) + v.SetDefault("rate_limit.requests_per_minute", 60) + v.SetDefault("rate_limit.burst_size", 10) + // Auth defaults v.SetDefault("auth.jwt_secret", "default-secret-key-please-change-in-production") v.SetDefault("auth.admin_master_password", "admin123") @@ -248,6 +261,11 @@ func LoadConfig() (*Config, error) { // API environment variables v.BindEnv("api.v2_enabled", "DLC_API_V2_ENABLED") + // Rate limit environment variables + v.BindEnv("rate_limit.enabled", "DLC_RATE_LIMIT_ENABLED") + v.BindEnv("rate_limit.requests_per_minute", "DLC_RATE_LIMIT_REQUESTS_PER_MINUTE") + v.BindEnv("rate_limit.burst_size", "DLC_RATE_LIMIT_BURST_SIZE") + // Database environment variables v.BindEnv("database.host", "DLC_DATABASE_HOST") v.BindEnv("database.port", "DLC_DATABASE_PORT") @@ -389,6 +407,27 @@ func (c *Config) GetLogOutput() string { return c.Logging.Output } +// GetRateLimitEnabled returns whether rate limiting is enabled +func (c *Config) GetRateLimitEnabled() bool { + return c.RateLimit.Enabled +} + +// GetRateLimitRequestsPerMinute returns the requests per minute limit +func (c *Config) GetRateLimitRequestsPerMinute() int { + if c.RateLimit.RequestsPerMinute <= 0 { + return 60 + } + return c.RateLimit.RequestsPerMinute +} + +// GetRateLimitBurstSize returns the burst size for rate limiting +func (c *Config) GetRateLimitBurstSize() int { + if c.RateLimit.BurstSize <= 0 { + return 10 + } + return c.RateLimit.BurstSize +} + // GetDatabaseHost returns the database host func (c *Config) GetDatabaseHost() string { if c.Database.Host == "" { diff --git a/pkg/middleware/ratelimit.go b/pkg/middleware/ratelimit.go new file mode 100644 index 0000000..1e39129 --- /dev/null +++ b/pkg/middleware/ratelimit.go @@ -0,0 +1,153 @@ +package middleware + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimitConfig holds the configuration for rate limiting +type RateLimitConfig struct { + Enabled bool + RequestsPerMinute int + BurstSize int +} + +// RateLimiter implements per-IP rate limiting using a token bucket algorithm +type RateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + rate rate.Limit + burst int + ttl time.Duration + enabled bool +} + +type visitor struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// NewRateLimiter creates a new rate limiter with the given configuration +func NewRateLimiter(cfg RateLimitConfig) *RateLimiter { + // Convert requests per minute to events per second + rateLimit := rate.Limit(float64(cfg.RequestsPerMinute) / 60.0) + burst := cfg.BurstSize + if burst <= 0 { + burst = 1 + } + + return &RateLimiter{ + mu: sync.Mutex{}, + visitors: make(map[string]*visitor), + rate: rateLimit, + burst: burst, + ttl: 10 * time.Minute, + enabled: cfg.Enabled, + } +} + +// getVisitor returns the rate limiter for the given IP, creating one if needed. +// It performs TTL-based eviction of stale entries. +func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter { + if !rl.enabled { + // If rate limiting is disabled, return a limiter that always allows + return rate.NewLimiter(rate.Inf, 1) + } + + now := time.Now() + + rl.mu.Lock() + defer rl.mu.Unlock() + + // Clean up old entries periodically (every 100 accesses to avoid lock contention) + if len(rl.visitors) > 0 && len(rl.visitors)%100 == 0 { + rl.cleanupOldVisitors(now) + } + + v, exists := rl.visitors[ip] + if !exists || now.Sub(v.lastSeen) > rl.ttl { + // Create new limiter for this IP + limiter := rate.NewLimiter(rl.rate, rl.burst) + rl.visitors[ip] = &visitor{ + limiter: limiter, + lastSeen: now, + } + return limiter + } + + // Update last seen time + v.lastSeen = now + return v.limiter +} + +// cleanupOldVisitors removes entries that haven't been seen in more than ttl +func (rl *RateLimiter) cleanupOldVisitors(now time.Time) { + for ip, v := range rl.visitors { + if now.Sub(v.lastSeen) > rl.ttl { + delete(rl.visitors, ip) + } + } +} + +// clientIP extracts the client IP address from the request +func (rl *RateLimiter) clientIP(r *http.Request) string { + // Try X-Forwarded-For header first + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ... + // The leftmost is the original client + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // Try X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + + // Fall back to RemoteAddr (strip port if present) + addr := r.RemoteAddr + if colonIdx := strings.LastIndex(addr, ":"); colonIdx != -1 { + return addr[:colonIdx] + } + return addr +} + +// Middleware returns the rate limiting middleware function +func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := rl.clientIP(r) + limiter := rl.getVisitor(ip) + + if !limiter.Allow() { + // Rate limit exceeded + // Calculate retry after based on the rate + // tokens needed = burst, rate = tokens/second + // So wait time = burst / rate (in seconds) + retryAfter := float64(rl.burst) / float64(rl.rate) + if retryAfter <= 0 { + retryAfter = 1 + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", retryAfter)) + w.WriteHeader(http.StatusTooManyRequests) + + response := map[string]interface{}{ + "error": "rate_limited", + "retry_after_seconds": int(retryAfter), + } + json.NewEncoder(w).Encode(response) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/ratelimit_test.go b/pkg/middleware/ratelimit_test.go new file mode 100644 index 0000000..7d4d52c --- /dev/null +++ b/pkg/middleware/ratelimit_test.go @@ -0,0 +1,310 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: true, + RequestsPerMinute: 60, + BurstSize: 5, + } + rl := NewRateLimiter(cfg) + + // Create a simple handler that returns 200 OK + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + + // Make 5 requests (equal to burst size) - all should succeed + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) + } + } +} + +func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: true, + RequestsPerMinute: 60, + BurstSize: 3, + } + rl := NewRateLimiter(cfg) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make 4 requests (exceeding burst of 3) - 4th should be rate limited + for i := 0; i < 3; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.2:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) + } + } + + // 4th request should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.2:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("Request 4: expected status 429, got %d", rr.Code) + } + + // Verify response body + var response map[string]interface{} + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response body: %v", err) + } + + if response["error"] != "rate_limited" { + t.Errorf("Expected error 'rate_limited', got %v", response["error"]) + } + + if _, ok := response["retry_after_seconds"]; !ok { + t.Error("Expected retry_after_seconds in response") + } + + // Verify Retry-After header + if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" { + t.Error("Expected Retry-After header to be set") + } +} + +func TestRateLimiter_DifferentIPsIndependent(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: true, + RequestsPerMinute: 60, + BurstSize: 2, + } + rl := NewRateLimiter(cfg) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // IP1 makes 2 requests (fills its burst) + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code) + } + } + + // IP1's 3rd request should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code) + } + + // IP2 should still be able to make requests (independent rate limit) + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "10.0.0.2:12345" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + + if rr2.Code != http.StatusOK { + t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code) + } +} + +func TestRateLimiter_Disabled(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: false, + RequestsPerMinute: 60, + BurstSize: 1, + } + rl := NewRateLimiter(cfg) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make many requests - all should succeed when disabled + for i := 0; i < 100; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code) + } + } +} + +func TestRateLimiter_TTLExpiration(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: true, + RequestsPerMinute: 60, + BurstSize: 2, + } + rl := NewRateLimiter(cfg) + + // Manually set a short TTL for testing + rl.ttl = 50 * time.Millisecond + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // IP makes 2 requests (fills burst) + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.50:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) + } + } + + // 3rd request should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.50:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("Request 3: expected status 429, got %d", rr.Code) + } + + // Wait for TTL to expire + time.Sleep(60 * time.Millisecond) + + // New request should succeed (new limiter created after TTL expiration) + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "10.0.0.50:12345" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + + if rr2.Code != http.StatusOK { + t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code) + } +} + +func TestRateLimiter_ClientIPExtraction(t *testing.T) { + rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10}) + + tests := []struct { + name string + header map[string]string + remoteAddr string + expected string + }{ + { + name: "X-Forwarded-For single IP", + header: map[string]string{"X-Forwarded-For": "203.0.113.195"}, + remoteAddr: "127.0.0.1:12345", + expected: "203.0.113.195", + }, + { + name: "X-Forwarded-For multiple IPs", + header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"}, + remoteAddr: "127.0.0.1:12345", + expected: "203.0.113.195", + }, + { + name: "X-Real-IP", + header: map[string]string{"X-Real-IP": "203.0.113.50"}, + remoteAddr: "127.0.0.1:12345", + expected: "203.0.113.50", + }, + { + name: "RemoteAddr with port", + header: map[string]string{}, + remoteAddr: "203.0.113.100:54321", + expected: "203.0.113.100", + }, + { + name: "RemoteAddr without port", + header: map[string]string{}, + remoteAddr: "203.0.113.101", + expected: "203.0.113.101", + }, + { + name: "X-Forwarded-For takes precedence over X-Real-IP", + header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"}, + remoteAddr: "127.0.0.1:12345", + expected: "203.0.113.200", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + for k, v := range tt.header { + req.Header.Set(k, v) + } + req.RemoteAddr = tt.remoteAddr + + ip := rl.clientIP(req) + if ip != tt.expected { + t.Errorf("clientIP() = %q, expected %q", ip, tt.expected) + } + }) + } +} + +func TestRateLimiter_ContentTypeHeader(t *testing.T) { + cfg := RateLimitConfig{ + Enabled: true, + RequestsPerMinute: 60, + BurstSize: 1, + } + rl := NewRateLimiter(cfg) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make 1 request to fill burst + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.200:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // 2nd request should be rate limited + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "192.168.1.200:12345" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + + if rr2.Code != http.StatusTooManyRequests { + t.Fatalf("Expected status 429, got %d", rr2.Code) + } + + // Check Content-Type header is JSON + contentType := rr2.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %q", contentType) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index c01c88e..129f58d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -13,12 +13,13 @@ import ( "time" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" + chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog/log" httpSwagger "github.com/swaggo/http-swagger" "dance-lessons-coach/pkg/config" "dance-lessons-coach/pkg/greet" + "dance-lessons-coach/pkg/middleware" "dance-lessons-coach/pkg/telemetry" "dance-lessons-coach/pkg/user" userapi "dance-lessons-coach/pkg/user/api" @@ -125,7 +126,7 @@ func initializeUserServices(cfg *config.Config) (user.UserRepository, user.UserS func (s *Server) setupRoutes() { // Use Zerolog middleware instead of Chi's default logger - s.router.Use(middleware.RequestLogger(&middleware.DefaultLogFormatter{ + s.router.Use(chimiddleware.RequestLogger(&chimiddleware.DefaultLogFormatter{ Logger: &log.Logger, NoColor: false, })) @@ -177,6 +178,13 @@ func (s *Server) registerApiV1Routes(r chi.Router) { greetService := greet.NewService() greetHandler := greet.NewApiV1GreetHandler(greetService) + // Create rate limit middleware + rateLimitMiddleware := middleware.NewRateLimiter(middleware.RateLimitConfig{ + Enabled: s.config.GetRateLimitEnabled(), + RequestsPerMinute: s.config.GetRateLimitRequestsPerMinute(), + BurstSize: s.config.GetRateLimitBurstSize(), + }) + // Create auth middleware if available var authMiddleware *AuthMiddleware if s.userService != nil { @@ -184,6 +192,8 @@ func (s *Server) registerApiV1Routes(r chi.Router) { } r.Route("/greet", func(r chi.Router) { + // Add rate limiting middleware for greet endpoint + r.Use(rateLimitMiddleware.Middleware) // Add optional authentication middleware if authMiddleware != nil { r.Use(authMiddleware.Middleware) @@ -220,8 +230,8 @@ func (s *Server) registerApiV2Routes(r chi.Router) { // getAllMiddlewares returns all middleware including OpenTelemetry if enabled func (s *Server) getAllMiddlewares() []func(http.Handler) http.Handler { middlewares := []func(http.Handler) http.Handler{ - middleware.StripSlashes, - middleware.Recoverer, + chimiddleware.StripSlashes, + chimiddleware.Recoverer, } if s.withOTEL { -- 2.49.1