Skip to content

added basic rate limits per endpoint #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ P2P_PORT=
LOG_LEVEL=
METRICS_ENABLED=true
METRICS_HOST="[::]"
METRICS_PORT=4000
METRICS_PORT=4000
DEFAULT_RATE_LIMIT=100
REGISTER_IDENTITY_RATE_LIMIT=30
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ services:
METRICS_ENABLED:
METRICS_HOST:
METRICS_PORT:
DEFAULT_RATE_LIMIT:
REGISTER_IDENTITY_RATE_LIMIT:
ports:
- "8001:8001"
- "23003:23003"
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ require (
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/supranational/blst v0.3.12
github.com/supranational/blst v0.3.14
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/supranational/blst v0.3.12 h1:Vfas2U2CFHhniv2QkUm2OVa1+pGTdqtpqm9NnhUUbZ8=
github.com/supranational/blst v0.3.12/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw=
github.com/supranational/blst v0.3.14 h1:xNMoHRJOTwMn63ip6qoWJ2Ymgvj7E2b9jY2FAwY+qRo=
github.com/supranational/blst v0.3.14/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw=
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
Expand Down
290 changes: 290 additions & 0 deletions internal/ipratelimiter/ipratelimiter.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be much better to move this under the middleware package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that its better to have separate package so we do not merge any other middleware feature with rate limiting

Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
package ipratelimiter

import (
"context"
"net"
"net/http"
"sync"
"time"

"github.com/gin-gonic/gin"
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service"
)

// EndpointLimit defines a rate limit for a specific endpoint
type EndpointLimit struct {
Path string
MaxRequestsPerDay int
}

// Request stores information about an API request
type Request struct {
Timestamp time.Time
}

// IPEndpointLimiter implements rate limiting per IP per endpoint on a daily basis
type IPEndpointLimiter struct {
mu sync.Mutex
limitsPerIP map[string]map[string][]Request // map[ip]map[endpoint][]requests
endpointSettings map[string]*EndpointLimit // map[endpoint]*EndpointLimit
defaultLimit *EndpointLimit
cleanup *time.Ticker
}

// NewIPEndpointLimiter creates a rate limiter with endpoint-specific daily limits
func NewIPEndpointLimiter(defaultMaxRequestsPerDay int) *IPEndpointLimiter {
if defaultMaxRequestsPerDay < 0 {
defaultMaxRequestsPerDay = 0
}
return &IPEndpointLimiter{
limitsPerIP: make(map[string]map[string][]Request),
endpointSettings: make(map[string]*EndpointLimit),
defaultLimit: &EndpointLimit{
MaxRequestsPerDay: defaultMaxRequestsPerDay,
},
cleanup: time.NewTicker(1 * time.Hour),
}
}

// cleanupExpiredData removes requests older than 1 day
func (rl *IPEndpointLimiter) cleanupExpiredData() {
rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now().UTC()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
oldestToKeep := today.AddDate(0, 0, -1)

// For each IP address
for ip, endpoints := range rl.limitsPerIP {
endpointsToRemove := []string{}

// For each endpoint this IP has accessed
for endpoint, requests := range endpoints {
var recentRequests []Request

// Keep only recent requests
for _, req := range requests {
requestTime := req.Timestamp
if requestTime.After(oldestToKeep) || requestTime.Equal(oldestToKeep) {
recentRequests = append(recentRequests, req)
}
}

// Update or mark for removal
if len(recentRequests) > 0 {
rl.limitsPerIP[ip][endpoint] = recentRequests
} else {
endpointsToRemove = append(endpointsToRemove, endpoint)
}
}

// Remove empty endpoints
for _, endpoint := range endpointsToRemove {
delete(rl.limitsPerIP[ip], endpoint)
}

// Remove IP if no endpoints left
if len(rl.limitsPerIP[ip]) == 0 {
delete(rl.limitsPerIP, ip)
}
}
}

// Start runs the cleanup routine - periodically removes requests from previous days
func (rl *IPEndpointLimiter) Start(ctx context.Context, runner service.Runner) error {
runner.Go(func() error {
defer rl.Close()

for {
select {
case <-ctx.Done():
return nil
case <-rl.cleanup.C:
rl.cleanupExpiredData()
}
}
})
return nil
}

// SetLimit adds or updates a rate limit for a specific endpoint
func (rl *IPEndpointLimiter) SetLimit(path string, maxRequestsPerDay int) {
if path == "" {
return
}
if maxRequestsPerDay < 0 {
maxRequestsPerDay = 0
}

rl.mu.Lock()
defer rl.mu.Unlock()

rl.endpointSettings[path] = &EndpointLimit{
Path: path,
MaxRequestsPerDay: maxRequestsPerDay,
}
}

// getEndpointLimit returns the limit settings for the given path
func (rl *IPEndpointLimiter) getEndpointLimit(path string) *EndpointLimit {
rl.mu.Lock()
defer rl.mu.Unlock()

if limit, exists := rl.endpointSettings[path]; exists {
return limit
}
return rl.defaultLimit
}

// Allow checks if a request from the given IP to the given endpoint is allowed
func (rl *IPEndpointLimiter) Allow(ip, endpoint string) bool {
if ip == "" || endpoint == "" {
return false
}

// Validate IP address
if net.ParseIP(ip) == nil {
return false
}

rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now().UTC()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)

// Get the limit for this endpoint
limit := rl.defaultLimit
if l, exists := rl.endpointSettings[endpoint]; exists {
limit = l
}

// Initialize IP map if not exists
if _, exists := rl.limitsPerIP[ip]; !exists {
rl.limitsPerIP[ip] = make(map[string][]Request)
}

// Initialize endpoint in IP map if not exists
if _, exists := rl.limitsPerIP[ip][endpoint]; !exists {
rl.limitsPerIP[ip][endpoint] = []Request{}
}

// Count requests from the current day
requests := rl.limitsPerIP[ip][endpoint]
var todayRequests []Request
for _, req := range requests {
requestTime := time.Date(req.Timestamp.Year(), req.Timestamp.Month(), req.Timestamp.Day(), 0, 0, 0, 0, time.UTC)
if requestTime.Equal(today) {
todayRequests = append(todayRequests, req)
}
}

// Check if adding this request would exceed the limit
if len(todayRequests) >= limit.MaxRequestsPerDay {
return false
}

// Record and allow
newRequest := Request{
Timestamp: now,
}
rl.limitsPerIP[ip][endpoint] = append(requests, newRequest)
return true
}

