Skip to content

Commit 7f7419b

Browse files
authored
Merge pull request #1 from leetcode-golang-classroom/sync-map-solution
✨ handle concurrency with sync.Map
2 parents 8f4d71a + ea64382 commit 7f7419b

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,32 @@ func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit,
9191
})
9292
}
9393
```
94+
95+
96+
## handle concurrency problem with sync.Map
97+
98+
```golang
99+
var ipLimiterMap sync.Map
100+
101+
// RateLimiterMiddleware - 建立 ratelimiter middleware
102+
func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler {
103+
104+
// var mu sync.Mutex
105+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
106+
// Fetch IP
107+
ip := r.getIP(req)
108+
// Create limiter if not present for IP
109+
limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst))
110+
limiter := limiterAny.(*rate.Limiter)
111+
// return error if the limit has been reached
112+
if !limiter.Allow() {
113+
w.Header().Set("Content-Type", "application/json")
114+
w.WriteHeader(http.StatusTooManyRequests)
115+
json.NewEncoder(w).Encode(map[string]string{"error": "Too many requests"})
116+
return
117+
}
118+
next.ServeHTTP(w, req)
119+
})
120+
}
121+
```
122+

internal/service/rate_limiter/rate-limiter.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,18 @@ func (r *RateLimiter) getIP(req *http.Request) string {
2222
return host
2323
}
2424

25+
var ipLimiterMap sync.Map
26+
2527
// RateLimiterMiddleware - 建立 ratelimiter middleware
2628
func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler {
27-
ipLimiterMap := make(map[string]*rate.Limiter)
28-
var mu sync.Mutex
29+
30+
// var mu sync.Mutex
2931
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3032
// Fetch IP
3133
ip := r.getIP(req)
3234
// Create limiter if not present for IP
33-
mu.Lock()
34-
limiter, exists := ipLimiterMap[ip]
35-
if !exists {
36-
limiter = rate.NewLimiter(limit, burst)
37-
ipLimiterMap[ip] = limiter
38-
}
39-
mu.Unlock()
35+
limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst))
36+
limiter := limiterAny.(*rate.Limiter)
4037
// return error if the limit has been reached
4138
if !limiter.Allow() {
4239
w.Header().Set("Content-Type", "application/json")

0 commit comments

Comments
 (0)