Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-981533 Separate retry strategy for auth endpoints and the remaining ones #982

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monitoring.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (sc *snowflakeConn) checkQueryStatus(
if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" {
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok)
}
resultPath := fmt.Sprintf("/monitoring/queries/%s", qid)
resultPath := fmt.Sprintf("%s/%s", monitoringQueriesPath, qid)
url := sc.rest.getFullURL(resultPath, &param)

res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
1 change: 1 addition & 0 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
tokenRequestPath = "/session/token-request"
abortRequestPath = "/queries/v1/abort-request"
authenticatorRequestPath = "/session/authenticator-request"
monitoringQueriesPath = "/monitoring/queries"
sessionRequestPath = "/session"
heartBeatPath = "/session/heartbeat"
)
Expand Down
54 changes: 39 additions & 15 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ import (
type waitAlgo struct {
mutex *sync.Mutex // required for *rand.Rand usage
random *rand.Rand
base time.Duration // base wait time
cap time.Duration // maximum wait time
}

var random *rand.Rand
var defaultWaitAlgo *waitAlgo

var endpointsEligibleForRetry = []string{
var authEndpoints = []string{
loginRequestPath,
tokenRequestPath,
authenticatorRequestPath,
queryRequestPath,
abortRequestPath,
sessionRequestPath,
}

var clientErrorsStatusCodesEligibleForRetry = []int{
Expand All @@ -43,7 +42,7 @@ var clientErrorsStatusCodesEligibleForRetry = []int{

func init() {
random = rand.New(rand.NewSource(time.Now().UnixNano()))
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random}
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 5 * time.Second, cap: 160 * time.Second}
}

const (
Expand Down Expand Up @@ -205,12 +204,30 @@ func isQueryRequest(url *url.URL) bool {
}

// jitter backoff in seconds
func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime float64) float64 {
func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration {
w.mutex.Lock()
defer w.mutex.Unlock()
jitterAmount := w.getJitter(currWaitTime)
jitteredSleepTime := chooseRandomFromRange(currWaitTime+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
return jitteredSleepTime
currWaitTimeInSeconds := currWaitTimeDuration.Seconds()
jitterAmount := w.getJitter(currWaitTimeInSeconds)
jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
return time.Duration(jitteredSleepTime * float64(time.Second))
}

func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration {
w.mutex.Lock()
defer w.mutex.Unlock()
t := 3*sleep - w.base
switch {
case t > 0:
return durationMin(w.cap, randSecondDuration(t)+w.base)
case t < 0:
return durationMin(w.cap, randSecondDuration(-t)+3*sleep)
}
return w.base
}

func randSecondDuration(n time.Duration) time.Duration {
return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second
}

func (w *waitAlgo) getJitter(currWaitTime float64) float64 {
Expand Down Expand Up @@ -284,7 +301,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
totalTimeout := r.timeout
logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout)
retryCounter := 0
sleepTime := 1.0 // seconds
sleepTime := time.Duration(time.Second)
clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10)

var requestGUIDReplacer requestGUIDReplacer
Expand Down Expand Up @@ -324,12 +341,16 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
}
// uses exponential jitter backoff
retryCounter++
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
if isLoginRequest(req) {
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime)
} else {
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
}

if totalTimeout > 0 {
logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout)
// if any timeout is set
totalTimeout -= time.Duration(sleepTime * float64(time.Second))
totalTimeout -= sleepTime
if totalTimeout <= 0 || retryCounter > r.maxRetryCount {
if err != nil {
return nil, err
Expand Down Expand Up @@ -360,7 +381,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout)
logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason)

await := time.NewTimer(time.Duration(sleepTime * float64(time.Second)))
await := time.NewTimer(sleepTime)
select {
case <-await.C:
// retry the request
Expand All @@ -378,10 +399,13 @@ func isRetryableError(req *http.Request, res *http.Response, err error) (bool, e
if res == nil || req == nil {
return false, err
}
isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path)
return isRetryableURL && isRetryableStatus(res.StatusCode), err
return isRetryableStatus(res.StatusCode), err
}

func isRetryableStatus(statusCode int) bool {
return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode)
}

func isLoginRequest(req *http.Request) bool {
return contains(authEndpoints, req.URL.Path)
}
20 changes: 8 additions & 12 deletions retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,6 @@ func TestIsRetryable(t *testing.T) {
err: nil,
expected: false,
},
{
req: &http.Request{URL: &url.URL{Path: heartBeatPath}},
res: &http.Response{StatusCode: http.StatusBadRequest},
err: nil,
expected: false,
},
{
req: &http.Request{URL: &url.URL{Path: loginRequestPath}},
res: &http.Response{StatusCode: http.StatusNotFound},
Expand All @@ -525,10 +519,12 @@ func TestIsRetryable(t *testing.T) {
}

for _, tc := range tcs {
result, _ := isRetryableError(tc.req, tc.res, tc.err)
if result != tc.expected {
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
}
t.Run(fmt.Sprintf("req %v, resp %v", tc.req, tc.res), func(t *testing.T) {
result, _ := isRetryableError(tc.req, tc.res, tc.err)
if result != tc.expected {
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
}
})
}
}

Expand Down Expand Up @@ -605,8 +601,8 @@ func TestCalculateRetryWait(t *testing.T) {

for _, tc := range tcs {
t.Run(fmt.Sprintf("attmept: %v", tc.attempt), func(t *testing.T) {
result := defaultWaitAlgo.calculateWaitBeforeRetry(tc.attempt, tc.currWaitTime)
assertBetweenE(t, result, tc.minSleepTime, tc.maxSleepTime)
result := defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(tc.attempt, time.Duration(tc.currWaitTime*float64(time.Second)))
assertBetweenE(t, result.Seconds(), tc.minSleepTime, tc.maxSleepTime)
})
}
}