From 3f88621ee1779ae2918d96e8b1baad55cb9e715d Mon Sep 17 00:00:00 2001 From: blockchainluffy Date: Mon, 21 Apr 2025 15:15:08 +0530 Subject: [PATCH 1/3] added basic rate limits per endpoint --- .env.example | 4 +- docker-compose.yml | 2 + go.mod | 2 +- go.sum | 2 + internal/ipratelimiter.go/ipratelimiter.go | 313 +++++++++++++++++++++ internal/router/router.go | 31 +- main.go | 17 +- tests/integration/init_test.go | 2 +- 8 files changed, 365 insertions(+), 8 deletions(-) create mode 100644 internal/ipratelimiter.go/ipratelimiter.go diff --git a/.env.example b/.env.example index 87631e9..088842e 100644 --- a/.env.example +++ b/.env.example @@ -21,4 +21,6 @@ P2P_PORT= LOG_LEVEL= METRICS_ENABLED=true METRICS_HOST="[::]" -METRICS_PORT=4000 \ No newline at end of file +METRICS_PORT=4000 +DEFAULT_RATE_LIMIT=100 +REGISTER_IDENTITY_RATE_LIMIT=30 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 358c823..f5d4232 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,6 +27,8 @@ services: METRICS_ENABLED: METRICS_HOST: METRICS_PORT: + DEFAULT_RATE_LIMIT: + REGISTER_IDENTITY_RATE_LIMIT: ports: - "8001:8001" - "23003:23003" diff --git a/go.mod b/go.mod index e399f00..b419715 100644 --- a/go.mod +++ b/go.mod @@ -218,7 +218,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/supranational/blst v0.3.12 + github.com/supranational/blst v0.3.14 github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 2c53537..44d178c 100644 --- a/go.sum +++ b/go.sum @@ -767,6 +767,8 @@ github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8 github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/supranational/blst v0.3.12 h1:Vfas2U2CFHhniv2QkUm2OVa1+pGTdqtpqm9NnhUUbZ8= github.com/supranational/blst v0.3.12/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= +github.com/supranational/blst v0.3.14 h1:xNMoHRJOTwMn63ip6qoWJ2Ymgvj7E2b9jY2FAwY+qRo= +github.com/supranational/blst v0.3.14/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg= github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M= diff --git a/internal/ipratelimiter.go/ipratelimiter.go b/internal/ipratelimiter.go/ipratelimiter.go new file mode 100644 index 0000000..ebfaf81 --- /dev/null +++ b/internal/ipratelimiter.go/ipratelimiter.go @@ -0,0 +1,313 @@ +package ipratelimiter + +import ( + "context" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service" +) + +// EndpointLimit defines a rate limit for a specific endpoint +type EndpointLimit struct { + Path string + MaxRequestsPerDay int +} + +// Request stores information about an API request +type Request struct { + Timestamp time.Time + Day int // Day of month (1-31) + Month int // Month (1-12) + Year int // Year (e.g., 2025) +} + +// IPEndpointLimiter implements rate limiting per IP per endpoint on a daily basis +type IPEndpointLimiter struct { + mu sync.Mutex + limitsPerIP map[string]map[string][]Request // map[ip]map[endpoint][]requests + endpointSettings map[string]*EndpointLimit // map[endpoint]*EndpointLimit + defaultLimit *EndpointLimit + cleanup *time.Ticker +} + +// NewIPEndpointLimiter creates a rate limiter with endpoint-specific daily limits +func NewIPEndpointLimiter(defaultMaxRequestsPerDay int) *IPEndpointLimiter { + limiter := &IPEndpointLimiter{ + limitsPerIP: make(map[string]map[string][]Request), + endpointSettings: make(map[string]*EndpointLimit), + defaultLimit: &EndpointLimit{ + Path: "*", + MaxRequestsPerDay: defaultMaxRequestsPerDay, + }, + cleanup: time.NewTicker(1 * time.Hour), // Hourly cleanup for daily limits + } + + return limiter +} + +// Start runs the cleanup routine - periodically removes requests from previous days +func (rl *IPEndpointLimiter) Start(ctx context.Context, runner service.Runner) error { + runner.Go(func() error { + defer rl.Close() + + for range rl.cleanup.C { + rl.mu.Lock() + now := time.Now() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) + + // Keep only the last 7 days of data for analysis purposes + oldestToKeep := today.AddDate(0, 0, -7) + + // For each IP address + for ip, endpoints := range rl.limitsPerIP { + endpointsToRemove := []string{} + + // For each endpoint this IP has accessed + for endpoint, requests := range endpoints { + var recentRequests []Request + + // Keep only recent requests + for _, req := range requests { + requestTime := time.Date(req.Year, time.Month(req.Month), req.Day, 0, 0, 0, 0, time.UTC) + if requestTime.After(oldestToKeep) || requestTime.Equal(oldestToKeep) { + recentRequests = append(recentRequests, req) + } + } + + // Update or mark for removal + if len(recentRequests) > 0 { + endpoints[endpoint] = recentRequests + } else { + endpointsToRemove = append(endpointsToRemove, endpoint) + } + } + + // Remove empty endpoints + for _, endpoint := range endpointsToRemove { + delete(endpoints, endpoint) + } + + // Remove IP if no endpoints left + if len(endpoints) == 0 { + delete(rl.limitsPerIP, ip) + } + } + rl.mu.Unlock() + } + return nil + }) + return nil +} + +// SetLimit adds or updates a rate limit for a specific endpoint +func (rl *IPEndpointLimiter) SetLimit(path string, maxRequestsPerDay int) { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.endpointSettings[path] = &EndpointLimit{ + Path: path, + MaxRequestsPerDay: maxRequestsPerDay, + } +} + +// getEndpointLimit returns the limit settings for the given path +func (rl *IPEndpointLimiter) getEndpointLimit(path string) *EndpointLimit { + rl.mu.Lock() + defer rl.mu.Unlock() + + if limit, exists := rl.endpointSettings[path]; exists { + return limit + } + return rl.defaultLimit +} + +// Allow checks if a request from the given IP to the given endpoint is allowed +func (rl *IPEndpointLimiter) Allow(ip, endpoint string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now().UTC() + currentDay := now.Day() + currentMonth := int(now.Month()) + currentYear := now.Year() + + // Get the limit for this endpoint + limit := rl.defaultLimit + if l, exists := rl.endpointSettings[endpoint]; exists { + limit = l + } + + // Initialize IP map if not exists + if _, exists := rl.limitsPerIP[ip]; !exists { + rl.limitsPerIP[ip] = make(map[string][]Request) + } + + //TODO: need to check if we have to calculate all requests in the rate limit or manage them individualy + + // Initialize endpoint in IP map if not exists + if _, exists := rl.limitsPerIP[ip][endpoint]; !exists { + newRequest := Request{ + Timestamp: now, + Day: currentDay, + Month: currentMonth, + Year: currentYear, + } + rl.limitsPerIP[ip][endpoint] = []Request{newRequest} + return true + } + + // Count requests from the current day + requests := rl.limitsPerIP[ip][endpoint] + var todayRequests []Request + for _, req := range requests { + if req.Day == currentDay && req.Month == currentMonth && req.Year == currentYear { + todayRequests = append(todayRequests, req) + } + } + + // Check if adding this request would exceed the limit + if len(todayRequests) >= limit.MaxRequestsPerDay { + // Record this request anyway for proper counting + newRequest := Request{ + Timestamp: now, + Day: currentDay, + Month: currentMonth, + Year: currentYear, + } + rl.limitsPerIP[ip][endpoint] = append(requests, newRequest) + return false + } + + // Record and allow + newRequest := Request{ + Timestamp: now, + Day: currentDay, + Month: currentMonth, + Year: currentYear, + } + rl.limitsPerIP[ip][endpoint] = append(requests, newRequest) + return true +} + +// GetCurrentUsage returns the number of requests made today +func (rl *IPEndpointLimiter) GetCurrentUsage(ip, endpoint string) int { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + currentDay := now.Day() + currentMonth := int(now.Month()) + currentYear := now.Year() + + // Check if IP exists + endpoints, exists := rl.limitsPerIP[ip] + if !exists { + return 0 + } + + // Check if endpoint exists for this IP + requests, exists := endpoints[endpoint] + if !exists { + return 0 + } + + // Count requests from today + count := 0 + for _, req := range requests { + if req.Day == currentDay && req.Month == currentMonth && req.Year == currentYear { + count++ + } + } + + return count +} + +// GetRemainingTime returns the time until rate limit reset +func (rl *IPEndpointLimiter) GetRemainingTime() time.Duration { + now := time.Now() + tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.Local) + return tomorrow.Sub(now) +} + +// Close stops the cleanup ticker +func (rl *IPEndpointLimiter) Close() { + rl.cleanup.Stop() +} + +// RateLimitMiddleware applies rate limiting to routes +func (rl *IPEndpointLimiter) RateLimitMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Get client IP + ip := getClientIP(c) + + // Get endpoint path + path := c.FullPath() + if path == "" { + path = c.Request.URL.Path + } + + // Check if allowed + if !rl.Allow(ip, path) { + limit := rl.getEndpointLimit(path) + usage := rl.GetCurrentUsage(ip, path) + remaining := 0 + if limit.MaxRequestsPerDay > usage { + remaining = limit.MaxRequestsPerDay - usage + } + + // Get seconds until reset + resetSeconds := int(rl.GetRemainingTime().Seconds()) + resetHours := int(rl.GetRemainingTime().Hours()) + resetMinutes := int(rl.GetRemainingTime().Minutes()) % 60 + + c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", limit.MaxRequestsPerDay)) + c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) + c.Header("X-RateLimit-Used", fmt.Sprintf("%d", usage)) + c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix()+int64(resetSeconds))) + + c.JSON(http.StatusTooManyRequests, gin.H{ + "error": "Daily rate limit exceeded for this endpoint.", + "limit": limit.MaxRequestsPerDay, + "used": usage, + "reset_in": fmt.Sprintf("%dh %dm", resetHours, resetMinutes), + }) + c.Abort() + return + } + + c.Next() + } +} + +// getClientIP extracts the client IP from the request +func getClientIP(c *gin.Context) string { + // Try X-Forwarded-For header first + if xForwardedFor := c.Request.Header.Get("X-Forwarded-For"); xForwardedFor != "" { + ips := strings.Split(xForwardedFor, ",") + if len(ips) > 0 { + clientIP := strings.TrimSpace(ips[0]) + if clientIP != "" { + return clientIP + } + } + } + + // Try X-Real-IP header + if xRealIP := c.Request.Header.Get("X-Real-IP"); xRealIP != "" { + return strings.TrimSpace(xRealIP) + } + + // Fall back to RemoteAddr + ip, _, err := net.SplitHostPort(c.Request.RemoteAddr) + if err != nil { + return c.Request.RemoteAddr + } + + return ip +} diff --git a/internal/router/router.go b/internal/router/router.go index 7f6628c..dc2d3f8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,12 +1,17 @@ package router import ( + "fmt" + "os" + "strconv" + "github.com/ethereum/go-ethereum/ethclient" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5/pgxpool" "github.com/shutter-network/shutter-api/common" "github.com/shutter-network/shutter-api/docs" + "github.com/shutter-network/shutter-api/internal/ipratelimiter.go" "github.com/shutter-network/shutter-api/internal/middleware" "github.com/shutter-network/shutter-api/internal/service" swaggerFiles "github.com/swaggo/files" @@ -18,13 +23,35 @@ func NewRouter( contract *common.Contract, ethClient *ethclient.Client, config *common.Config, -) *gin.Engine { +) (*gin.Engine, *ipratelimiter.IPEndpointLimiter) { + + defaultRateLimitStr := os.Getenv("DEFAULT_RATE_LIMIT") + defaultRateLimit, err := strconv.ParseInt(defaultRateLimitStr, 10, 0) + if err != nil { + panic(fmt.Errorf("failed to convert DEFAULT_RATE_LIMIT to int: %w", err)) + } + + // Create limiter with default settings + limiter := ipratelimiter.NewIPEndpointLimiter(int(defaultRateLimit)) + + registerRateLimitStr := os.Getenv("REGISTER_IDENTITY_RATE_LIMIT") + registerRateLimit, err := strconv.ParseInt(registerRateLimitStr, 10, 0) + if err != nil { + panic(fmt.Errorf("failed to convert REGISTER_IDENTITY_RATE_LIMIT to int: %w", err)) + } + + // Configure endpoint-specific monthly limits + limiter.SetLimit("/api/register_identity", int(registerRateLimit)) + router := gin.New() router.Use(gin.Logger()) router.Use(gin.Recovery()) router.Use(cors.Default()) router.Use(middleware.ErrorHandler()) + // Apply rate limiting to all routes + router.Use(limiter.RateLimitMiddleware()) + cryptoService := service.NewCryptoService(db, contract, ethClient, config) docs.SwaggerInfo.BasePath = "/api" api := router.Group("/api") @@ -37,5 +64,5 @@ func NewRouter( router.GET("/docs/*any", ginSwagger.WrapHandler(swaggerFiles.Handler, func(c *ginSwagger.Config) { c.Title = "Shutter-API" })) - return router + return router, limiter } diff --git a/main.go b/main.go index f04170d..5b4a65e 100644 --- a/main.go +++ b/main.go @@ -179,11 +179,21 @@ func main() { log.Err(err).Msg("unable to parse keyper http url") return } - app := router.NewRouter(db, contract, client, config) - watcher := watcher.NewWatcher(config, db) - group, deferFn := service.RunBackground(ctx, watcher) + app, limiter := router.NewRouter(db, contract, client, config) + group, deferFn := service.RunBackground(ctx, limiter) defer deferFn() + go func() { + if err := group.Wait(); err != nil { + log.Err(err).Msg("ipratelimiter service failed") + panic(err) + } + }() + + watcher := watcher.NewWatcher(config, db) + group, watcherdeferFn := service.RunBackground(ctx, watcher) + defer watcherdeferFn() + if metricsConfig.Enabled { group, deferFn := service.RunBackground(ctx, metricsServer) defer deferFn() @@ -197,6 +207,7 @@ func main() { go func() { if err := group.Wait(); err != nil { log.Err(err).Msg("watcher service failed") + panic(err) } }() app.Run("0.0.0.0:" + port) diff --git a/tests/integration/init_test.go b/tests/integration/init_test.go index 743277c..dfb9967 100644 --- a/tests/integration/init_test.go +++ b/tests/integration/init_test.go @@ -107,7 +107,7 @@ func (s *TestShutterService) SetupSuite() { go func() { s.Require().NoError(group.Wait()) }() - s.router = router.NewRouter(s.db, s.contract, s.ethClient, s.config) + s.router, _ = router.NewRouter(s.db, s.contract, s.ethClient, s.config) s.testServer = httptest.NewServer(s.router) } From 16675b74e3ad3e2efe04954a567098cec9060bb4 Mon Sep 17 00:00:00 2001 From: blockchainluffy Date: Tue, 22 Apr 2025 15:08:32 +0530 Subject: [PATCH 2/3] added test for ipratelimiter --- .../ipratelimiter.go | 243 ++++++------- internal/ipratelimiter/ipratelimiter_test.go | 320 ++++++++++++++++++ internal/router/router.go | 2 +- 3 files changed, 431 insertions(+), 134 deletions(-) rename internal/{ipratelimiter.go => ipratelimiter}/ipratelimiter.go (50%) create mode 100644 internal/ipratelimiter/ipratelimiter_test.go diff --git a/internal/ipratelimiter.go/ipratelimiter.go b/internal/ipratelimiter/ipratelimiter.go similarity index 50% rename from internal/ipratelimiter.go/ipratelimiter.go rename to internal/ipratelimiter/ipratelimiter.go index ebfaf81..e70def4 100644 --- a/internal/ipratelimiter.go/ipratelimiter.go +++ b/internal/ipratelimiter/ipratelimiter.go @@ -2,10 +2,8 @@ package ipratelimiter import ( "context" - "fmt" "net" "net/http" - "strings" "sync" "time" @@ -22,9 +20,6 @@ type EndpointLimit struct { // Request stores information about an API request type Request struct { Timestamp time.Time - Day int // Day of month (1-31) - Month int // Month (1-12) - Year int // Year (e.g., 2025) } // IPEndpointLimiter implements rate limiting per IP per endpoint on a daily basis @@ -38,17 +33,62 @@ type IPEndpointLimiter struct { // NewIPEndpointLimiter creates a rate limiter with endpoint-specific daily limits func NewIPEndpointLimiter(defaultMaxRequestsPerDay int) *IPEndpointLimiter { - limiter := &IPEndpointLimiter{ + if defaultMaxRequestsPerDay < 0 { + defaultMaxRequestsPerDay = 0 + } + return &IPEndpointLimiter{ limitsPerIP: make(map[string]map[string][]Request), endpointSettings: make(map[string]*EndpointLimit), defaultLimit: &EndpointLimit{ - Path: "*", MaxRequestsPerDay: defaultMaxRequestsPerDay, }, - cleanup: time.NewTicker(1 * time.Hour), // Hourly cleanup for daily limits + cleanup: time.NewTicker(1 * time.Hour), } +} + +// cleanupExpiredData removes requests older than 1 day +func (rl *IPEndpointLimiter) cleanupExpiredData() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now().UTC() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + oldestToKeep := today.AddDate(0, 0, -1) + + // For each IP address + for ip, endpoints := range rl.limitsPerIP { + endpointsToRemove := []string{} + + // For each endpoint this IP has accessed + for endpoint, requests := range endpoints { + var recentRequests []Request + + // Keep only recent requests + for _, req := range requests { + requestTime := req.Timestamp + if requestTime.After(oldestToKeep) || requestTime.Equal(oldestToKeep) { + recentRequests = append(recentRequests, req) + } + } + + // Update or mark for removal + if len(recentRequests) > 0 { + rl.limitsPerIP[ip][endpoint] = recentRequests + } else { + endpointsToRemove = append(endpointsToRemove, endpoint) + } + } - return limiter + // Remove empty endpoints + for _, endpoint := range endpointsToRemove { + delete(rl.limitsPerIP[ip], endpoint) + } + + // Remove IP if no endpoints left + if len(rl.limitsPerIP[ip]) == 0 { + delete(rl.limitsPerIP, ip) + } + } } // Start runs the cleanup routine - periodically removes requests from previous days @@ -56,57 +96,27 @@ func (rl *IPEndpointLimiter) Start(ctx context.Context, runner service.Runner) e runner.Go(func() error { defer rl.Close() - for range rl.cleanup.C { - rl.mu.Lock() - now := time.Now() - today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) - - // Keep only the last 7 days of data for analysis purposes - oldestToKeep := today.AddDate(0, 0, -7) - - // For each IP address - for ip, endpoints := range rl.limitsPerIP { - endpointsToRemove := []string{} - - // For each endpoint this IP has accessed - for endpoint, requests := range endpoints { - var recentRequests []Request - - // Keep only recent requests - for _, req := range requests { - requestTime := time.Date(req.Year, time.Month(req.Month), req.Day, 0, 0, 0, 0, time.UTC) - if requestTime.After(oldestToKeep) || requestTime.Equal(oldestToKeep) { - recentRequests = append(recentRequests, req) - } - } - - // Update or mark for removal - if len(recentRequests) > 0 { - endpoints[endpoint] = recentRequests - } else { - endpointsToRemove = append(endpointsToRemove, endpoint) - } - } - - // Remove empty endpoints - for _, endpoint := range endpointsToRemove { - delete(endpoints, endpoint) - } - - // Remove IP if no endpoints left - if len(endpoints) == 0 { - delete(rl.limitsPerIP, ip) - } + for { + select { + case <-ctx.Done(): + return nil + case <-rl.cleanup.C: + rl.cleanupExpiredData() } - rl.mu.Unlock() } - return nil }) return nil } // SetLimit adds or updates a rate limit for a specific endpoint func (rl *IPEndpointLimiter) SetLimit(path string, maxRequestsPerDay int) { + if path == "" { + return + } + if maxRequestsPerDay < 0 { + maxRequestsPerDay = 0 + } + rl.mu.Lock() defer rl.mu.Unlock() @@ -129,13 +139,20 @@ func (rl *IPEndpointLimiter) getEndpointLimit(path string) *EndpointLimit { // Allow checks if a request from the given IP to the given endpoint is allowed func (rl *IPEndpointLimiter) Allow(ip, endpoint string) bool { + if ip == "" || endpoint == "" { + return false + } + + // Validate IP address + if net.ParseIP(ip) == nil { + return false + } + rl.mu.Lock() defer rl.mu.Unlock() now := time.Now().UTC() - currentDay := now.Day() - currentMonth := int(now.Month()) - currentYear := now.Year() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) // Get the limit for this endpoint limit := rl.defaultLimit @@ -148,48 +165,29 @@ func (rl *IPEndpointLimiter) Allow(ip, endpoint string) bool { rl.limitsPerIP[ip] = make(map[string][]Request) } - //TODO: need to check if we have to calculate all requests in the rate limit or manage them individualy - // Initialize endpoint in IP map if not exists if _, exists := rl.limitsPerIP[ip][endpoint]; !exists { - newRequest := Request{ - Timestamp: now, - Day: currentDay, - Month: currentMonth, - Year: currentYear, - } - rl.limitsPerIP[ip][endpoint] = []Request{newRequest} - return true + rl.limitsPerIP[ip][endpoint] = []Request{} } // Count requests from the current day requests := rl.limitsPerIP[ip][endpoint] var todayRequests []Request for _, req := range requests { - if req.Day == currentDay && req.Month == currentMonth && req.Year == currentYear { + requestTime := time.Date(req.Timestamp.Year(), req.Timestamp.Month(), req.Timestamp.Day(), 0, 0, 0, 0, time.UTC) + if requestTime.Equal(today) { todayRequests = append(todayRequests, req) } } // Check if adding this request would exceed the limit if len(todayRequests) >= limit.MaxRequestsPerDay { - // Record this request anyway for proper counting - newRequest := Request{ - Timestamp: now, - Day: currentDay, - Month: currentMonth, - Year: currentYear, - } - rl.limitsPerIP[ip][endpoint] = append(requests, newRequest) return false } // Record and allow newRequest := Request{ Timestamp: now, - Day: currentDay, - Month: currentMonth, - Year: currentYear, } rl.limitsPerIP[ip][endpoint] = append(requests, newRequest) return true @@ -200,10 +198,8 @@ func (rl *IPEndpointLimiter) GetCurrentUsage(ip, endpoint string) int { rl.mu.Lock() defer rl.mu.Unlock() - now := time.Now() - currentDay := now.Day() - currentMonth := int(now.Month()) - currentYear := now.Year() + now := time.Now().UTC() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) // Check if IP exists endpoints, exists := rl.limitsPerIP[ip] @@ -220,7 +216,8 @@ func (rl *IPEndpointLimiter) GetCurrentUsage(ip, endpoint string) int { // Count requests from today count := 0 for _, req := range requests { - if req.Day == currentDay && req.Month == currentMonth && req.Year == currentYear { + requestTime := time.Date(req.Timestamp.Year(), req.Timestamp.Month(), req.Timestamp.Day(), 0, 0, 0, 0, time.UTC) + if requestTime.Equal(today) { count++ } } @@ -230,8 +227,8 @@ func (rl *IPEndpointLimiter) GetCurrentUsage(ip, endpoint string) int { // GetRemainingTime returns the time until rate limit reset func (rl *IPEndpointLimiter) GetRemainingTime() time.Duration { - now := time.Now() - tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.Local) + now := time.Now().UTC() + tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.UTC) return tomorrow.Sub(now) } @@ -243,41 +240,20 @@ func (rl *IPEndpointLimiter) Close() { // RateLimitMiddleware applies rate limiting to routes func (rl *IPEndpointLimiter) RateLimitMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - // Get client IP - ip := getClientIP(c) - - // Get endpoint path - path := c.FullPath() - if path == "" { - path = c.Request.URL.Path + if c == nil || c.Request == nil { + c.AbortWithStatus(http.StatusInternalServerError) + return } - // Check if allowed - if !rl.Allow(ip, path) { - limit := rl.getEndpointLimit(path) - usage := rl.GetCurrentUsage(ip, path) - remaining := 0 - if limit.MaxRequestsPerDay > usage { - remaining = limit.MaxRequestsPerDay - usage - } + ip := getClientIP(c) + if ip == "" { + c.AbortWithStatus(http.StatusBadRequest) + return + } - // Get seconds until reset - resetSeconds := int(rl.GetRemainingTime().Seconds()) - resetHours := int(rl.GetRemainingTime().Hours()) - resetMinutes := int(rl.GetRemainingTime().Minutes()) % 60 - - c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", limit.MaxRequestsPerDay)) - c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) - c.Header("X-RateLimit-Used", fmt.Sprintf("%d", usage)) - c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix()+int64(resetSeconds))) - - c.JSON(http.StatusTooManyRequests, gin.H{ - "error": "Daily rate limit exceeded for this endpoint.", - "limit": limit.MaxRequestsPerDay, - "used": usage, - "reset_in": fmt.Sprintf("%dh %dm", resetHours, resetMinutes), - }) - c.Abort() + endpoint := c.Request.URL.Path + if !rl.Allow(ip, endpoint) { + c.AbortWithStatus(http.StatusTooManyRequests) return } @@ -287,27 +263,28 @@ func (rl *IPEndpointLimiter) RateLimitMiddleware() gin.HandlerFunc { // getClientIP extracts the client IP from the request func getClientIP(c *gin.Context) string { - // Try X-Forwarded-For header first - if xForwardedFor := c.Request.Header.Get("X-Forwarded-For"); xForwardedFor != "" { - ips := strings.Split(xForwardedFor, ",") - if len(ips) > 0 { - clientIP := strings.TrimSpace(ips[0]) - if clientIP != "" { - return clientIP - } - } + if c == nil || c.Request == nil { + return "" } - // Try X-Real-IP header - if xRealIP := c.Request.Header.Get("X-Real-IP"); xRealIP != "" { - return strings.TrimSpace(xRealIP) + // Get the remote address + remoteAddr := c.Request.RemoteAddr + if remoteAddr == "" { + return "" } - // Fall back to RemoteAddr - ip, _, err := net.SplitHostPort(c.Request.RemoteAddr) - if err != nil { - return c.Request.RemoteAddr + // Remove port if present + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + if net.ParseIP(host) != nil { + return host + } + } else { + // Try parsing the whole string as an IP + if net.ParseIP(remoteAddr) != nil { + return remoteAddr + } } - return ip + return "" } diff --git a/internal/ipratelimiter/ipratelimiter_test.go b/internal/ipratelimiter/ipratelimiter_test.go new file mode 100644 index 0000000..0e77379 --- /dev/null +++ b/internal/ipratelimiter/ipratelimiter_test.go @@ -0,0 +1,320 @@ +package ipratelimiter + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestNewIPEndpointLimiter(t *testing.T) { + defaultLimit := 100 + limiter := NewIPEndpointLimiter(defaultLimit) + + assert.NotNil(t, limiter) + assert.Equal(t, defaultLimit, limiter.defaultLimit.MaxRequestsPerDay) + assert.NotNil(t, limiter.limitsPerIP) + assert.NotNil(t, limiter.endpointSettings) + assert.NotNil(t, limiter.cleanup) +} + +func TestSetLimit(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + path := "/test" + maxRequests := 50 + + limiter.SetLimit(path, maxRequests) + + limit := limiter.getEndpointLimit(path) + assert.Equal(t, maxRequests, limit.MaxRequestsPerDay) + assert.Equal(t, path, limit.Path) +} + +func TestAllow(t *testing.T) { + limiter := NewIPEndpointLimiter(2) // Set a small limit for testing + ip := "127.0.0.1" + endpoint := "/test" + + // First request should be allowed + assert.True(t, limiter.Allow(ip, endpoint)) + assert.Equal(t, 1, limiter.GetCurrentUsage(ip, endpoint)) + + // Second request should be allowed + assert.True(t, limiter.Allow(ip, endpoint)) + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint)) + + // Third request should be denied + assert.False(t, limiter.Allow(ip, endpoint)) + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint)) +} + +func TestGetCurrentUsage(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + ip := "127.0.0.1" + endpoint := "/test" + + // Initial usage should be 0 + assert.Equal(t, 0, limiter.GetCurrentUsage(ip, endpoint)) + + // Make a request and check usage + limiter.Allow(ip, endpoint) + assert.Equal(t, 1, limiter.GetCurrentUsage(ip, endpoint)) + + // Make another request and check usage + limiter.Allow(ip, endpoint) + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint)) +} + +func TestGetRemainingTime(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + remainingTime := limiter.GetRemainingTime() + + // Remaining time should be less than 24 hours + assert.True(t, remainingTime < 24*time.Hour) + // Remaining time should be positive + assert.True(t, remainingTime > 0) +} + +func TestRateLimitMiddleware(t *testing.T) { + limiter := NewIPEndpointLimiter(2) // Set a small limit for testing + router := gin.New() + router.Use(limiter.RateLimitMiddleware()) + + // Setup test endpoint + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + // First request should succeed + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" // Set RemoteAddr for the request + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Second request should succeed + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" // Set RemoteAddr for the request + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Third request should be rate limited + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" // Set RemoteAddr for the request + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + // Test with different IP should succeed + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "127.0.0.2:1234" // Different IP + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Test with invalid IP should fail + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "invalid-ip" // Invalid IP + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestGetClientIP(t *testing.T) { + // Test valid RemoteAddr with port + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request.RemoteAddr = "192.168.1.1:1234" + assert.Equal(t, "192.168.1.1", getClientIP(c)) + + // Test valid RemoteAddr without port + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request.RemoteAddr = "192.168.1.2" + assert.Equal(t, "192.168.1.2", getClientIP(c)) + + // Test nil context + assert.Equal(t, "", getClientIP(nil)) + + // Test nil request + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + assert.Equal(t, "", getClientIP(c)) +} + +func TestGetClientIPWithInvalidHeaders(t *testing.T) { + // Test with empty RemoteAddr + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request.RemoteAddr = "" + assert.Equal(t, "", getClientIP(c)) + + // Test with invalid RemoteAddr format + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request.RemoteAddr = "invalid-address" + assert.Equal(t, "", getClientIP(c)) + + // Test with invalid IP in RemoteAddr + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("GET", "/", nil) + c.Request.RemoteAddr = "256.256.256.256:1234" + assert.Equal(t, "", getClientIP(c)) +} + +func TestInvalidIPAddress(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + endpoint := "/test" + + // Test with empty IP + assert.False(t, limiter.Allow("", endpoint)) + assert.Equal(t, 0, limiter.GetCurrentUsage("", endpoint)) + + // Test with invalid IP format + assert.False(t, limiter.Allow("invalid.ip.address", endpoint)) + assert.Equal(t, 0, limiter.GetCurrentUsage("invalid.ip.address", endpoint)) +} + +func TestInvalidEndpoint(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + ip := "127.0.0.1" + + // Test with empty endpoint + assert.False(t, limiter.Allow(ip, "")) + assert.Equal(t, 0, limiter.GetCurrentUsage(ip, "")) + + // Test with non-existent endpoint (should use default limit) + assert.True(t, limiter.Allow(ip, "/nonexistent")) + assert.Equal(t, 1, limiter.GetCurrentUsage(ip, "/nonexistent")) + + // Test with non-existent endpoint after reaching limit + // Set a very low limit for testing + limiter = NewIPEndpointLimiter(1) + assert.True(t, limiter.Allow(ip, "/nonexistent")) + assert.False(t, limiter.Allow(ip, "/nonexistent")) + assert.Equal(t, 1, limiter.GetCurrentUsage(ip, "/nonexistent")) +} + +func TestNegativeLimit(t *testing.T) { + // Test with negative default limit + limiter := NewIPEndpointLimiter(-1) + assert.Equal(t, 0, limiter.defaultLimit.MaxRequestsPerDay) + + // Test setting negative limit for endpoint + limiter.SetLimit("/test", -1) + limit := limiter.getEndpointLimit("/test") + assert.Equal(t, 0, limit.MaxRequestsPerDay) +} + +func TestConcurrentAccess(t *testing.T) { + limiter := NewIPEndpointLimiter(1000) + ip := "127.0.0.1" + endpoint := "/test" + concurrentRequests := 100 + + // Channel to collect results + results := make(chan bool, concurrentRequests) + + // Launch concurrent requests + for i := 0; i < concurrentRequests; i++ { + go func() { + results <- limiter.Allow(ip, endpoint) + }() + } + + // Collect results + allowedCount := 0 + for i := 0; i < concurrentRequests; i++ { + if <-results { + allowedCount++ + } + } + + // Verify we didn't exceed the limit + assert.True(t, allowedCount <= 1000, "Concurrent requests exceeded limit") + assert.Equal(t, allowedCount, limiter.GetCurrentUsage(ip, endpoint)) +} + +func TestCleanupWithExpiredData(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + + // Add some test data with old timestamps + ip := "127.0.0.1" + endpoint := "/test" + + now := time.Now().UTC() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + + // Create a request from 24 hours ago + oldTime := today.AddDate(0, 0, -1).Add(-1 * time.Second) + oldRequest := Request{ + Timestamp: oldTime, + } + + // Add the old request to the limiter + limiter.mu.Lock() + if _, exists := limiter.limitsPerIP[ip]; !exists { + limiter.limitsPerIP[ip] = make(map[string][]Request) + } + limiter.limitsPerIP[ip][endpoint] = []Request{oldRequest} + limiter.mu.Unlock() + + // Run cleanup directly + limiter.cleanupExpiredData() + + // Verify old data was cleaned up + limiter.mu.Lock() + requests, exists := limiter.limitsPerIP[ip][endpoint] + limiter.mu.Unlock() + + if exists { + assert.Equal(t, 0, len(requests), "Old data should have been cleaned up") + } + + // Add a recent request and verify it's not cleaned up + recentTime := time.Now().UTC() + recentRequest := Request{ + Timestamp: recentTime, + } + + limiter.mu.Lock() + if _, exists := limiter.limitsPerIP[ip]; !exists { + limiter.limitsPerIP[ip] = make(map[string][]Request) + } + limiter.limitsPerIP[ip][endpoint] = []Request{recentRequest} + limiter.mu.Unlock() + + // Run cleanup again + limiter.cleanupExpiredData() + + // Verify recent data is still there + limiter.mu.Lock() + requests, exists = limiter.limitsPerIP[ip][endpoint] + limiter.mu.Unlock() + + assert.True(t, exists, "Recent data should not be cleaned up") + assert.Equal(t, 1, len(requests), "Recent data should not be cleaned up") + + limiter.Close() +} + +func TestMiddlewareWithInvalidContext(t *testing.T) { + limiter := NewIPEndpointLimiter(100) + router := gin.New() + router.Use(limiter.RateLimitMiddleware()) + + // Test with nil context + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} diff --git a/internal/router/router.go b/internal/router/router.go index dc2d3f8..9653adb 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -11,7 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/shutter-network/shutter-api/common" "github.com/shutter-network/shutter-api/docs" - "github.com/shutter-network/shutter-api/internal/ipratelimiter.go" + "github.com/shutter-network/shutter-api/internal/ipratelimiter" "github.com/shutter-network/shutter-api/internal/middleware" "github.com/shutter-network/shutter-api/internal/service" swaggerFiles "github.com/swaggo/files" From a6396da392ed0d4d0b73ee9b42fd71abefb50598 Mon Sep 17 00:00:00 2001 From: blockchainluffy Date: Wed, 23 Apr 2025 12:15:13 +0530 Subject: [PATCH 3/3] added multiple endpoint tests --- internal/ipratelimiter/ipratelimiter_test.go | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/internal/ipratelimiter/ipratelimiter_test.go b/internal/ipratelimiter/ipratelimiter_test.go index 0e77379..2430166 100644 --- a/internal/ipratelimiter/ipratelimiter_test.go +++ b/internal/ipratelimiter/ipratelimiter_test.go @@ -318,3 +318,74 @@ func TestMiddlewareWithInvalidContext(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) } + +func TestDifferentEndpointLimits(t *testing.T) { + limiter := NewIPEndpointLimiter(100) // Default limit + ip := "127.0.0.1" + + // Set up two endpoints with different limits + endpoint1 := "/endpoint1" + endpoint2 := "/endpoint2" + limiter.SetLimit(endpoint1, 2) // Small limit + limiter.SetLimit(endpoint2, 5) // Larger limit + + // Test endpoint1 with small limit + assert.True(t, limiter.Allow(ip, endpoint1)) // First request + assert.True(t, limiter.Allow(ip, endpoint1)) // Second request + assert.False(t, limiter.Allow(ip, endpoint1)) // Should be blocked + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint1)) + + // Test endpoint2 with larger limit + assert.True(t, limiter.Allow(ip, endpoint2)) // First request + assert.True(t, limiter.Allow(ip, endpoint2)) // Second request + assert.True(t, limiter.Allow(ip, endpoint2)) // Third request + assert.True(t, limiter.Allow(ip, endpoint2)) // Fourth request + assert.True(t, limiter.Allow(ip, endpoint2)) // Fifth request + assert.False(t, limiter.Allow(ip, endpoint2)) // Should be blocked + assert.Equal(t, 5, limiter.GetCurrentUsage(ip, endpoint2)) +} + +func TestEndpointLimitsIndependent(t *testing.T) { + limiter := NewIPEndpointLimiter(100) // Default limit + ip := "127.0.0.1" + + // Set up two endpoints with different limits + endpoint1 := "/endpoint1" + endpoint2 := "/endpoint2" + limiter.SetLimit(endpoint1, 2) // Small limit + limiter.SetLimit(endpoint2, 5) // Larger limit + + // Max out the larger limit endpoint first + for i := 0; i < 5; i++ { + assert.True(t, limiter.Allow(ip, endpoint2)) + } + assert.False(t, limiter.Allow(ip, endpoint2)) // Should be blocked + assert.Equal(t, 5, limiter.GetCurrentUsage(ip, endpoint2)) + + // Verify smaller limit endpoint is still accessible + assert.True(t, limiter.Allow(ip, endpoint1)) // First request + assert.True(t, limiter.Allow(ip, endpoint1)) // Second request + assert.False(t, limiter.Allow(ip, endpoint1)) // Should be blocked + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint1)) + + // Now max out the smaller limit endpoint + // Reset the test by creating a new limiter + limiter = NewIPEndpointLimiter(100) + limiter.SetLimit(endpoint1, 2) + limiter.SetLimit(endpoint2, 5) + + // Max out the smaller limit endpoint + assert.True(t, limiter.Allow(ip, endpoint1)) // First request + assert.True(t, limiter.Allow(ip, endpoint1)) // Second request + assert.False(t, limiter.Allow(ip, endpoint1)) // Should be blocked + assert.Equal(t, 2, limiter.GetCurrentUsage(ip, endpoint1)) + + // Verify larger limit endpoint is still accessible + assert.True(t, limiter.Allow(ip, endpoint2)) // First request + assert.True(t, limiter.Allow(ip, endpoint2)) // Second request + assert.True(t, limiter.Allow(ip, endpoint2)) // Third request + assert.True(t, limiter.Allow(ip, endpoint2)) // Fourth request + assert.True(t, limiter.Allow(ip, endpoint2)) // Fifth request + assert.False(t, limiter.Allow(ip, endpoint2)) // Should be blocked + assert.Equal(t, 5, limiter.GetCurrentUsage(ip, endpoint2)) +}