Skip to content

Commit

Permalink
SNOW-985356 Do not start chunk downloader when first batch causes err…
Browse files Browse the repository at this point in the history
…or (#993)
  • Loading branch information
sfc-gh-pfus authored Dec 7, 2023
1 parent cf94c15 commit a4c3557
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
8 changes: 4 additions & 4 deletions arrow_chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[
}

// Build arrow chunk based on RowSet of base64
func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) arrowResultChunk {
func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) (arrowResultChunk, error) {
rowSetBytes, err := base64.StdEncoding.DecodeString(rowsetBase64)
if err != nil {
return arrowResultChunk{}
return arrowResultChunk{}, err
}
rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes), ipc.WithAllocator(alloc))
if err != nil {
return arrowResultChunk{}
return arrowResultChunk{}, err
}

return arrowResultChunk{rr, 0, loc, alloc}
return arrowResultChunk{rr, 0, loc, alloc}, nil
}
10 changes: 8 additions & 2 deletions chunk_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ func (scd *snowflakeChunkDownloader) start() error {
if scd.sc != nil && scd.sc.cfg != nil {
loc = getCurrentLocation(scd.sc.cfg.Params)
}
firstArrowChunk := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
if err != nil {
return err
}
higherPrecision := higherPrecisionEnabled(scd.ctx)
scd.CurrentChunk, err = firstArrowChunk.decodeArrowChunk(scd.RowSet.RowType, higherPrecision)
scd.CurrentChunkSize = firstArrowChunk.rowCount
Expand Down Expand Up @@ -274,7 +277,10 @@ func (scd *snowflakeChunkDownloader) startArrowBatches() error {
if scd.sc != nil && scd.sc.cfg != nil {
loc = getCurrentLocation(scd.sc.cfg.Params)
}
firstArrowChunk := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
if err != nil {
return err
}
scd.FirstBatch = &ArrowBatch{
idx: 0,
scd: scd,
Expand Down
28 changes: 28 additions & 0 deletions chunk_downloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package gosnowflake

import (
"context"
"testing"
)

func TestChunkDownloaderDoesNotStartWhenArrowParsingCausesError(t *testing.T) {
tcs := []string{
"invalid base64",
"aW52YWxpZCBhcnJvdw==", // valid base64, but invalid arrow
}
for _, tc := range tcs {
t.Run(tc, func(t *testing.T) {
scd := snowflakeChunkDownloader{
ctx: context.Background(),
QueryResultFormat: "arrow",
RowSet: rowSetType{
RowSetBase64: tc,
},
}

err := scd.start()

assertNotNilF(t, err)
})
}
}

0 comments on commit a4c3557

Please sign in to comment.