✨ feat(server): add per-IP rate limit middleware on /api/v1/greet (#22)
Phase 1 of ADR-0022. In-memory per-IP rate limiter on golang.org/x/time/rate. Returns 429 with Retry-After when exceeded. 7 unit tests pass. BDD scenario @skip until testserver rework. Closes #13. ~95% Mistral Vibe autonomous via ICM workspace. Cost ~6.5€ (T5 + resume + trainer commit/PR). Co-authored-by: Gabriel Radureau <arcodange@gmail.com> Co-committed-by: Gabriel Radureau <arcodange@gmail.com>
This commit was merged in pull request #22.
This commit is contained in:
@@ -34,4 +34,19 @@ Feature: Greet Service
|
|||||||
Scenario: v2 greeting with name that is too long
|
Scenario: v2 greeting with name that is too long
|
||||||
Given the server is running with v2 enabled
|
Given the server is running with v2 enabled
|
||||||
When I send a POST request to v2 greet with name "ThisNameIsWayTooLongAndShouldFailValidationBecauseItExceedsTheMaximumAllowedLengthOf100Characters!!!!"
|
When I send a POST request to v2 greet with name "ThisNameIsWayTooLongAndShouldFailValidationBecauseItExceedsTheMaximumAllowedLengthOf100Characters!!!!"
|
||||||
Then the response should contain error "validation_failed"
|
Then the response should contain error "validation_failed"
|
||||||
|
|
||||||
|
@ratelimit @skip @bdd-deferred
|
||||||
|
# NOTE: Functional behavior validated by unit tests in pkg/middleware/ratelimit_test.go.
|
||||||
|
# BDD scenario currently skipped: env-var-based rate limit config does not reach the
|
||||||
|
# already-started test server (architectural limitation of testsetup, not the middleware).
|
||||||
|
# TODO: rework testserver to allow per-scenario rate limit config (admin endpoint or
|
||||||
|
# per-scenario fresh server), then re-enable this scenario.
|
||||||
|
Scenario: Greet endpoint rejects requests over the rate limit
|
||||||
|
Given the server is running with rate limit set to 3 requests per minute and burst 3
|
||||||
|
When I make 3 requests to "/api/v1/greet/Alice"
|
||||||
|
Then all responses should have status 200
|
||||||
|
When I make 1 more request to "/api/v1/greet/Alice"
|
||||||
|
Then the response should have status 429
|
||||||
|
And the response body should contain "rate_limited"
|
||||||
|
And the response should have header "Retry-After"
|
||||||
1
go.mod
1
go.mod
@@ -22,6 +22,7 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk v1.43.0
|
go.opentelemetry.io/otel/sdk v1.43.0
|
||||||
go.opentelemetry.io/otel/trace v1.43.0
|
go.opentelemetry.io/otel/trace v1.43.0
|
||||||
golang.org/x/crypto v0.49.0
|
golang.org/x/crypto v0.49.0
|
||||||
|
golang.org/x/time v0.15.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/gorm v1.31.1
|
gorm.io/gorm v1.31.1
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -206,6 +206,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
|
|||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||||
|
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||||
|
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||||
|
|||||||
94
pkg/bdd/steps/ratelimit_steps.go
Normal file
94
pkg/bdd/steps/ratelimit_steps.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package steps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"dance-lessons-coach/pkg/bdd/testserver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitSteps holds rate limit-related step definitions
|
||||||
|
type RateLimitSteps struct {
|
||||||
|
client *testserver.Client
|
||||||
|
scenarioKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimitSteps creates a new RateLimitSteps instance
|
||||||
|
func NewRateLimitSteps(client *testserver.Client) *RateLimitSteps {
|
||||||
|
return &RateLimitSteps{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetScenarioKey sets the current scenario key for state isolation
|
||||||
|
func (s *RateLimitSteps) SetScenarioKey(key string) {
|
||||||
|
s.scenarioKey = key
|
||||||
|
}
|
||||||
|
|
||||||
|
// theServerIsRunningWithRateLimitSetTo configures rate limit settings via env vars
|
||||||
|
// and ensures the server is running
|
||||||
|
func (s *RateLimitSteps) theServerIsRunningWithRateLimitSetTo(rpm, burst int) error {
|
||||||
|
// Set rate limit env vars for the test server
|
||||||
|
os.Setenv("DLC_RATE_LIMIT_ENABLED", "true")
|
||||||
|
os.Setenv("DLC_RATE_LIMIT_REQUESTS_PER_MINUTE", fmt.Sprintf("%d", rpm))
|
||||||
|
os.Setenv("DLC_RATE_LIMIT_BURST_SIZE", fmt.Sprintf("%d", burst))
|
||||||
|
|
||||||
|
// Verify the server is running
|
||||||
|
return s.client.Request("GET", "/api/ready", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// iMakeNRequestsTo sends N requests to the same endpoint
|
||||||
|
func (s *RateLimitSteps) iMakeNRequestsTo(numRequests int, path string) error {
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
if err := s.client.Request("GET", path, nil); err != nil {
|
||||||
|
return fmt.Errorf("request %d failed: %w", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// allResponsesShouldHaveStatus verifies that all responses had a specific status
|
||||||
|
func (s *RateLimitSteps) allResponsesShouldHaveStatus(statusCode int) error {
|
||||||
|
// Since the client only stores the last response, we check that one
|
||||||
|
// For the rate limit test, after making 3 requests with burst=3, all should succeed
|
||||||
|
actualStatus := s.client.GetLastStatusCode()
|
||||||
|
if actualStatus != statusCode {
|
||||||
|
return fmt.Errorf("expected status %d, got %d", statusCode, actualStatus)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// iMakeOneMoreRequestTo sends 1 more request to the endpoint
|
||||||
|
func (s *RateLimitSteps) iMakeOneMoreRequestTo(path string) error {
|
||||||
|
return s.client.Request("GET", path, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// theResponseShouldHaveStatus verifies the response status code
|
||||||
|
func (s *RateLimitSteps) theResponseShouldHaveStatus(statusCode int) error {
|
||||||
|
actualStatus := s.client.GetLastStatusCode()
|
||||||
|
if actualStatus != statusCode {
|
||||||
|
return fmt.Errorf("expected status %d, got %d", statusCode, actualStatus)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// theResponseBodyShouldContain verifies the response body contains a specific string
|
||||||
|
func (s *RateLimitSteps) theResponseBodyShouldContain(text string) error {
|
||||||
|
body := string(s.client.GetLastBody())
|
||||||
|
if !strings.Contains(body, text) {
|
||||||
|
return fmt.Errorf("expected response body to contain %q, got %q", text, body)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// theResponseShouldHaveHeader verifies that the response has a specific header
|
||||||
|
func (s *RateLimitSteps) theResponseShouldHaveHeader(headerName string) error {
|
||||||
|
resp := s.client.GetLastResponse()
|
||||||
|
if resp == nil {
|
||||||
|
return fmt.Errorf("no response available")
|
||||||
|
}
|
||||||
|
headerValue := resp.Header.Get(headerName)
|
||||||
|
if headerValue == "" {
|
||||||
|
return fmt.Errorf("expected header %q to be set, but it was not found", headerName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ type StepContext struct {
|
|||||||
commonSteps *CommonSteps
|
commonSteps *CommonSteps
|
||||||
jwtRetentionSteps *JWTRetentionSteps
|
jwtRetentionSteps *JWTRetentionSteps
|
||||||
configSteps *ConfigSteps
|
configSteps *ConfigSteps
|
||||||
|
rateLimitSteps *RateLimitSteps
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStepContext creates a new step context
|
// NewStepContext creates a new step context
|
||||||
@@ -28,6 +29,7 @@ func NewStepContext(client *testserver.Client) *StepContext {
|
|||||||
commonSteps: NewCommonSteps(client),
|
commonSteps: NewCommonSteps(client),
|
||||||
jwtRetentionSteps: NewJWTRetentionSteps(client),
|
jwtRetentionSteps: NewJWTRetentionSteps(client),
|
||||||
configSteps: NewConfigSteps(client),
|
configSteps: NewConfigSteps(client),
|
||||||
|
rateLimitSteps: NewRateLimitSteps(client),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,9 @@ func SetScenarioKeyForAllSteps(sc *StepContext, key string) {
|
|||||||
if sc.commonSteps != nil {
|
if sc.commonSteps != nil {
|
||||||
sc.commonSteps.SetScenarioKey(key)
|
sc.commonSteps.SetScenarioKey(key)
|
||||||
}
|
}
|
||||||
|
if sc.rateLimitSteps != nil {
|
||||||
|
sc.rateLimitSteps.SetScenarioKey(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,6 +299,15 @@ func InitializeAllSteps(ctx *godog.ScenarioContext, client *testserver.Client, s
|
|||||||
ctx.Step(`^the audit entry should contain the previous and new values$`, sc.configSteps.theAuditEntryShouldContainThePreviousAndNewValues)
|
ctx.Step(`^the audit entry should contain the previous and new values$`, sc.configSteps.theAuditEntryShouldContainThePreviousAndNewValues)
|
||||||
ctx.Step(`^the audit entry should contain the timestamp of the change$`, sc.configSteps.theAuditEntryShouldContainTheTimestampOfTheChange)
|
ctx.Step(`^the audit entry should contain the timestamp of the change$`, sc.configSteps.theAuditEntryShouldContainTheTimestampOfTheChange)
|
||||||
|
|
||||||
|
// Rate limit steps
|
||||||
|
ctx.Step(`^the server is running with rate limit set to (\d+) requests per minute and burst (\d+)$`, sc.rateLimitSteps.theServerIsRunningWithRateLimitSetTo)
|
||||||
|
ctx.Step(`^I make (\d+) requests to "([^"]*)"$`, sc.rateLimitSteps.iMakeNRequestsTo)
|
||||||
|
ctx.Step(`^all responses should have status (\d+)$`, sc.rateLimitSteps.allResponsesShouldHaveStatus)
|
||||||
|
ctx.Step(`^I make 1 more request to "([^"]*)"$`, sc.rateLimitSteps.iMakeOneMoreRequestTo)
|
||||||
|
ctx.Step(`^the response should have status (\d+)$`, sc.rateLimitSteps.theResponseShouldHaveStatus)
|
||||||
|
ctx.Step(`^the response body should contain "([^"]*)"$`, sc.rateLimitSteps.theResponseBodyShouldContain)
|
||||||
|
ctx.Step(`^the response should have header "([^"]*)"$`, sc.rateLimitSteps.theResponseShouldHaveHeader)
|
||||||
|
|
||||||
// Common steps
|
// Common steps
|
||||||
ctx.Step(`^the response should be "{\\"([^"]*)":\\"([^"]*)"}"$`, sc.commonSteps.theResponseShouldBe)
|
ctx.Step(`^the response should be "{\\"([^"]*)":\\"([^"]*)"}"$`, sc.commonSteps.theResponseShouldBe)
|
||||||
ctx.Step(`^the response should contain error "([^"]*)"$`, sc.commonSteps.theResponseShouldContainError)
|
ctx.Step(`^the response should contain error "([^"]*)"$`, sc.commonSteps.theResponseShouldContainError)
|
||||||
|
|||||||
@@ -676,6 +676,25 @@ func (s *Server) shouldEnableV2() bool {
|
|||||||
// createTestConfig creates a test configuration
|
// createTestConfig creates a test configuration
|
||||||
// Pass v2Enabled explicitly to avoid reading env vars deep in the stack
|
// Pass v2Enabled explicitly to avoid reading env vars deep in the stack
|
||||||
func createTestConfig(port int, v2Enabled bool) *config.Config {
|
func createTestConfig(port int, v2Enabled bool) *config.Config {
|
||||||
|
// Check for rate limit env vars, use defaults if not set
|
||||||
|
rateLimitEnabled := true
|
||||||
|
rateLimitRPM := 60
|
||||||
|
rateLimitBurst := 10
|
||||||
|
|
||||||
|
if env := os.Getenv("DLC_RATE_LIMIT_ENABLED"); env != "" {
|
||||||
|
rateLimitEnabled = strings.EqualFold(env, "true") || env == "1"
|
||||||
|
}
|
||||||
|
if env := os.Getenv("DLC_RATE_LIMIT_REQUESTS_PER_MINUTE"); env != "" {
|
||||||
|
if val, err := strconv.Atoi(env); err == nil {
|
||||||
|
rateLimitRPM = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if env := os.Getenv("DLC_RATE_LIMIT_BURST_SIZE"); env != "" {
|
||||||
|
if val, err := strconv.Atoi(env); err == nil {
|
||||||
|
rateLimitBurst = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &config.Config{
|
return &config.Config{
|
||||||
Server: config.ServerConfig{
|
Server: config.ServerConfig{
|
||||||
Host: "0.0.0.0",
|
Host: "0.0.0.0",
|
||||||
@@ -702,5 +721,10 @@ func createTestConfig(port int, v2Enabled bool) *config.Config {
|
|||||||
Logging: config.LoggingConfig{
|
Logging: config.LoggingConfig{
|
||||||
Level: "debug",
|
Level: "debug",
|
||||||
},
|
},
|
||||||
|
RateLimit: config.RateLimitConfig{
|
||||||
|
Enabled: rateLimitEnabled,
|
||||||
|
RequestsPerMinute: rateLimitRPM,
|
||||||
|
BurstSize: rateLimitBurst,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type Config struct {
|
|||||||
API APIConfig `mapstructure:"api"`
|
API APIConfig `mapstructure:"api"`
|
||||||
Auth AuthConfig `mapstructure:"auth"`
|
Auth AuthConfig `mapstructure:"auth"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServerConfig holds server-related configuration
|
||||||
@@ -97,6 +98,13 @@ type DatabaseConfig struct {
|
|||||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateLimitConfig holds rate limiting configuration
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
RequestsPerMinute int `mapstructure:"requests_per_minute"`
|
||||||
|
BurstSize int `mapstructure:"burst_size"`
|
||||||
|
}
|
||||||
|
|
||||||
// VersionInfo holds application version information
|
// VersionInfo holds application version information
|
||||||
type VersionInfo struct {
|
type VersionInfo struct {
|
||||||
Version string `mapstructure:"-"` // Set via ldflags
|
Version string `mapstructure:"-"` // Set via ldflags
|
||||||
@@ -189,6 +197,11 @@ func LoadConfig() (*Config, error) {
|
|||||||
// API defaults
|
// API defaults
|
||||||
v.SetDefault("api.v2_enabled", false)
|
v.SetDefault("api.v2_enabled", false)
|
||||||
|
|
||||||
|
// Rate limit defaults
|
||||||
|
v.SetDefault("rate_limit.enabled", true)
|
||||||
|
v.SetDefault("rate_limit.requests_per_minute", 60)
|
||||||
|
v.SetDefault("rate_limit.burst_size", 10)
|
||||||
|
|
||||||
// Auth defaults
|
// Auth defaults
|
||||||
v.SetDefault("auth.jwt_secret", "default-secret-key-please-change-in-production")
|
v.SetDefault("auth.jwt_secret", "default-secret-key-please-change-in-production")
|
||||||
v.SetDefault("auth.admin_master_password", "admin123")
|
v.SetDefault("auth.admin_master_password", "admin123")
|
||||||
@@ -248,6 +261,11 @@ func LoadConfig() (*Config, error) {
|
|||||||
// API environment variables
|
// API environment variables
|
||||||
v.BindEnv("api.v2_enabled", "DLC_API_V2_ENABLED")
|
v.BindEnv("api.v2_enabled", "DLC_API_V2_ENABLED")
|
||||||
|
|
||||||
|
// Rate limit environment variables
|
||||||
|
v.BindEnv("rate_limit.enabled", "DLC_RATE_LIMIT_ENABLED")
|
||||||
|
v.BindEnv("rate_limit.requests_per_minute", "DLC_RATE_LIMIT_REQUESTS_PER_MINUTE")
|
||||||
|
v.BindEnv("rate_limit.burst_size", "DLC_RATE_LIMIT_BURST_SIZE")
|
||||||
|
|
||||||
// Database environment variables
|
// Database environment variables
|
||||||
v.BindEnv("database.host", "DLC_DATABASE_HOST")
|
v.BindEnv("database.host", "DLC_DATABASE_HOST")
|
||||||
v.BindEnv("database.port", "DLC_DATABASE_PORT")
|
v.BindEnv("database.port", "DLC_DATABASE_PORT")
|
||||||
@@ -389,6 +407,27 @@ func (c *Config) GetLogOutput() string {
|
|||||||
return c.Logging.Output
|
return c.Logging.Output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRateLimitEnabled returns whether rate limiting is enabled
|
||||||
|
func (c *Config) GetRateLimitEnabled() bool {
|
||||||
|
return c.RateLimit.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRateLimitRequestsPerMinute returns the requests per minute limit
|
||||||
|
func (c *Config) GetRateLimitRequestsPerMinute() int {
|
||||||
|
if c.RateLimit.RequestsPerMinute <= 0 {
|
||||||
|
return 60
|
||||||
|
}
|
||||||
|
return c.RateLimit.RequestsPerMinute
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRateLimitBurstSize returns the burst size for rate limiting
|
||||||
|
func (c *Config) GetRateLimitBurstSize() int {
|
||||||
|
if c.RateLimit.BurstSize <= 0 {
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
return c.RateLimit.BurstSize
|
||||||
|
}
|
||||||
|
|
||||||
// GetDatabaseHost returns the database host
|
// GetDatabaseHost returns the database host
|
||||||
func (c *Config) GetDatabaseHost() string {
|
func (c *Config) GetDatabaseHost() string {
|
||||||
if c.Database.Host == "" {
|
if c.Database.Host == "" {
|
||||||
|
|||||||
153
pkg/middleware/ratelimit.go
Normal file
153
pkg/middleware/ratelimit.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitConfig holds the configuration for rate limiting
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
RequestsPerMinute int
|
||||||
|
BurstSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiter implements per-IP rate limiting using a token bucket algorithm
|
||||||
|
type RateLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
visitors map[string]*visitor
|
||||||
|
rate rate.Limit
|
||||||
|
burst int
|
||||||
|
ttl time.Duration
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type visitor struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter creates a new rate limiter with the given configuration
|
||||||
|
func NewRateLimiter(cfg RateLimitConfig) *RateLimiter {
|
||||||
|
// Convert requests per minute to events per second
|
||||||
|
rateLimit := rate.Limit(float64(cfg.RequestsPerMinute) / 60.0)
|
||||||
|
burst := cfg.BurstSize
|
||||||
|
if burst <= 0 {
|
||||||
|
burst = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RateLimiter{
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
visitors: make(map[string]*visitor),
|
||||||
|
rate: rateLimit,
|
||||||
|
burst: burst,
|
||||||
|
ttl: 10 * time.Minute,
|
||||||
|
enabled: cfg.Enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVisitor returns the rate limiter for the given IP, creating one if needed.
|
||||||
|
// It performs TTL-based eviction of stale entries.
|
||||||
|
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
|
||||||
|
if !rl.enabled {
|
||||||
|
// If rate limiting is disabled, return a limiter that always allows
|
||||||
|
return rate.NewLimiter(rate.Inf, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up old entries periodically (every 100 accesses to avoid lock contention)
|
||||||
|
if len(rl.visitors) > 0 && len(rl.visitors)%100 == 0 {
|
||||||
|
rl.cleanupOldVisitors(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, exists := rl.visitors[ip]
|
||||||
|
if !exists || now.Sub(v.lastSeen) > rl.ttl {
|
||||||
|
// Create new limiter for this IP
|
||||||
|
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
||||||
|
rl.visitors[ip] = &visitor{
|
||||||
|
limiter: limiter,
|
||||||
|
lastSeen: now,
|
||||||
|
}
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last seen time
|
||||||
|
v.lastSeen = now
|
||||||
|
return v.limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldVisitors removes entries that haven't been seen in more than ttl
|
||||||
|
func (rl *RateLimiter) cleanupOldVisitors(now time.Time) {
|
||||||
|
for ip, v := range rl.visitors {
|
||||||
|
if now.Sub(v.lastSeen) > rl.ttl {
|
||||||
|
delete(rl.visitors, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientIP extracts the client IP address from the request
|
||||||
|
func (rl *RateLimiter) clientIP(r *http.Request) string {
|
||||||
|
// Try X-Forwarded-For header first
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ...
|
||||||
|
// The leftmost is the original client
|
||||||
|
ips := strings.Split(xff, ",")
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return strings.TrimSpace(ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try X-Real-IP header
|
||||||
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||||
|
return strings.TrimSpace(xri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to RemoteAddr (strip port if present)
|
||||||
|
addr := r.RemoteAddr
|
||||||
|
if colonIdx := strings.LastIndex(addr, ":"); colonIdx != -1 {
|
||||||
|
return addr[:colonIdx]
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns the rate limiting middleware function
|
||||||
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := rl.clientIP(r)
|
||||||
|
limiter := rl.getVisitor(ip)
|
||||||
|
|
||||||
|
if !limiter.Allow() {
|
||||||
|
// Rate limit exceeded
|
||||||
|
// Calculate retry after based on the rate
|
||||||
|
// tokens needed = burst, rate = tokens/second
|
||||||
|
// So wait time = burst / rate (in seconds)
|
||||||
|
retryAfter := float64(rl.burst) / float64(rl.rate)
|
||||||
|
if retryAfter <= 0 {
|
||||||
|
retryAfter = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", retryAfter))
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"error": "rate_limited",
|
||||||
|
"retry_after_seconds": int(retryAfter),
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
310
pkg/middleware/ratelimit_test.go
Normal file
310
pkg/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiter_AllowsRequestsWithinBurst(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 5,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
// Create a simple handler that returns 200 OK
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make 5 requests (equal to burst size) - all should succeed
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_BlocksRequestsExceedingBurst(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 3,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make 4 requests (exceeding burst of 3) - 4th should be rate limited
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.2:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4th request should be rate limited
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.2:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("Request 4: expected status 429, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response body
|
||||||
|
var response map[string]interface{}
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||||
|
t.Fatalf("Failed to decode response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["error"] != "rate_limited" {
|
||||||
|
t.Errorf("Expected error 'rate_limited', got %v", response["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := response["retry_after_seconds"]; !ok {
|
||||||
|
t.Error("Expected retry_after_seconds in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Retry-After header
|
||||||
|
if retryAfter := rr.Header().Get("Retry-After"); retryAfter == "" {
|
||||||
|
t.Error("Expected Retry-After header to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 2,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// IP1 makes 2 requests (fills its burst)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("IP1 request %d: expected status 200, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP1's 3rd request should be rate limited
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("IP1 request 3: expected status 429, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP2 should still be able to make requests (independent rate limit)
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.RemoteAddr = "10.0.0.2:12345"
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr2, req2)
|
||||||
|
|
||||||
|
if rr2.Code != http.StatusOK {
|
||||||
|
t.Errorf("IP2 request 1: expected status 200, got %d", rr2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_Disabled(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: false,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 1,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make many requests - all should succeed when disabled
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request %d with disabled rate limiter: expected status 200, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_TTLExpiration(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 2,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
// Manually set a short TTL for testing
|
||||||
|
rl.ttl = 50 * time.Millisecond
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// IP makes 2 requests (fills burst)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.50:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3rd request should be rate limited
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.50:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("Request 3: expected status 429, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for TTL to expire
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// New request should succeed (new limiter created after TTL expiration)
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.RemoteAddr = "10.0.0.50:12345"
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr2, req2)
|
||||||
|
|
||||||
|
if rr2.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request after TTL: expected status 200, got %d", rr2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_ClientIPExtraction(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(RateLimitConfig{Enabled: true, RequestsPerMinute: 60, BurstSize: 10})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header map[string]string
|
||||||
|
remoteAddr string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For single IP",
|
||||||
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195"},
|
||||||
|
remoteAddr: "127.0.0.1:12345",
|
||||||
|
expected: "203.0.113.195",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For multiple IPs",
|
||||||
|
header: map[string]string{"X-Forwarded-For": "203.0.113.195, 70.41.3.18, 150.172.238.178"},
|
||||||
|
remoteAddr: "127.0.0.1:12345",
|
||||||
|
expected: "203.0.113.195",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Real-IP",
|
||||||
|
header: map[string]string{"X-Real-IP": "203.0.113.50"},
|
||||||
|
remoteAddr: "127.0.0.1:12345",
|
||||||
|
expected: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr with port",
|
||||||
|
header: map[string]string{},
|
||||||
|
remoteAddr: "203.0.113.100:54321",
|
||||||
|
expected: "203.0.113.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr without port",
|
||||||
|
header: map[string]string{},
|
||||||
|
remoteAddr: "203.0.113.101",
|
||||||
|
expected: "203.0.113.101",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||||
|
header: map[string]string{"X-Forwarded-For": "203.0.113.200", "X-Real-IP": "203.0.113.201"},
|
||||||
|
remoteAddr: "127.0.0.1:12345",
|
||||||
|
expected: "203.0.113.200",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
for k, v := range tt.header {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
|
||||||
|
ip := rl.clientIP(req)
|
||||||
|
if ip != tt.expected {
|
||||||
|
t.Errorf("clientIP() = %q, expected %q", ip, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_ContentTypeHeader(t *testing.T) {
|
||||||
|
cfg := RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
BurstSize: 1,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiter(cfg)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make 1 request to fill burst
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.200:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// 2nd request should be rate limited
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.RemoteAddr = "192.168.1.200:12345"
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr2, req2)
|
||||||
|
|
||||||
|
if rr2.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("Expected status 429, got %d", rr2.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Content-Type header is JSON
|
||||||
|
contentType := rr2.Header().Get("Content-Type")
|
||||||
|
if contentType != "application/json" {
|
||||||
|
t.Errorf("Expected Content-Type: application/json, got %q", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
httpSwagger "github.com/swaggo/http-swagger"
|
httpSwagger "github.com/swaggo/http-swagger"
|
||||||
|
|
||||||
"dance-lessons-coach/pkg/config"
|
"dance-lessons-coach/pkg/config"
|
||||||
"dance-lessons-coach/pkg/greet"
|
"dance-lessons-coach/pkg/greet"
|
||||||
|
"dance-lessons-coach/pkg/middleware"
|
||||||
"dance-lessons-coach/pkg/telemetry"
|
"dance-lessons-coach/pkg/telemetry"
|
||||||
"dance-lessons-coach/pkg/user"
|
"dance-lessons-coach/pkg/user"
|
||||||
userapi "dance-lessons-coach/pkg/user/api"
|
userapi "dance-lessons-coach/pkg/user/api"
|
||||||
@@ -125,7 +126,7 @@ func initializeUserServices(cfg *config.Config) (user.UserRepository, user.UserS
|
|||||||
|
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
// Use Zerolog middleware instead of Chi's default logger
|
// Use Zerolog middleware instead of Chi's default logger
|
||||||
s.router.Use(middleware.RequestLogger(&middleware.DefaultLogFormatter{
|
s.router.Use(chimiddleware.RequestLogger(&chimiddleware.DefaultLogFormatter{
|
||||||
Logger: &log.Logger,
|
Logger: &log.Logger,
|
||||||
NoColor: false,
|
NoColor: false,
|
||||||
}))
|
}))
|
||||||
@@ -177,6 +178,13 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
|
|||||||
greetService := greet.NewService()
|
greetService := greet.NewService()
|
||||||
greetHandler := greet.NewApiV1GreetHandler(greetService)
|
greetHandler := greet.NewApiV1GreetHandler(greetService)
|
||||||
|
|
||||||
|
// Create rate limit middleware
|
||||||
|
rateLimitMiddleware := middleware.NewRateLimiter(middleware.RateLimitConfig{
|
||||||
|
Enabled: s.config.GetRateLimitEnabled(),
|
||||||
|
RequestsPerMinute: s.config.GetRateLimitRequestsPerMinute(),
|
||||||
|
BurstSize: s.config.GetRateLimitBurstSize(),
|
||||||
|
})
|
||||||
|
|
||||||
// Create auth middleware if available
|
// Create auth middleware if available
|
||||||
var authMiddleware *AuthMiddleware
|
var authMiddleware *AuthMiddleware
|
||||||
if s.userService != nil {
|
if s.userService != nil {
|
||||||
@@ -184,6 +192,8 @@ func (s *Server) registerApiV1Routes(r chi.Router) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Route("/greet", func(r chi.Router) {
|
r.Route("/greet", func(r chi.Router) {
|
||||||
|
// Add rate limiting middleware for greet endpoint
|
||||||
|
r.Use(rateLimitMiddleware.Middleware)
|
||||||
// Add optional authentication middleware
|
// Add optional authentication middleware
|
||||||
if authMiddleware != nil {
|
if authMiddleware != nil {
|
||||||
r.Use(authMiddleware.Middleware)
|
r.Use(authMiddleware.Middleware)
|
||||||
@@ -220,8 +230,8 @@ func (s *Server) registerApiV2Routes(r chi.Router) {
|
|||||||
// getAllMiddlewares returns all middleware including OpenTelemetry if enabled
|
// getAllMiddlewares returns all middleware including OpenTelemetry if enabled
|
||||||
func (s *Server) getAllMiddlewares() []func(http.Handler) http.Handler {
|
func (s *Server) getAllMiddlewares() []func(http.Handler) http.Handler {
|
||||||
middlewares := []func(http.Handler) http.Handler{
|
middlewares := []func(http.Handler) http.Handler{
|
||||||
middleware.StripSlashes,
|
chimiddleware.StripSlashes,
|
||||||
middleware.Recoverer,
|
chimiddleware.Recoverer,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.withOTEL {
|
if s.withOTEL {
|
||||||
|
|||||||
Reference in New Issue
Block a user