diff --git a/okta/client.go b/okta/client.go index ad9987c8..28908003 100644 --- a/okta/client.go +++ b/okta/client.go @@ -51,6 +51,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "unicode/utf8" @@ -75,6 +76,12 @@ const ( DpopAccessTokenPrivateKey = "DPOP_OKTA_ACCESS_TOKEN_PRIVATE_KEY" ) +type RateLimit struct { + Limit int + Remaining int + Reset int64 +} + // APIClient manages communication with the Okta Admin Management API v5.1.0 // In most cases there should be only one, shared, APIClient. type APIClient struct { @@ -84,6 +91,9 @@ type APIClient struct { tokenCache *goCache.Cache freshcache bool + rateLimit *RateLimit + rateLimitLock sync.Mutex + // API Services AgentPoolsAPI AgentPoolsAPI @@ -1068,6 +1078,26 @@ func (c *APIClient) RefreshNext() *APIClient { return c } +func parseRateLimit(resp *http.Response) (*RateLimit, error) { + limit, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Limit")) + if err != nil { + return nil, err + } + remaining, err := strconv.Atoi(resp.Header.Get("X-Rate-Limit-Remaining")) + if err != nil { + return nil, err + } + reset, err := Get429BackoffTime(resp) + if err != nil { + return nil, err + } + return &RateLimit{ + Limit: limit, + Remaining: remaining, + Reset: reset, + }, nil +} + func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, error) { cacheKey := CreateCacheKey(req) if req.Method != http.MethodGet { @@ -1080,11 +1110,36 @@ func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, c.freshcache = false } if !inCache { + if c.cfg.Okta.Client.RateLimit.Prevent { + c.rateLimitLock.Lock() + limit := c.rateLimit + c.rateLimitLock.Unlock() + if limit != nil && limit.Remaining <= 0 { + timer := time.NewTimer(time.Second * time.Duration(limit.Reset)) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return nil, ctx.Err() + case <-timer.C: + } + } + } + resp, err := c.doWithRetries(ctx, req) if err != nil { return nil, err } if resp.StatusCode >= 200 && resp.StatusCode <= 299 && req.Method == http.MethodGet { + if c.cfg.Okta.Client.RateLimit.Prevent { + c.rateLimitLock.Lock() + newLimit, err := parseRateLimit(resp) + if err == nil { + c.rateLimit = newLimit + } + c.rateLimitLock.Unlock() + } c.cache.Set(cacheKey, resp) } return resp, err diff --git a/okta/configuration.go b/okta/configuration.go index cce4c11e..046569a8 100644 --- a/okta/configuration.go +++ b/okta/configuration.go @@ -142,6 +142,7 @@ type Configuration struct { RateLimit struct { MaxRetries int32 `yaml:"maxRetries" envconfig:"OKTA_CLIENT_RATE_LIMIT_MAX_RETRIES"` MaxBackoff int64 `yaml:"maxBackoff" envconfig:"OKTA_CLIENT_RATE_LIMIT_MAX_BACKOFF"` + Prevent bool `yaml:"prevent" envconfig:"OKTA_CLIENT_RATE_LIMIT_PREVENT"` } `yaml:"rateLimit"` OrgUrl string `yaml:"orgUrl" envconfig:"OKTA_CLIENT_ORGURL"` Token string `yaml:"token" envconfig:"OKTA_CLIENT_TOKEN"` @@ -477,6 +478,12 @@ func WithRateLimitMaxRetries(maxRetries int32) ConfigSetter { } } +func WithRateLimitPrevent(prevent bool) ConfigSetter { + return func(c *Configuration) { + c.Okta.Client.RateLimit.Prevent = prevent + } +} + func WithRateLimitMaxBackOff(maxBackoff int64) ConfigSetter { return func(c *Configuration) { c.Okta.Client.RateLimit.MaxBackoff = maxBackoff