diff --git a/connection.go b/connection.go index 4ba97d14b..5fcb355bb 100644 --- a/connection.go +++ b/connection.go @@ -60,6 +60,13 @@ const ( queryResultType resultType = "query" ) +type execKey string + +const ( + executionType execKey = "executionType" + executionTypeStatement string = "statement" +) + const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { @@ -333,8 +340,17 @@ func (sc *snowflakeConn) ExecContext( }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) + } else if isDql(&data.Data) { + logger.WithContext(ctx).Debugf("DQL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } + return driver.ResultNoRows, nil } logger.Debug("DDL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } return driver.ResultNoRows, nil } diff --git a/connection_util.go b/connection_util.go index 4d37dea28..54390522a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -197,7 +197,14 @@ func isDml(v int64) bool { return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert } +func isDql(data *execResponseData) bool { + return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) +} + func updateRows(data execResponseData) (int64, error) { + if data.RowSet == nil { + return 0, nil + } var count int64 for i, n := 0, len(data.RowType); i < n; i++ { v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) @@ -292,3 +299,8 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { } return nil } + +func isStatementContext(ctx context.Context) bool { + v := ctx.Value(executionType) + return v == executionTypeStatement +} diff --git a/result.go b/result.go index e08f41902..c2a718308 100644 --- a/result.go +++ b/result.go @@ -2,6 +2,8 @@ package gosnowflake +import "errors" + type queryStatus string const ( @@ -73,3 +75,19 @@ func (res *snowflakeResult) waitForAsyncExecStatus() error { } return nil } + +type snowflakeResultNoRows struct { + queryID string +} + +func (*snowflakeResultNoRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available") +} + +func (*snowflakeResultNoRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available") +} + +func (rnr *snowflakeResultNoRows) GetQueryID() string { + return rnr.queryID +} diff --git a/statement.go b/statement.go index 3647a7679..e3ce5b744 100644 --- a/statement.go +++ b/statement.go @@ -34,17 +34,7 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") - result, err := stmt.sc.ExecContext(ctx, stmt.query, args) - if err != nil { - stmt.setQueryIDFromError(err) - return nil, err - } - r, ok := result.(SnowflakeResult) - if !ok { - return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) - } - stmt.lastQueryID = r.GetQueryID() - return result, err + return stmt.execInternal(ctx, args) } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -64,11 +54,25 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - result, err := stmt.sc.Exec(stmt.query, args) + return stmt.execInternal(context.Background(), toNamedValues(args)) +} + +func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal") + if ctx == nil { + ctx = context.Background() + } + stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement) + result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args) if err != nil { stmt.setQueryIDFromError(err) return nil, err } + rnr, ok := result.(*snowflakeResultNoRows) + if ok { + stmt.lastQueryID = rnr.GetQueryID() + return driver.ResultNoRows, nil + } r, ok := result.(SnowflakeResult) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) diff --git a/statement_test.go b/statement_test.go index 5fe5058ec..a565cb046 100644 --- a/statement_test.go +++ b/statement_test.go @@ -40,6 +40,102 @@ func openConn(t *testing.T) *sql.Conn { return conn } +func TestExecStmt(t *testing.T) { + dqlQuery := "SELECT 1" + dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" + ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + multiStmtQuery := "DELETE FROM TestDDLExec;\n" + + "SELECT 1;\n" + + "SELECT 2;" + ctx := context.Background() + multiStmtCtx, err := WithMultiStatement(ctx, 3) + if err != nil { + t.Error(err) + } + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec(ddlQuery) + defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (any, error) + }{ + { + name: "dql Exec", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dql ExecContext", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "ddl Exec", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "ddl ExecContext", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "dml Exec", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dml ExecContext", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "multistmt ExecContext", + query: multiStmtQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) + }, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + t.Error(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + if _, err := tc.f(stmt); err != nil { + t.Errorf("should have not failed to execute the query, err: %s\n", err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("should have set the query id") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func TestFailedQueryIdInSnowflakeError(t *testing.T) { failingQuery := "SELECTT 1" failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"