Skip to content

Commit

Permalink
change put get test to test with stream and non-stream to increase co…
Browse files Browse the repository at this point in the history
…decov
  • Loading branch information
sfc-gh-ext-simba-jl committed Jan 3, 2025
1 parent 7c8944b commit 624cafc
Showing 1 changed file with 61 additions and 19 deletions.
80 changes: 61 additions & 19 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,18 +699,39 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) {
})
}

func TestPutGetLargeFile(t *testing.T) {
func TestPutGetLargeFileNonStream(t *testing.T) {
testPutGetLargeFile(t, false)
}

func TestPutGetLargeFileStream(t *testing.T) {
testPutGetLargeFile(t, true)
}

func testPutGetLargeFile(t *testing.T, isStream bool) {
sourceDir, err := os.Getwd()
assertNilF(t, err)
fname := filepath.Join(sourceDir, "/test_data/largefile.txt")
fnamegz := "largefile.txt.gz"

runDBTest(t, func(dbt *DBTest) {
stageDir := "test_put_largefile_" + randomString(10)
dbt.mustExec("rm @~/" + stageDir)

ctx := context.Background()
if isStream {
f, err := os.Open(fname)
assertNilF(t, err)
defer func() {
assertNilF(t, f.Close())
}()
ctx = WithFileStream(ctx, f)
}

// PUT test
putQuery := fmt.Sprintf("put file://%v/test_data/largefile.txt @~/%v", sourceDir, stageDir)
putQuery := fmt.Sprintf("put file://%v @~/%v", fname, stageDir)
sqlText := strings.ReplaceAll(putQuery, "\\", "\\\\")
dbt.mustExec(sqlText)
_ = dbt.mustExecContext(ctx, sqlText)

defer dbt.mustExec("rm @~/" + stageDir)
rows := dbt.mustQuery("ls @~/" + stageDir)
defer func() {
Expand All @@ -722,15 +743,20 @@ func TestPutGetLargeFile(t *testing.T) {
assertNilF(t, err)
}

if !strings.Contains(file, "largefile.txt.gz") {
if !strings.Contains(file, fnamegz) {
t.Fatalf("should contain file. got: %v", file)
}

// GET test with stream
// GET test
var streamBuf bytes.Buffer
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{GetFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)
sql := fmt.Sprintf("get @~/%v/largefile.txt.gz 'file://%v'", stageDir, t.TempDir())
ctx = context.Background()
if isStream {
ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{GetFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)
}

tmpDir := t.TempDir()
sql := fmt.Sprintf("get @~/%v/%v 'file://%v'", stageDir, fnamegz, tmpDir)
sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
rows2 := dbt.mustQueryContext(ctx, sqlText)
defer func() {
Expand All @@ -739,21 +765,35 @@ func TestPutGetLargeFile(t *testing.T) {
for rows2.Next() {
err = rows2.Scan(&file, &s1, &s2, &s3)
assertNilE(t, err)
assertTrueE(t, strings.HasPrefix(file, "largefile.txt.gz"), "a file was not downloaded by GET")
assertTrueE(t, strings.HasPrefix(file, fnamegz), "a file was not downloaded by GET")
v, err := strconv.Atoi(s1)
assertNilE(t, err)
assertEqualE(t, v, 424821, "did not return the right file size")
assertEqualE(t, s2, "DOWNLOADED", "did not return DOWNLOADED status")
assertEqualE(t, s3, "")
}

// convert the compressed stream to string
var contents string
gz, err := gzip.NewReader(&streamBuf)
assertNilE(t, err)
// convert the compressed contents to string
var gz *gzip.Reader
if isStream {
gz, err = gzip.NewReader(&streamBuf)
assertNilE(t, err)
} else {
downloadedFile := filepath.Join(tmpDir, fnamegz)
f, err := os.Open(downloadedFile)
assertNilE(t, err)
defer func() {
assertNilF(t, f.Close())
}()

gz, err = gzip.NewReader(f)
assertNilE(t, err)
}
defer func() {
assertNilF(t, gz.Close())
}()

var contents string
for {
c := make([]byte, defaultChunkBufferSize)
if n, err := gz.Read(c); err != nil {
Expand All @@ -767,12 +807,14 @@ func TestPutGetLargeFile(t *testing.T) {
}
}

// verify the downloaded stream with the original file
fname := filepath.Join(sourceDir, "/test_data/largefile.txt")
f, err := os.Open(fname)
assertNilE(t, err)
defer f.Close()
originalContents, err := io.ReadAll(f)
// verify the downloaded contents with the original file
originalFile, err := os.Open(fname)
assertNilF(t, err)
defer func() {
assertNilF(t, originalFile.Close())
}()

originalContents, err := io.ReadAll(originalFile)
assertNilE(t, err)
assertEqualF(t, contents, string(originalContents), "data did not match content")
})
Expand Down

0 comments on commit 624cafc

Please sign in to comment.