diff --git a/connection.go b/connection.go index 4ba97d14b..5fcb355bb 100644 --- a/connection.go +++ b/connection.go @@ -60,6 +60,13 @@ const ( queryResultType resultType = "query" ) +type execKey string + +const ( + executionType execKey = "executionType" + executionTypeStatement string = "statement" +) + const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { @@ -333,8 +340,17 @@ func (sc *snowflakeConn) ExecContext( }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) + } else if isDql(&data.Data) { + logger.WithContext(ctx).Debugf("DQL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } + return driver.ResultNoRows, nil } logger.Debug("DDL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } return driver.ResultNoRows, nil } diff --git a/connection_util.go b/connection_util.go index 4d37dea28..54390522a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -197,7 +197,14 @@ func isDml(v int64) bool { return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert } +func isDql(data *execResponseData) bool { + return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) +} + func updateRows(data execResponseData) (int64, error) { + if data.RowSet == nil { + return 0, nil + } var count int64 for i, n := 0, len(data.RowType); i < n; i++ { v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) @@ -292,3 +299,8 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { } return nil } + +func isStatementContext(ctx context.Context) bool { + v := ctx.Value(executionType) + return v == executionTypeStatement +} diff --git a/monitoring.go b/monitoring.go index 07b17c0aa..444a7c4bf 100644 --- a/monitoring.go +++ b/monitoring.go @@ -136,7 +136,7 @@ func (sc *snowflakeConn) checkQueryStatus( if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok) } - resultPath := fmt.Sprintf("/monitoring/queries/%s", qid) + resultPath := fmt.Sprintf("%s/%s", monitoringQueriesPath, qid) url := sc.rest.getFullURL(resultPath, ¶m) res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout) diff --git a/restful.go b/restful.go index 777d94df4..c92d9c762 100644 --- a/restful.go +++ b/restful.go @@ -37,6 +37,7 @@ const ( tokenRequestPath = "/session/token-request" abortRequestPath = "/queries/v1/abort-request" authenticatorRequestPath = "/session/authenticator-request" + monitoringQueriesPath = "/monitoring/queries" sessionRequestPath = "/session" heartBeatPath = "/session/heartbeat" ) diff --git a/result.go b/result.go index e08f41902..c2a718308 100644 --- a/result.go +++ b/result.go @@ -2,6 +2,8 @@ package gosnowflake +import "errors" + type queryStatus string const ( @@ -73,3 +75,19 @@ func (res *snowflakeResult) waitForAsyncExecStatus() error { } return nil } + +type snowflakeResultNoRows struct { + queryID string +} + +func (*snowflakeResultNoRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available") +} + +func (*snowflakeResultNoRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available") +} + +func (rnr *snowflakeResultNoRows) GetQueryID() string { + return rnr.queryID +} diff --git a/retry.go b/retry.go index e46148fa7..b70b97e07 100644 --- a/retry.go +++ b/retry.go @@ -20,18 +20,17 @@ import ( type waitAlgo struct { mutex *sync.Mutex // required for *rand.Rand usage random *rand.Rand + base time.Duration // base wait time + cap time.Duration // maximum wait time } var random *rand.Rand var defaultWaitAlgo *waitAlgo -var endpointsEligibleForRetry = []string{ +var authEndpoints = []string{ loginRequestPath, tokenRequestPath, authenticatorRequestPath, - queryRequestPath, - abortRequestPath, - sessionRequestPath, } var clientErrorsStatusCodesEligibleForRetry = []int{ @@ -43,7 +42,7 @@ var clientErrorsStatusCodesEligibleForRetry = []int{ func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) - defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random} + defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 5 * time.Second, cap: 160 * time.Second} } const ( @@ -205,12 +204,30 @@ func isQueryRequest(url *url.URL) bool { } // jitter backoff in seconds -func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime float64) float64 { +func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration { w.mutex.Lock() defer w.mutex.Unlock() - jitterAmount := w.getJitter(currWaitTime) - jitteredSleepTime := chooseRandomFromRange(currWaitTime+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount) - return jitteredSleepTime + currWaitTimeInSeconds := currWaitTimeDuration.Seconds() + jitterAmount := w.getJitter(currWaitTimeInSeconds) + jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount) + return time.Duration(jitteredSleepTime * float64(time.Second)) +} + +func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration { + w.mutex.Lock() + defer w.mutex.Unlock() + t := 3*sleep - w.base + switch { + case t > 0: + return durationMin(w.cap, randSecondDuration(t)+w.base) + case t < 0: + return durationMin(w.cap, randSecondDuration(-t)+3*sleep) + } + return w.base +} + +func randSecondDuration(n time.Duration) time.Duration { + return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second } func (w *waitAlgo) getJitter(currWaitTime float64) float64 { @@ -284,7 +301,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { totalTimeout := r.timeout logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout) retryCounter := 0 - sleepTime := 1.0 // seconds + sleepTime := time.Duration(time.Second) clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) var requestGUIDReplacer requestGUIDReplacer @@ -324,12 +341,16 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } // uses exponential jitter backoff retryCounter++ - sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime) + if isLoginRequest(req) { + sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime) + } else { + sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime) + } if totalTimeout > 0 { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) // if any timeout is set - totalTimeout -= time.Duration(sleepTime * float64(time.Second)) + totalTimeout -= sleepTime if totalTimeout <= 0 || retryCounter > r.maxRetryCount { if err != nil { return nil, err @@ -360,7 +381,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout) logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason) - await := time.NewTimer(time.Duration(sleepTime * float64(time.Second))) + await := time.NewTimer(sleepTime) select { case <-await.C: // retry the request @@ -378,10 +399,13 @@ func isRetryableError(req *http.Request, res *http.Response, err error) (bool, e if res == nil || req == nil { return false, err } - isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path) - return isRetryableURL && isRetryableStatus(res.StatusCode), err + return isRetryableStatus(res.StatusCode), err } func isRetryableStatus(statusCode int) bool { return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode) } + +func isLoginRequest(req *http.Request) bool { + return contains(authEndpoints, req.URL.Path) +} diff --git a/retry_test.go b/retry_test.go index 90c77dea6..0b280bf45 100644 --- a/retry_test.go +++ b/retry_test.go @@ -493,12 +493,6 @@ func TestIsRetryable(t *testing.T) { err: nil, expected: false, }, - { - req: &http.Request{URL: &url.URL{Path: heartBeatPath}}, - res: &http.Response{StatusCode: http.StatusBadRequest}, - err: nil, - expected: false, - }, { req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: &http.Response{StatusCode: http.StatusNotFound}, @@ -525,10 +519,12 @@ func TestIsRetryable(t *testing.T) { } for _, tc := range tcs { - result, _ := isRetryableError(tc.req, tc.res, tc.err) - if result != tc.expected { - t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res) - } + t.Run(fmt.Sprintf("req %v, resp %v", tc.req, tc.res), func(t *testing.T) { + result, _ := isRetryableError(tc.req, tc.res, tc.err) + if result != tc.expected { + t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res) + } + }) } } @@ -605,8 +601,8 @@ func TestCalculateRetryWait(t *testing.T) { for _, tc := range tcs { t.Run(fmt.Sprintf("attmept: %v", tc.attempt), func(t *testing.T) { - result := defaultWaitAlgo.calculateWaitBeforeRetry(tc.attempt, tc.currWaitTime) - assertBetweenE(t, result, tc.minSleepTime, tc.maxSleepTime) + result := defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(tc.attempt, time.Duration(tc.currWaitTime*float64(time.Second))) + assertBetweenE(t, result.Seconds(), tc.minSleepTime, tc.maxSleepTime) }) } } diff --git a/statement.go b/statement.go index 3647a7679..e3ce5b744 100644 --- a/statement.go +++ b/statement.go @@ -34,17 +34,7 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") - result, err := stmt.sc.ExecContext(ctx, stmt.query, args) - if err != nil { - stmt.setQueryIDFromError(err) - return nil, err - } - r, ok := result.(SnowflakeResult) - if !ok { - return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) - } - stmt.lastQueryID = r.GetQueryID() - return result, err + return stmt.execInternal(ctx, args) } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -64,11 +54,25 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - result, err := stmt.sc.Exec(stmt.query, args) + return stmt.execInternal(context.Background(), toNamedValues(args)) +} + +func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal") + if ctx == nil { + ctx = context.Background() + } + stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement) + result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args) if err != nil { stmt.setQueryIDFromError(err) return nil, err } + rnr, ok := result.(*snowflakeResultNoRows) + if ok { + stmt.lastQueryID = rnr.GetQueryID() + return driver.ResultNoRows, nil + } r, ok := result.(SnowflakeResult) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) diff --git a/statement_test.go b/statement_test.go index 5fe5058ec..a565cb046 100644 --- a/statement_test.go +++ b/statement_test.go @@ -40,6 +40,102 @@ func openConn(t *testing.T) *sql.Conn { return conn } +func TestExecStmt(t *testing.T) { + dqlQuery := "SELECT 1" + dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" + ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + multiStmtQuery := "DELETE FROM TestDDLExec;\n" + + "SELECT 1;\n" + + "SELECT 2;" + ctx := context.Background() + multiStmtCtx, err := WithMultiStatement(ctx, 3) + if err != nil { + t.Error(err) + } + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec(ddlQuery) + defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (any, error) + }{ + { + name: "dql Exec", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dql ExecContext", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "ddl Exec", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "ddl ExecContext", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "dml Exec", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dml ExecContext", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "multistmt ExecContext", + query: multiStmtQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) + }, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + t.Error(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + if _, err := tc.f(stmt); err != nil { + t.Errorf("should have not failed to execute the query, err: %s\n", err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("should have set the query id") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func TestFailedQueryIdInSnowflakeError(t *testing.T) { failingQuery := "SELECTT 1" failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"