From 624cafc82e6c85e24795912a8ba558c18bc4707c Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Thu, 2 Jan 2025 19:52:11 -0800 Subject: [PATCH] change put get test to test with stream and non-stream to increase codecov --- put_get_test.go | 80 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/put_get_test.go b/put_get_test.go index 68ca087d4..3732ee6ad 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -699,18 +699,39 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { }) } -func TestPutGetLargeFile(t *testing.T) { +func TestPutGetLargeFileNonStream(t *testing.T) { + testPutGetLargeFile(t, false) +} + +func TestPutGetLargeFileStream(t *testing.T) { + testPutGetLargeFile(t, true) +} + +func testPutGetLargeFile(t *testing.T, isStream bool) { sourceDir, err := os.Getwd() assertNilF(t, err) + fname := filepath.Join(sourceDir, "/test_data/largefile.txt") + fnamegz := "largefile.txt.gz" runDBTest(t, func(dbt *DBTest) { stageDir := "test_put_largefile_" + randomString(10) dbt.mustExec("rm @~/" + stageDir) + ctx := context.Background() + if isStream { + f, err := os.Open(fname) + assertNilF(t, err) + defer func() { + assertNilF(t, f.Close()) + }() + ctx = WithFileStream(ctx, f) + } + // PUT test - putQuery := fmt.Sprintf("put file://%v/test_data/largefile.txt @~/%v", sourceDir, stageDir) + putQuery := fmt.Sprintf("put file://%v @~/%v", fname, stageDir) sqlText := strings.ReplaceAll(putQuery, "\\", "\\\\") - dbt.mustExec(sqlText) + _ = dbt.mustExecContext(ctx, sqlText) + defer dbt.mustExec("rm @~/" + stageDir) rows := dbt.mustQuery("ls @~/" + stageDir) defer func() { @@ -722,15 +743,20 @@ func TestPutGetLargeFile(t *testing.T) { assertNilF(t, err) } - if !strings.Contains(file, "largefile.txt.gz") { + if !strings.Contains(file, fnamegz) { t.Fatalf("should contain file. got: %v", file) } - // GET test with stream + // GET test var streamBuf bytes.Buffer - ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{GetFileToStream: true}) - ctx = WithFileGetStream(ctx, &streamBuf) - sql := fmt.Sprintf("get @~/%v/largefile.txt.gz 'file://%v'", stageDir, t.TempDir()) + ctx = context.Background() + if isStream { + ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{GetFileToStream: true}) + ctx = WithFileGetStream(ctx, &streamBuf) + } + + tmpDir := t.TempDir() + sql := fmt.Sprintf("get @~/%v/%v 'file://%v'", stageDir, fnamegz, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) defer func() { @@ -739,7 +765,7 @@ func TestPutGetLargeFile(t *testing.T) { for rows2.Next() { err = rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) - assertTrueE(t, strings.HasPrefix(file, "largefile.txt.gz"), "a file was not downloaded by GET") + assertTrueE(t, strings.HasPrefix(file, fnamegz), "a file was not downloaded by GET") v, err := strconv.Atoi(s1) assertNilE(t, err) assertEqualE(t, v, 424821, "did not return the right file size") @@ -747,13 +773,27 @@ func TestPutGetLargeFile(t *testing.T) { assertEqualE(t, s3, "") } - // convert the compressed stream to string - var contents string - gz, err := gzip.NewReader(&streamBuf) - assertNilE(t, err) + // convert the compressed contents to string + var gz *gzip.Reader + if isStream { + gz, err = gzip.NewReader(&streamBuf) + assertNilE(t, err) + } else { + downloadedFile := filepath.Join(tmpDir, fnamegz) + f, err := os.Open(downloadedFile) + assertNilE(t, err) + defer func() { + assertNilF(t, f.Close()) + }() + + gz, err = gzip.NewReader(f) + assertNilE(t, err) + } defer func() { assertNilF(t, gz.Close()) }() + + var contents string for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { @@ -767,12 +807,14 @@ func TestPutGetLargeFile(t *testing.T) { } } - // verify the downloaded stream with the original file - fname := filepath.Join(sourceDir, "/test_data/largefile.txt") - f, err := os.Open(fname) - assertNilE(t, err) - defer f.Close() - originalContents, err := io.ReadAll(f) + // verify the downloaded contents with the original file + originalFile, err := os.Open(fname) + assertNilF(t, err) + defer func() { + assertNilF(t, originalFile.Close()) + }() + + originalContents, err := io.ReadAll(originalFile) assertNilE(t, err) assertEqualF(t, contents, string(originalContents), "data did not match content") })