diff --git a/gcs_storage_client.go b/gcs_storage_client.go index b45c51504..8558094ba 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -20,6 +20,7 @@ const ( gcsMetadataMatdescKey = gcsMetadataPrefix + "matdesc" gcsMetadataEncryptionDataProp = gcsMetadataPrefix + "encryptiondata" gcsFileHeaderDigest = "gcs-file-header-digest" + gcsRegionMeCentral2 = "me-central2" ) type snowflakeGcsClient struct { @@ -52,7 +53,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin if meta.presignedURL != nil { meta.resStatus = notFoundFile } else { - URL, err := util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(filename, "/")) + URL, err := util.generateFileURL(meta.stageInfo, strings.TrimLeft(filename, "/")) if err != nil { return nil, err } @@ -147,7 +148,7 @@ func (util *snowflakeGcsClient) uploadFile( var err error if uploadURL == nil { - uploadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.dstFileName, "/")) + uploadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.dstFileName, "/")) if err != nil { return err } @@ -279,7 +280,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile( gcsHeaders := make(map[string]string) if downloadURL == nil || downloadURL.String() == "" { - downloadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.srcFileName, "/")) + downloadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.srcFileName, "/")) if err != nil { return err } @@ -388,10 +389,11 @@ func (util *snowflakeGcsClient) extractBucketNameAndPath(location string) *gcsLo return &gcsLocation{containerName, path} } -func (util *snowflakeGcsClient) generateFileURL(stageLocation string, filename string) (*url.URL, error) { - gcsLoc := util.extractBucketNameAndPath(stageLocation) +func (util *snowflakeGcsClient) generateFileURL(stageInfo *execResponseStageInfo, filename string) (*url.URL, error) { + gcsLoc := util.extractBucketNameAndPath(stageInfo.Location) fullFilePath := gcsLoc.path + filename - URL, err := url.Parse("https://storage.googleapis.com/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath)) + endPoint := getGcsCustomEndpoint(stageInfo) + URL, err := url.Parse(endPoint + "/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath)) if err != nil { return nil, err } @@ -407,3 +409,16 @@ func newGcsClient() gcsAPI { Transport: SnowflakeTransport, } } + +func getGcsCustomEndpoint(info *execResponseStageInfo) string { + endpoint := "https://storage.googleapis.com" + + // TODO: SNOW-1789759 hardcoded region will be replaced in the future + isRegionalURLEnabled := (strings.ToLower(info.Region) == gcsRegionMeCentral2) || info.UseRegionalURL + if info.EndPoint != "" { + endpoint = fmt.Sprintf("https://%s", info.EndPoint) + } else if info.Region != "" && isRegionalURLEnabled { + endpoint = fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(info.Region)) + } + return endpoint +} diff --git a/gcs_storage_client_test.go b/gcs_storage_client_test.go index 3c360f7f7..88c24176e 100644 --- a/gcs_storage_client_test.go +++ b/gcs_storage_client_test.go @@ -105,7 +105,9 @@ func TestGenerateFileURL(t *testing.T) { } for _, test := range testcases { t.Run(test.location, func(t *testing.T) { - gcsURL, err := gcsUtil.generateFileURL(test.location, test.fname) + stageInfo := &execResponseStageInfo{} + stageInfo.Location = test.location + gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname) if err != nil { t.Error(err) } @@ -1126,3 +1128,93 @@ func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) { t.Error("should have raised an error") } } + +func TestGetGcsCustomEndpoint(t *testing.T) { + testcases := []struct { + desc string + in execResponseStageInfo + out string + }{ + { + desc: "when the endPoint is not specified and UseRegionalURL is false", + in: execResponseStageInfo{ + UseRegionalURL: false, + EndPoint: "", + Region: "WEST-1", + }, + out: "https://storage.googleapis.com", + }, + { + desc: "when the useRegionalURL is only enabled", + in: execResponseStageInfo{ + UseRegionalURL: true, + EndPoint: "", + Region: "mockLocation", + }, + out: "https://storage.mocklocation.rep.googleapis.com", + }, + { + desc: "when the region is me-central2", + in: execResponseStageInfo{ + UseRegionalURL: false, + EndPoint: "", + Region: "me-central2", + }, + out: "https://storage.me-central2.rep.googleapis.com", + }, + { + desc: "when the region is me-central2 (mixed case)", + in: execResponseStageInfo{ + UseRegionalURL: false, + EndPoint: "", + Region: "ME-cEntRal2", + }, + out: "https://storage.me-central2.rep.googleapis.com", + }, + { + desc: "when the region is me-central2 (uppercase)", + in: execResponseStageInfo{ + UseRegionalURL: false, + EndPoint: "", + Region: "ME-CENTRAL2", + }, + out: "https://storage.me-central2.rep.googleapis.com", + }, + { + desc: "when the endPoint is specified", + in: execResponseStageInfo{ + UseRegionalURL: false, + EndPoint: "storage.specialEndPoint.rep.googleapis.com", + Region: "ME-cEntRal1", + }, + out: "https://storage.specialEndPoint.rep.googleapis.com", + }, + { + desc: "when both the endPoint and the useRegionalUrl are specified", + in: execResponseStageInfo{ + UseRegionalURL: true, + EndPoint: "storage.specialEndPoint.rep.googleapis.com", + Region: "ME-cEntRal1", + }, + out: "https://storage.specialEndPoint.rep.googleapis.com", + }, + { + desc: "when both the endPoint is specified and the region is me-central2", + in: execResponseStageInfo{ + UseRegionalURL: true, + EndPoint: "storage.specialEndPoint.rep.googleapis.com", + Region: "ME-CENTRAL2", + }, + out: "https://storage.specialEndPoint.rep.googleapis.com", + }, + } + + for _, test := range testcases { + t.Run(test.desc, func(t *testing.T) { + endpoint := getGcsCustomEndpoint(&test.in) + if endpoint != test.out { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, endpoint) + } + }) + } +} diff --git a/query.go b/query.go index d9234ffbe..96c788ffe 100644 --- a/query.go +++ b/query.go @@ -108,6 +108,8 @@ type execResponseStageInfo struct { Creds execResponseCredentials `json:"creds,omitempty"` PresignedURL string `json:"presignedUrl,omitempty"` EndPoint string `json:"endPoint,omitempty"` + UseS3RegionalURL bool `json:"useS3RegionalUrl,omitempty"` + UseRegionalURL bool `json:"useRegionalUrl,omitempty"` } // make all data field optional diff --git a/s3_storage_client.go b/s3_storage_client.go index b27cef302..35d962bfc 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -7,16 +7,17 @@ import ( "context" "errors" "fmt" + "io" + "net/http" + "os" + "strings" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" - "io" - "net/http" - "os" - "strings" ) const ( @@ -47,12 +48,7 @@ var S3LoggingMode aws.ClientLogMode func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { stageCredentials := info.Creds s3Logger := logging.LoggerFunc(s3LoggingFunc) - - var endpoint *string - if info.EndPoint != "" { - tmp := "https://" + info.EndPoint - endpoint = &tmp - } + endPoint := getS3CustomEndpoint(info) return s3.New(s3.Options{ Region: info.Region, @@ -60,7 +56,7 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce stageCredentials.AwsKeyID, stageCredentials.AwsSecretKey, stageCredentials.AwsToken)), - BaseEndpoint: endpoint, + BaseEndpoint: endPoint, UseAccelerate: useAccelerateEndpoint, HTTPClient: &http.Client{ Transport: SnowflakeTransport, @@ -70,6 +66,23 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce }), nil } +func getS3CustomEndpoint(info *execResponseStageInfo) *string { + var endPoint *string + isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL + if info.EndPoint != "" { + tmp := fmt.Sprintf("https://%s", info.EndPoint) + endPoint = &tmp + } else if info.Region != "" && isRegionalURLEnabled { + domainSuffixForRegionalURL := "amazonaws.com" + if strings.HasPrefix(strings.ToLower(info.Region), "cn-") { + domainSuffixForRegionalURL = "amazonaws.com.cn" + } + tmp := fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL) + endPoint = &tmp + } + return endPoint +} + func s3LoggingFunc(classification logging.Classification, format string, v ...interface{}) { switch classification { case logging.Debug: diff --git a/s3_storage_client_test.go b/s3_storage_client_test.go index a9c3eb5c9..d7db9de2c 100644 --- a/s3_storage_client_test.go +++ b/s3_storage_client_test.go @@ -793,3 +793,102 @@ func TestConvertContentLength(t *testing.T) { }) } } + +func TestGetS3Endpoint(t *testing.T) { + testcases := []struct { + desc string + in execResponseStageInfo + out string + }{ + + { + desc: "when UseRegionalURL is valid and the region does not start with cn-", + in: execResponseStageInfo{ + UseS3RegionalURL: false, + UseRegionalURL: true, + EndPoint: "", + Region: "WEST-1", + }, + out: "https://s3.WEST-1.amazonaws.com", + }, + { + desc: "when UseS3RegionalURL is valid and the region does not start with cn-", + in: execResponseStageInfo{ + UseS3RegionalURL: true, + UseRegionalURL: false, + EndPoint: "", + Region: "WEST-1", + }, + out: "https://s3.WEST-1.amazonaws.com", + }, + { + desc: "when endPoint is enabled and the region does not start with cn-", + in: execResponseStageInfo{ + UseS3RegionalURL: false, + UseRegionalURL: false, + EndPoint: "s3.endpoint", + Region: "mockLocation", + }, + out: "https://s3.endpoint", + }, + { + desc: "when endPoint is enabled and the region starts with cn-", + in: execResponseStageInfo{ + UseS3RegionalURL: false, + UseRegionalURL: false, + EndPoint: "s3.endpoint", + Region: "cn-mockLocation", + }, + out: "https://s3.endpoint", + }, + { + desc: "when useS3RegionalURL is valid and domain starts with cn", + in: execResponseStageInfo{ + UseS3RegionalURL: true, + UseRegionalURL: false, + EndPoint: "", + Region: "cn-mockLocation", + }, + out: "https://s3.cn-mockLocation.amazonaws.com.cn", + }, + { + desc: "when useRegionalURL is valid and domain starts with cn", + in: execResponseStageInfo{ + UseS3RegionalURL: true, + UseRegionalURL: false, + EndPoint: "", + Region: "cn-mockLocation", + }, + out: "https://s3.cn-mockLocation.amazonaws.com.cn", + }, + { + desc: "when useRegionalURL is valid and domain starts with cn", + in: execResponseStageInfo{ + UseS3RegionalURL: true, + UseRegionalURL: false, + EndPoint: "", + Region: "cn-mockLocation", + }, + out: "https://s3.cn-mockLocation.amazonaws.com.cn", + }, + { + desc: "when endPoint is specified, both UseRegionalURL and useS3PRegionalUrl are valid, and the region starts with cn", + in: execResponseStageInfo{ + UseS3RegionalURL: true, + UseRegionalURL: true, + EndPoint: "s3.endpoint", + Region: "cn-mockLocation", + }, + out: "https://s3.endpoint", + }, + } + + for _, test := range testcases { + t.Run(test.desc, func(t *testing.T) { + endpoint := getS3CustomEndpoint(&test.in) + if *endpoint != test.out { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, *endpoint) + } + }) + } +}