Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-978164: Fix stmt.Exec for DML #978

Merged
merged 13 commits into from
Dec 4, 2023
Merged
16 changes: 16 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ const (
queryResultType resultType = "query"
)

type execKey string

const (
executionType execKey = "executionType"
executionTypeStatement string = "statement"
)

const privateLinkSuffix = "privatelink.snowflakecomputing.com"

type snowflakeConn struct {
Expand Down Expand Up @@ -333,8 +340,17 @@ func (sc *snowflakeConn) ExecContext(
}, nil // last insert id is not supported by Snowflake
} else if isMultiStmt(&data.Data) {
return sc.handleMultiExec(ctx, data.Data)
} else if isDql(&data.Data) {
logger.WithContext(ctx).Debugf("DQL")
if isStatementContext(ctx) {
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
}
return driver.ResultNoRows, nil
}
logger.Debug("DDL")
if isStatementContext(ctx) {
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
}
return driver.ResultNoRows, nil
}

Expand Down
12 changes: 12 additions & 0 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,14 @@
return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert
}

func isDql(data *execResponseData) bool {
return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data)
}

func updateRows(data execResponseData) (int64, error) {
if data.RowSet == nil {
return 0, nil
}

Check warning on line 207 in connection_util.go

View check run for this annotation

Codecov / codecov/patch

connection_util.go#L206-L207

Added lines #L206 - L207 were not covered by tests
var count int64
for i, n := 0, len(data.RowType); i < n; i++ {
v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64)
Expand Down Expand Up @@ -292,3 +299,8 @@
}
return nil
}

func isStatementContext(ctx context.Context) bool {
v := ctx.Value(executionType)
return v == executionTypeStatement
}
18 changes: 18 additions & 0 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package gosnowflake

import "errors"

type queryStatus string

const (
Expand Down Expand Up @@ -73,3 +75,19 @@
}
return nil
}

type snowflakeResultNoRows struct {
queryID string
}

func (*snowflakeResultNoRows) LastInsertId() (int64, error) {
return 0, errors.New("no LastInsertId available")

Check warning on line 84 in result.go

View check run for this annotation

Codecov / codecov/patch

result.go#L83-L84

Added lines #L83 - L84 were not covered by tests
}

func (*snowflakeResultNoRows) RowsAffected() (int64, error) {
return 0, errors.New("no RowsAffected available")

Check warning on line 88 in result.go

View check run for this annotation

Codecov / codecov/patch

result.go#L87-L88

Added lines #L87 - L88 were not covered by tests
}

func (rnr *snowflakeResultNoRows) GetQueryID() string {
return rnr.queryID
}
28 changes: 16 additions & 12 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,7 @@

func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
}
stmt.lastQueryID = r.GetQueryID()
return result, err
return stmt.execInternal(ctx, args)
}

func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
Expand All @@ -64,11 +54,25 @@

func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
result, err := stmt.sc.Exec(stmt.query, args)
return stmt.execInternal(context.Background(), toNamedValues(args))
}

func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal")
if ctx == nil {
ctx = context.Background()
}

Check warning on line 64 in statement.go

View check run for this annotation

Codecov / codecov/patch

statement.go#L63-L64

Added lines #L63 - L64 were not covered by tests
stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement)
result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
rnr, ok := result.(*snowflakeResultNoRows)
if ok {
stmt.lastQueryID = rnr.GetQueryID()
return driver.ResultNoRows, nil
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
Expand Down
96 changes: 96 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,102 @@ func openConn(t *testing.T) *sql.Conn {
return conn
}

func TestExecStmt(t *testing.T) {
dqlQuery := "SELECT 1"
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
dmlQuery := "INSERT INTO TestDDLExec VALUES (1)"
ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)"
multiStmtQuery := "DELETE FROM TestDDLExec;\n" +
"SELECT 1;\n" +
"SELECT 2;"
ctx := context.Background()
multiStmtCtx, err := WithMultiStatement(ctx, 3)
if err != nil {
t.Error(err)
}
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(ddlQuery)
defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec")
testcases := []struct {
name string
query string
f func(stmt driver.Stmt) (any, error)
}{
{
name: "dql Exec",
query: dqlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "dql ExecContext",
query: dqlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "ddl Exec",
query: ddlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "ddl ExecContext",
query: ddlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "dml Exec",
query: dmlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "dml ExecContext",
query: dmlQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
{
name: "multistmt ExecContext",
query: multiStmtQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil)
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := dbt.conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
if err != nil {
t.Error(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
if _, err := tc.f(stmt); err != nil {
t.Errorf("should have not failed to execute the query, err: %s\n", err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("should have set the query id")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}
})
}

func TestFailedQueryIdInSnowflakeError(t *testing.T) {
failingQuery := "SELECTT 1"
failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"
Expand Down
Loading