Skip to content

Commit

Permalink
SNOW-978164: Fix stmt.Exec for DML (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pbulawa authored Dec 4, 2023
1 parent d56c0f2 commit 1341c39
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 12 deletions.
16 changes: 16 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ const (
queryResultType resultType = "query"
)

type execKey string

const (
executionType execKey = "executionType"
executionTypeStatement string = "statement"
)

const privateLinkSuffix = "privatelink.snowflakecomputing.com"

type snowflakeConn struct {
Expand Down Expand Up @@ -333,8 +340,17 @@ func (sc *snowflakeConn) ExecContext(
}, nil // last insert id is not supported by Snowflake
} else if isMultiStmt(&data.Data) {
return sc.handleMultiExec(ctx, data.Data)
} else if isDql(&data.Data) {
logger.WithContext(ctx).Debugf("DQL")
if isStatementContext(ctx) {
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
}
return driver.ResultNoRows, nil
}
logger.Debug("DDL")
if isStatementContext(ctx) {
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
}
return driver.ResultNoRows, nil
}

Expand Down
12 changes: 12 additions & 0 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,14 @@ func isDml(v int64) bool {
return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert
}

func isDql(data *execResponseData) bool {
return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data)
}

func updateRows(data execResponseData) (int64, error) {
if data.RowSet == nil {
return 0, nil
}
var count int64
for i, n := 0, len(data.RowType); i < n; i++ {
v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64)
Expand Down Expand Up @@ -292,3 +299,8 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error {
}
return nil
}

func isStatementContext(ctx context.Context) bool {
v := ctx.Value(executionType)
return v == executionTypeStatement
}
18 changes: 18 additions & 0 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package gosnowflake

import "errors"

type queryStatus string

const (
Expand Down Expand Up @@ -73,3 +75,19 @@ func (res *snowflakeResult) waitForAsyncExecStatus() error {
}
return nil
}

type snowflakeResultNoRows struct {
queryID string
}

func (*snowflakeResultNoRows) LastInsertId() (int64, error) {
return 0, errors.New("no LastInsertId available")
}

func (*snowflakeResultNoRows) RowsAffected() (int64, error) {
return 0, errors.New("no RowsAffected available")
}

func (rnr *snowflakeResultNoRows) GetQueryID() string {
return rnr.queryID
}
28 changes: 16 additions & 12 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,7 @@ func (stmt *snowflakeStmt) NumInput() int {

func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
}
stmt.lastQueryID = r.GetQueryID()
return result, err
return stmt.execInternal(ctx, args)
}

func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
Expand All @@ -64,11 +54,25 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named

func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
result, err := stmt.sc.Exec(stmt.query, args)
return stmt.execInternal(context.Background(), toNamedValues(args))
}

func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal")
if ctx == nil {
ctx = context.Background()
}
stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement)
result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
rnr, ok := result.(*snowflakeResultNoRows)
if ok {
stmt.lastQueryID = rnr.GetQueryID()
return driver.ResultNoRows, nil
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
Expand Down
96 changes: 96 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,102 @@ func openConn(t *testing.T) *sql.Conn {
return conn
}

func TestExecStmt(t *testing.T) {
dqlQuery := "SELECT 1"
dmlQuery := "INSERT INTO TestDDLExec VALUES (1)"
ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)"
multiStmtQuery := "DELETE FROM TestDDLExec;\n" +
"SELECT 1;\n" +
"SELECT 2;"
ctx := context.Background()
multiStmtCtx, err := WithMultiStatement(ctx, 3)
if err != nil {
t.Error(err)
}
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(ddlQuery)
defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec")
testcases := []struct {
name string
query string
f func(stmt driver.Stmt) (any, error)
}{
{
name: "dql Exec",
query: dqlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "dql ExecContext",
query: dqlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "ddl Exec",
query: ddlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "ddl ExecContext",
query: ddlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "dml Exec",
query: dmlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "dml ExecContext",
query: dmlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "multistmt ExecContext",
query: multiStmtQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil)
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := dbt.conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
if err != nil {
t.Error(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
if _, err := tc.f(stmt); err != nil {
t.Errorf("should have not failed to execute the query, err: %s\n", err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("should have set the query id")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}
})
}

func TestFailedQueryIdInSnowflakeError(t *testing.T) {
failingQuery := "SELECTT 1"
failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"
Expand Down

0 comments on commit 1341c39

Please sign in to comment.