// GetCurrentUsage returns the number of requests made today
func (rl *IPEndpointLimiter) GetCurrentUsage(ip, endpoint string) int {
rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now().UTC()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)

// Check if IP exists
endpoints, exists := rl.limitsPerIP[ip]
if !exists {
return 0
}

// Check if endpoint exists for this IP
requests, exists := endpoints[endpoint]
if !exists {
return 0
}

// Count requests from today
count := 0
for _, req := range requests {
requestTime := time.Date(req.Timestamp.Year(), req.Timestamp.Month(), req.Timestamp.Day(), 0, 0, 0, 0, time.UTC)
if requestTime.Equal(today) {
count++
}
}

return count
}

// GetRemainingTime returns the time until rate limit reset
func (rl *IPEndpointLimiter) GetRemainingTime() time.Duration {
now := time.Now().UTC()
tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.UTC)
return tomorrow.Sub(now)
}

// Close stops the cleanup ticker
func (rl *IPEndpointLimiter) Close() {
rl.cleanup.Stop()
}

// RateLimitMiddleware applies rate limiting to routes
func (rl *IPEndpointLimiter) RateLimitMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}

ip := getClientIP(c)
if ip == "" {
c.AbortWithStatus(http.StatusBadRequest)
return
}

endpoint := c.Request.URL.Path
if !rl.Allow(ip, endpoint) {
c.AbortWithStatus(http.StatusTooManyRequests)
return
}

c.Next()
}
}

// getClientIP extracts the client IP from the request
func getClientIP(c *gin.Context) string {
if c == nil || c.Request == nil {
return ""
}

// Get the remote address
remoteAddr := c.Request.RemoteAddr
if remoteAddr == "" {
return ""
}

// Remove port if present
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if net.ParseIP(host) != nil {
return host
}
} else {
// Try parsing the whole string as an IP
if net.ParseIP(remoteAddr) != nil {
return remoteAddr
}
}

return ""
}
Loading