From 3f9a1aaeb8bab9fe7ab5ecdcb61f40439a2daa6b Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Mon, 4 Dec 2023 13:47:26 +0100 Subject: [PATCH] Add more tests, change if conditions --- connection_util.go | 4 ++-- statement_test.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/connection_util.go b/connection_util.go index b5b1ea8af..af19e73f4 100644 --- a/connection_util.go +++ b/connection_util.go @@ -198,7 +198,7 @@ func isDml(v int64) bool { } func isDql(data *execResponseData) bool { - return data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name != "multiple statement execution" + return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) } func updateRows(data execResponseData) (int64, error) { @@ -302,7 +302,7 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { func isStatementContext(ctx context.Context) bool { v := ctx.Value(executionType) - if v != nil && v == executionTypeStatement { + if v == executionTypeStatement { return true } return false diff --git a/statement_test.go b/statement_test.go index 85c7740c4..a565cb046 100644 --- a/statement_test.go +++ b/statement_test.go @@ -42,9 +42,18 @@ func openConn(t *testing.T) *sql.Conn { func TestExecStmt(t *testing.T) { dqlQuery := "SELECT 1" + dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" + multiStmtQuery := "DELETE FROM TestDDLExec;\n" + + "SELECT 1;\n" + + "SELECT 2;" ctx := context.Background() + multiStmtCtx, err := WithMultiStatement(ctx, 3) + if err != nil { + t.Error(err) + } runDBTest(t, func(dbt *DBTest) { + dbt.mustExec(ddlQuery) defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") testcases := []struct { name string @@ -79,6 +88,27 @@ func TestExecStmt(t *testing.T) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, + { + name: "dml Exec", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.Exec(nil) + }, + }, + { + name: "dml ExecContext", + query: dmlQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + { + name: "multistmt ExecContext", + query: multiStmtQuery, + f: func(stmt driver.Stmt) (any, error) { + return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) + }, + }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { @@ -91,7 +121,7 @@ func TestExecStmt(t *testing.T) { t.Error("queryId should be empty before executing any query") } if _, err := tc.f(stmt); err != nil { - t.Error("should have not failed to execute the query") + t.Errorf("should have not failed to execute the query, err: %s\n", err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("should have set the query id")