Skip to content

Commit

Permalink
SNOW-981533 Separate retry strategy for auth endpoints and the remain…
Browse files Browse the repository at this point in the history
…ing ones
  • Loading branch information
sfc-gh-pfus committed Dec 1, 2023
1 parent 5c79db8 commit 17ed58a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 28 deletions.
2 changes: 1 addition & 1 deletion monitoring.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (sc *snowflakeConn) checkQueryStatus(
if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" {
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok)
}
resultPath := fmt.Sprintf("/monitoring/queries/%s", qid)
resultPath := fmt.Sprintf("%s/%s", monitoringQueriesPath, qid)
url := sc.rest.getFullURL(resultPath, &param)

res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
Expand Down
1 change: 1 addition & 0 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
tokenRequestPath = "/session/token-request"
abortRequestPath = "/queries/v1/abort-request"
authenticatorRequestPath = "/session/authenticator-request"
monitoringQueriesPath = "/monitoring/queries"
sessionRequestPath = "/session"
heartBeatPath = "/session/heartbeat"
)
Expand Down
54 changes: 39 additions & 15 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ import (
type waitAlgo struct {
mutex *sync.Mutex // required for *rand.Rand usage
random *rand.Rand
base time.Duration // base wait time
cap time.Duration // maximum wait time
}

var random *rand.Rand
var defaultWaitAlgo *waitAlgo

var endpointsEligibleForRetry = []string{
var authEndpoints = []string{
loginRequestPath,
tokenRequestPath,
authenticatorRequestPath,
queryRequestPath,
abortRequestPath,
sessionRequestPath,
}

var clientErrorsStatusCodesEligibleForRetry = []int{
Expand All @@ -43,7 +42,7 @@ var clientErrorsStatusCodesEligibleForRetry = []int{

func init() {
random = rand.New(rand.NewSource(time.Now().UnixNano()))
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random}
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 5 * time.Second, cap: 160 * time.Second}
}

const (
Expand Down Expand Up @@ -205,12 +204,30 @@ func isQueryRequest(url *url.URL) bool {
}

// jitter backoff in seconds
func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime float64) float64 {
func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration {
w.mutex.Lock()
defer w.mutex.Unlock()
jitterAmount := w.getJitter(currWaitTime)
jitteredSleepTime := chooseRandomFromRange(currWaitTime+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
return jitteredSleepTime
currWaitTimeInSeconds := currWaitTimeDuration.Seconds()
jitterAmount := w.getJitter(currWaitTimeInSeconds)
jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
return time.Duration(jitteredSleepTime * float64(time.Second))
}

func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration {
w.mutex.Lock()
defer w.mutex.Unlock()
t := 3*sleep - w.base
switch {
case t > 0:
return durationMin(w.cap, randSecondDuration(t)+w.base)
case t < 0:
return durationMin(w.cap, randSecondDuration(-t)+3*sleep)
}
return w.base
}

func randSecondDuration(n time.Duration) time.Duration {
return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second
}

func (w *waitAlgo) getJitter(currWaitTime float64) float64 {
Expand Down Expand Up @@ -284,7 +301,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
totalTimeout := r.timeout
logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout)
retryCounter := 0
sleepTime := 1.0 // seconds
sleepTime := time.Duration(time.Second)
clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10)

var requestGUIDReplacer requestGUIDReplacer
Expand Down Expand Up @@ -324,12 +341,16 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
}
// uses exponential jitter backoff
retryCounter++
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
if isLoginRequest(req) {
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime)
} else {
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
}

if totalTimeout > 0 {
logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout)
// if any timeout is set
totalTimeout -= time.Duration(sleepTime * float64(time.Second))
totalTimeout -= sleepTime
if totalTimeout <= 0 || retryCounter > r.maxRetryCount {
if err != nil {
return nil, err
Expand Down Expand Up @@ -360,7 +381,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout)
logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason)

await := time.NewTimer(time.Duration(sleepTime * float64(time.Second)))
await := time.NewTimer(sleepTime)
select {
case <-await.C:
// retry the request
Expand All @@ -378,10 +399,13 @@ func isRetryableError(req *http.Request, res *http.Response, err error) (bool, e
if res == nil || req == nil {
return false, err
}
isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path)
return isRetryableURL && isRetryableStatus(res.StatusCode), err
return isRetryableStatus(res.StatusCode), err
}

func isRetryableStatus(statusCode int) bool {
return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode)
}

func isLoginRequest(req *http.Request) bool {
return contains(authEndpoints, req.URL.Path)
}
20 changes: 8 additions & 12 deletions retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,6 @@ func TestIsRetryable(t *testing.T) {
err: nil,
expected: false,
},
{
req: &http.Request{URL: &url.URL{Path: heartBeatPath}},
res: &http.Response{StatusCode: http.StatusBadRequest},
err: nil,
expected: false,
},
{
req: &http.Request{URL: &url.URL{Path: loginRequestPath}},
res: &http.Response{StatusCode: http.StatusNotFound},
Expand All @@ -525,10 +519,12 @@ func TestIsRetryable(t *testing.T) {
}

for _, tc := range tcs {
result, _ := isRetryableError(tc.req, tc.res, tc.err)
if result != tc.expected {
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
}
t.Run(fmt.Sprintf("req %v, resp %v", tc.req, tc.res), func(t *testing.T) {
result, _ := isRetryableError(tc.req, tc.res, tc.err)
if result != tc.expected {
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
}
})
}
}

Expand Down Expand Up @@ -605,8 +601,8 @@ func TestCalculateRetryWait(t *testing.T) {

for _, tc := range tcs {
t.Run(fmt.Sprintf("attmept: %v", tc.attempt), func(t *testing.T) {
result := defaultWaitAlgo.calculateWaitBeforeRetry(tc.attempt, tc.currWaitTime)
assertBetweenE(t, result, tc.minSleepTime, tc.maxSleepTime)
result := defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(tc.attempt, time.Duration(tc.currWaitTime*float64(time.Second)))
assertBetweenE(t, result.Seconds(), tc.minSleepTime, tc.maxSleepTime)
})
}
}

0 comments on commit 17ed58a

Please sign in to comment.