From 922bfb7f455992e6eaa356df6b3c95260acc9f6b Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Tue, 28 Nov 2023 11:22:55 +0100 Subject: [PATCH 01/10] Return ResultNoRows --- statement.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/statement.go b/statement.go index 3647a7679..783e80c90 100644 --- a/statement.go +++ b/statement.go @@ -41,7 +41,7 @@ func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedV } r, ok := result.(SnowflakeResult) if !ok { - return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + return driver.ResultNoRows, nil } stmt.lastQueryID = r.GetQueryID() return result, err @@ -71,7 +71,7 @@ func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { } r, ok := result.(SnowflakeResult) if !ok { - return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + return driver.ResultNoRows, nil } stmt.lastQueryID = r.GetQueryID() return result, err From 61e84d10cb6e43e8bf23b67444518ba1d825aa3f Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Wed, 29 Nov 2023 16:55:14 +0100 Subject: [PATCH 02/10] Return SnowflakeResult instead of driver.ResultNoRows --- connection.go | 6 ++++- connection_util.go | 5 +++- statement_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index 4ba97d14b..538677de6 100644 --- a/connection.go +++ b/connection.go @@ -335,7 +335,11 @@ func (sc *snowflakeConn) ExecContext( return sc.handleMultiExec(ctx, data.Data) } logger.Debug("DDL") - return driver.ResultNoRows, nil + return &snowflakeResult{ + affectedRows: 0, + insertID: -1, + queryID: data.Data.QueryID, + }, nil } func (sc *snowflakeConn) QueryContext( diff --git a/connection_util.go b/connection_util.go index 4d37dea28..5e0e92c97 100644 --- a/connection_util.go +++ b/connection_util.go @@ -194,10 +194,13 @@ func getResultType(ctx context.Context) resultType { // isDml returns true if the statement type code is in the range of DML. func isDml(v int64) bool { - return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert + return (statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert) || v == statementTypeIDSelect } func updateRows(data execResponseData) (int64, error) { + if data.RowSet == nil { + return 0, nil + } var count int64 for i, n := 0, len(data.RowType); i < n; i++ { v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) diff --git a/statement_test.go b/statement_test.go index 5fe5058ec..ea4a77930 100644 --- a/statement_test.go +++ b/statement_test.go @@ -40,6 +40,68 @@ func openConn(t *testing.T) *sql.Conn { return conn } +func TestDMLExec(t *testing.T) { + query := "SELECT 1" + runDBTest(t, func(dbt *DBTest) { + testcases := []struct { + name string + f func(dbt *DBTest) (any, error) + }{ + { + name: "Exec", + f: func(dbt *DBTest) (any, error) { + stmt, _ := dbt.prepare(query) + return stmt.Exec() + }, + }, + { + name: "ExecContext", + f: func(dbt *DBTest) (any, error) { + stmt, _ := dbt.prepare(query) + return stmt.ExecContext(context.Background()) + }, + }, + } + for _, tc := range testcases { + _, err := tc.f(dbt) + if err != nil { + t.Error(err) + } + } + }) +} + +func TestDDLExec(t *testing.T) { + query := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + runDBTest(t, func(dbt *DBTest) { + testcases := []struct { + name string + f func(dbt *DBTest) (any, error) + }{ + { + name: "Exec", + f: func(dbt *DBTest) (any, error) { + stmt, _ := dbt.prepare(query) + return stmt.Exec() + }, + }, + { + name: "ExecContext", + f: func(dbt *DBTest) (any, error) { + stmt, _ := dbt.prepare(query) + return stmt.ExecContext(context.Background()) + }, + }, + } + for _, tc := range testcases { + _, err := tc.f(dbt) + if err != nil { + t.Error(err) + } + } + }) +} + func TestFailedQueryIdInSnowflakeError(t *testing.T) { failingQuery := "SELECTT 1" failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE" From a5234e560f3ca177df99d57ef73fef11270f1888 Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Thu, 30 Nov 2023 11:00:48 +0100 Subject: [PATCH 03/10] Remove select from Dml and add separate method for select --- connection.go | 7 +++++++ connection_util.go | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 538677de6..82f6b0147 100644 --- a/connection.go +++ b/connection.go @@ -333,6 +333,13 @@ func (sc *snowflakeConn) ExecContext( }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) + } else if isSelect(data.Data.StatementTypeID) { + logger.WithContext(ctx).Debugf("SELECT") + return &snowflakeResult{ + affectedRows: 0, + insertID: -1, + queryID: data.Data.QueryID, + }, nil } logger.Debug("DDL") return &snowflakeResult{ diff --git a/connection_util.go b/connection_util.go index 5e0e92c97..812984161 100644 --- a/connection_util.go +++ b/connection_util.go @@ -194,7 +194,11 @@ func getResultType(ctx context.Context) resultType { // isDml returns true if the statement type code is in the range of DML. func isDml(v int64) bool { - return (statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert) || v == statementTypeIDSelect + return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert +} + +func isSelect(v int64) bool { + return v == statementTypeIDSelect } func updateRows(data execResponseData) (int64, error) { From 7892788c380c9e284b1cf0c05f533a7171f1f10b Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Thu, 30 Nov 2023 11:47:26 +0100 Subject: [PATCH 04/10] Use ResultNoRows for DDL/DQL. Drop table after tests --- connection.go | 16 ++++------------ connection_util.go | 4 ++-- statement.go | 6 ++++++ statement_test.go | 15 ++++++++------- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/connection.go b/connection.go index 82f6b0147..3d6c3a0a5 100644 --- a/connection.go +++ b/connection.go @@ -333,20 +333,12 @@ func (sc *snowflakeConn) ExecContext( }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) - } else if isSelect(data.Data.StatementTypeID) { - logger.WithContext(ctx).Debugf("SELECT") - return &snowflakeResult{ - affectedRows: 0, - insertID: -1, - queryID: data.Data.QueryID, - }, nil + } else if isDql(&data.Data) { + logger.WithContext(ctx).Debugf("DQL") + return driver.ResultNoRows, nil } logger.Debug("DDL") - return &snowflakeResult{ - affectedRows: 0, - insertID: -1, - queryID: data.Data.QueryID, - }, nil + return driver.ResultNoRows, nil } func (sc *snowflakeConn) QueryContext( diff --git a/connection_util.go b/connection_util.go index 812984161..75f819c9b 100644 --- a/connection_util.go +++ b/connection_util.go @@ -197,8 +197,8 @@ func isDml(v int64) bool { return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert } -func isSelect(v int64) bool { - return v == statementTypeIDSelect +func isDql(data *execResponseData) bool { + return data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name != "multiple statement execution" } func updateRows(data execResponseData) (int64, error) { diff --git a/statement.go b/statement.go index 3647a7679..62f032fec 100644 --- a/statement.go +++ b/statement.go @@ -39,6 +39,9 @@ func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedV stmt.setQueryIDFromError(err) return nil, err } + if result == driver.ResultNoRows { + return result, nil + } r, ok := result.(SnowflakeResult) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) @@ -69,6 +72,9 @@ func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { stmt.setQueryIDFromError(err) return nil, err } + if result == driver.ResultNoRows { + return result, nil + } r, ok := result.(SnowflakeResult) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) diff --git a/statement_test.go b/statement_test.go index ea4a77930..0bb011dbf 100644 --- a/statement_test.go +++ b/statement_test.go @@ -40,23 +40,23 @@ func openConn(t *testing.T) *sql.Conn { return conn } -func TestDMLExec(t *testing.T) { +func TestDQLExec(t *testing.T) { query := "SELECT 1" runDBTest(t, func(dbt *DBTest) { testcases := []struct { name string - f func(dbt *DBTest) (any, error) + f func(dbt *DBTest) (driver.Result, error) }{ { name: "Exec", - f: func(dbt *DBTest) (any, error) { + f: func(dbt *DBTest) (driver.Result, error) { stmt, _ := dbt.prepare(query) return stmt.Exec() }, }, { name: "ExecContext", - f: func(dbt *DBTest) (any, error) { + f: func(dbt *DBTest) (driver.Result, error) { stmt, _ := dbt.prepare(query) return stmt.ExecContext(context.Background()) }, @@ -74,20 +74,21 @@ func TestDMLExec(t *testing.T) { func TestDDLExec(t *testing.T) { query := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" runDBTest(t, func(dbt *DBTest) { + defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") testcases := []struct { name string - f func(dbt *DBTest) (any, error) + f func(dbt *DBTest) (driver.Result, error) }{ { name: "Exec", - f: func(dbt *DBTest) (any, error) { + f: func(dbt *DBTest) (driver.Result, error) { stmt, _ := dbt.prepare(query) return stmt.Exec() }, }, { name: "ExecContext", - f: func(dbt *DBTest) (any, error) { + f: func(dbt *DBTest) (driver.Result, error) { stmt, _ := dbt.prepare(query) return stmt.ExecContext(context.Background()) }, From c82aa67989b2253503f1508cf3a72106b2ad8a4f Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Thu, 30 Nov 2023 11:49:29 +0100 Subject: [PATCH 05/10] Reintroduced TestResultNoRows --- rows_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/rows_test.go b/rows_test.go index cbb01045b..fabb6c3bd 100644 --- a/rows_test.go +++ b/rows_test.go @@ -74,6 +74,22 @@ func TestRowsClose(t *testing.T) { }) } +func TestResultNoRows(t *testing.T) { + // DDL + runDBTest(t, func(dbt *DBTest) { + row, err := dbt.exec("CREATE OR REPLACE TABLE test(c1 int)") + if err != nil { + t.Fatalf("failed to execute DDL. err: %v", err) + } + if _, err = row.RowsAffected(); err == nil { + t.Fatal("should have failed to get RowsAffected") + } + if _, err = row.LastInsertId(); err == nil { + t.Fatal("should have failed to get LastInsertID") + } + }) +} + func TestRowsWithoutChunkDownloader(t *testing.T) { sts1 := "1" sts2 := "Test1" From c56556215391afb30c69a028afa824ef9006c4db Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Fri, 1 Dec 2023 16:58:15 +0100 Subject: [PATCH 06/10] Introduce snowflakeResultNoRows, use context to determine if statement is used --- connection.go | 11 ++++++ connection_util.go | 8 +++++ result.go | 18 ++++++++++ statement.go | 32 ++++++++--------- statement_test.go | 87 ++++++++++++++++++++++++---------------------- 5 files changed, 97 insertions(+), 59 deletions(-) diff --git a/connection.go b/connection.go index 3d6c3a0a5..f68648a96 100644 --- a/connection.go +++ b/connection.go @@ -60,6 +60,11 @@ const ( queryResultType resultType = "query" ) +const ( + executionType = "executionType" + executionTypeStatement = "statement" +) + const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { @@ -335,9 +340,15 @@ func (sc *snowflakeConn) ExecContext( return sc.handleMultiExec(ctx, data.Data) } else if isDql(&data.Data) { logger.WithContext(ctx).Debugf("DQL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } return driver.ResultNoRows, nil } logger.Debug("DDL") + if isStatementContext(ctx) { + return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil + } return driver.ResultNoRows, nil } diff --git a/connection_util.go b/connection_util.go index 75f819c9b..b5b1ea8af 100644 --- a/connection_util.go +++ b/connection_util.go @@ -299,3 +299,11 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { } return nil } + +func isStatementContext(ctx context.Context) bool { + v := ctx.Value(executionType) + if v != nil && v == executionTypeStatement { + return true + } + return false +} diff --git a/result.go b/result.go index e08f41902..1b7545686 100644 --- a/result.go +++ b/result.go @@ -2,6 +2,8 @@ package gosnowflake +import "errors" + type queryStatus string const ( @@ -73,3 +75,19 @@ func (res *snowflakeResult) waitForAsyncExecStatus() error { } return nil } + +type snowflakeResultNoRows struct { + queryID string +} + +func (*snowflakeResultNoRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available after DDL statement") +} + +func (*snowflakeResultNoRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available after DDL statement") +} + +func (rnr *snowflakeResultNoRows) GetQueryID() string { + return rnr.queryID +} diff --git a/statement.go b/statement.go index 62f032fec..3cb9a9285 100644 --- a/statement.go +++ b/statement.go @@ -34,20 +34,7 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") - result, err := stmt.sc.ExecContext(ctx, stmt.query, args) - if err != nil { - stmt.setQueryIDFromError(err) - return nil, err - } - if result == driver.ResultNoRows { - return result, nil - } - r, ok := result.(SnowflakeResult) - if !ok { - return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) - } - stmt.lastQueryID = r.GetQueryID() - return result, err + return stmt.execInternal(ctx, args) } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -67,13 +54,24 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - result, err := stmt.sc.Exec(stmt.query, args) + return stmt.execInternal(nil, toNamedValues(args)) +} + +func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal") + if ctx == nil { + ctx = context.Background() + } + stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement) + result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args) if err != nil { stmt.setQueryIDFromError(err) return nil, err } - if result == driver.ResultNoRows { - return result, nil + rnr, ok := result.(*snowflakeResultNoRows) + if ok { + stmt.lastQueryID = rnr.GetQueryID() + return driver.ResultNoRows, nil } r, ok := result.(SnowflakeResult) if !ok { diff --git a/statement_test.go b/statement_test.go index 0bb011dbf..85c7740c4 100644 --- a/statement_test.go +++ b/statement_test.go @@ -40,65 +40,68 @@ func openConn(t *testing.T) *sql.Conn { return conn } -func TestDQLExec(t *testing.T) { - query := "SELECT 1" +func TestExecStmt(t *testing.T) { + dqlQuery := "SELECT 1" + ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + ctx := context.Background() runDBTest(t, func(dbt *DBTest) { + defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") testcases := []struct { - name string - f func(dbt *DBTest) (driver.Result, error) + name string + query string + f func(stmt driver.Stmt) (any, error) }{ { - name: "Exec", - f: func(dbt *DBTest) (driver.Result, error) { - stmt, _ := dbt.prepare(query) - return stmt.Exec() + name: "dql Exec", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) }, }, { - name: "ExecContext", - f: func(dbt *DBTest) (driver.Result, error) { - stmt, _ := dbt.prepare(query) - return stmt.ExecContext(context.Background()) + name: "dql ExecContext", + query: dqlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, - } - for _, tc := range testcases { - _, err := tc.f(dbt) - if err != nil { - t.Error(err) - } - } - }) -} - -func TestDDLExec(t *testing.T) { - query := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" - runDBTest(t, func(dbt *DBTest) { - defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") - testcases := []struct { - name string - f func(dbt *DBTest) (driver.Result, error) - }{ { - name: "Exec", - f: func(dbt *DBTest) (driver.Result, error) { - stmt, _ := dbt.prepare(query) - return stmt.Exec() + name: "ddl Exec", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) }, }, { - name: "ExecContext", - f: func(dbt *DBTest) (driver.Result, error) { - stmt, _ := dbt.prepare(query) - return stmt.ExecContext(context.Background()) + name: "ddl ExecContext", + query: ddlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, } for _, tc := range testcases { - _, err := tc.f(dbt) - if err != nil { - t.Error(err) - } + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + t.Error(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + if _, err := tc.f(stmt); err != nil { + t.Error("should have not failed to execute the query") + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("should have set the query id") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) } }) } From 49a007a91e15ac776354a1258f3e94f9bf3fa3ee Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Sun, 3 Dec 2023 13:41:45 +0100 Subject: [PATCH 07/10] lint --- connection.go | 6 ++++-- go.sum | 10 ---------- statement.go | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/connection.go b/connection.go index f68648a96..5fcb355bb 100644 --- a/connection.go +++ b/connection.go @@ -60,9 +60,11 @@ const ( queryResultType resultType = "query" ) +type execKey string + const ( - executionType = "executionType" - executionTypeStatement = "statement" + executionType execKey = "executionType" + executionTypeStatement string = "statement" ) const privateLinkSuffix = "privatelink.snowflakecomputing.com" diff --git a/go.sum b/go.sum index 4dae0767b..f2e9fd274 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,6 @@ github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg= @@ -145,8 +143,6 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -158,17 +154,11 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/statement.go b/statement.go index 3cb9a9285..e3ce5b744 100644 --- a/statement.go +++ b/statement.go @@ -54,7 +54,7 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - return stmt.execInternal(nil, toNamedValues(args)) + return stmt.execInternal(context.Background(), toNamedValues(args)) } func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { From 3f9a1aaeb8bab9fe7ab5ecdcb61f40439a2daa6b Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Mon, 4 Dec 2023 13:47:26 +0100 Subject: [PATCH 08/10] Add more tests, change if conditions --- connection_util.go | 4 ++-- statement_test.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/connection_util.go b/connection_util.go index b5b1ea8af..af19e73f4 100644 --- a/connection_util.go +++ b/connection_util.go @@ -198,7 +198,7 @@ func isDml(v int64) bool { } func isDql(data *execResponseData) bool { - return data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name != "multiple statement execution" + return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) } func updateRows(data execResponseData) (int64, error) { @@ -302,7 +302,7 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { func isStatementContext(ctx context.Context) bool { v := ctx.Value(executionType) - if v != nil && v == executionTypeStatement { + if v == executionTypeStatement { return true } return false diff --git a/statement_test.go b/statement_test.go index 85c7740c4..a565cb046 100644 --- a/statement_test.go +++ b/statement_test.go @@ -42,9 +42,18 @@ func openConn(t *testing.T) *sql.Conn { func TestExecStmt(t *testing.T) { dqlQuery := "SELECT 1" + dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + multiStmtQuery := "DELETE FROM TestDDLExec;\n" + + "SELECT 1;\n" + + "SELECT 2;" ctx := context.Background() + multiStmtCtx, err := WithMultiStatement(ctx, 3) + if err != nil { + t.Error(err) + } runDBTest(t, func(dbt *DBTest) { + dbt.mustExec(ddlQuery) defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") testcases := []struct { name string @@ -79,6 +88,27 @@ func TestExecStmt(t *testing.T) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, + { + name: "dml Exec", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dml ExecContext", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "multistmt ExecContext", + query: multiStmtQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) + }, + }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { @@ -91,7 +121,7 @@ func TestExecStmt(t *testing.T) { t.Error("queryId should be empty before executing any query") } if _, err := tc.f(stmt); err != nil { - t.Error("should have not failed to execute the query") + t.Errorf("should have not failed to execute the query, err: %s\n", err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("should have set the query id") From 7f670a2f6f1ef18ee6b3931b91be8792e4049ca0 Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Mon, 4 Dec 2023 14:00:20 +0100 Subject: [PATCH 09/10] lint --- connection_util.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/connection_util.go b/connection_util.go index af19e73f4..54390522a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -302,8 +302,5 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { func isStatementContext(ctx context.Context) bool { v := ctx.Value(executionType) - if v == executionTypeStatement { - return true - } - return false + return v == executionTypeStatement } From 8e85efc985b49781bcb17fb381428a8339618a98 Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Mon, 4 Dec 2023 14:31:11 +0100 Subject: [PATCH 10/10] Change error message, revert go.sum change --- go.sum | 10 ++++++++++ result.go | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/go.sum b/go.sum index f2e9fd274..4dae0767b 100644 --- a/go.sum +++ b/go.sum @@ -134,6 +134,8 @@ github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg= @@ -143,6 +145,8 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -154,11 +158,17 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/result.go b/result.go index 1b7545686..c2a718308 100644 --- a/result.go +++ b/result.go @@ -81,11 +81,11 @@ type snowflakeResultNoRows struct { } func (*snowflakeResultNoRows) LastInsertId() (int64, error) { - return 0, errors.New("no LastInsertId available after DDL statement") + return 0, errors.New("no LastInsertId available") } func (*snowflakeResultNoRows) RowsAffected() (int64, error) { - return 0, errors.New("no RowsAffected available after DDL statement") + return 0, errors.New("no RowsAffected available") } func (rnr *snowflakeResultNoRows) GetQueryID() string {