Skip to content

Commit

Permalink
SNOW-1859664 use correct transport for calling cloud providers (#1288)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dszmolka authored Jan 17, 2025
1 parent 7f77aea commit 72a121f
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 28 deletions.
12 changes: 6 additions & 6 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bo
RetryDelay: 2 * time.Second,
},
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: getTransport(util.cfg),
},
},
})
Expand All @@ -74,7 +74,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
return nil, err
}
path := azureLoc.path + strings.TrimLeft(filename, "/")
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return nil, &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -188,7 +188,7 @@ func (util *snowflakeAzureClient) uploadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)

if err != nil {
return &SnowflakeError{
Expand Down Expand Up @@ -273,7 +273,7 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -348,10 +348,10 @@ func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Respons
strings.Contains(errStr, "Server failed to authenticate the request")
}

func createContainerClient(clientURL string) (*container.Client, error) {
func createContainerClient(clientURL string, cfg *Config) (*container.Client, error) {
return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: getTransport(cfg),
},
}})
}
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (t *DummyTransport) RoundTrip(r *http.Request) (*http.Response, error) {
}
return &http.Response{StatusCode: 200}, nil
}
return snowflakeInsecureTransport.RoundTrip(r)
return snowflakeNoOcspTransport.RoundTrip(r)
}

func TestInternalClient(t *testing.T) {
Expand Down
20 changes: 19 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err
if sc.cfg.Transporter == nil {
if sc.cfg.DisableOCSPChecks || sc.cfg.InsecureMode {
// no revocation check with OCSP. Think twice when you want to enable this option.
st = snowflakeInsecureTransport
st = snowflakeNoOcspTransport
} else {
// set OCSP fail open mode
ocspResponseCacheLock.Lock()
Expand Down Expand Up @@ -856,3 +856,21 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err

return sc, nil
}

func getTransport(cfg *Config) http.RoundTripper {
if cfg == nil {
logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
// if user configured a custom Transporter, prioritize that
if cfg.Transporter != nil {
logger.Debug("getTransport: using Transporter configured by the user")
return cfg.Transporter
}
if cfg.DisableOCSPChecks || cfg.InsecureMode {
logger.Debug("getTransport: skipping OCSP validation for cloud storage")
return snowflakeNoOcspTransport
}
logger.Debug("getTransport: will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
53 changes: 53 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,56 @@ func TestBeginCreatesTransaction(t *testing.T) {
}
})
}

type EmptyTransporter struct{}

func (t EmptyTransporter) RoundTrip(*http.Request) (*http.Response, error) {
return nil, nil
}

func TestGetTransport(t *testing.T) {
testcases := []struct {
name string
cfg *Config
transport http.RoundTripper
}{
{
name: "DisableOCSPChecks and InsecureMode false",
cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false},
transport: SnowflakeTransport,
},
{
name: "DisableOCSPChecks true and InsecureMode false",
cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false},
transport: snowflakeNoOcspTransport,
},
{
name: "DisableOCSPChecks false and InsecureMode true",
cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true},
transport: snowflakeNoOcspTransport,
},
{
name: "DisableOCSPChecks and InsecureMode missing from Config",
cfg: &Config{Account: "four"},
transport: SnowflakeTransport,
},
{
name: "whole Config is missing",
cfg: nil,
transport: SnowflakeTransport,
},
{
name: "Using custom Transporter",
cfg: &Config{Account: "five", DisableOCSPChecks: true, InsecureMode: false, Transporter: EmptyTransporter{}},
transport: EmptyTransporter{},
},
}
for _, test := range testcases {
t.Run(test.name, func(t *testing.T) {
result := getTransport(test.cfg)
if test.transport != result {
t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result)
}
})
}
}
2 changes: 1 addition & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,7 @@ type CountingTransport struct {

func (t *CountingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
t.requests++
return snowflakeInsecureTransport.RoundTrip(r)
return snowflakeNoOcspTransport.RoundTrip(r)
}

func TestOpenWithTransport(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,15 @@ type s3BucketAccelerateConfigGetter interface {

type s3ClientCreator interface {
extractBucketNameAndPath(location string) (*s3Location, error)
createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
}

func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error {
s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location)
if err != nil {
return err
}
s3Cli, err := s3Util.createClient(sfa.stageInfo, false)
s3Cli, err := s3Util.createClientWithConfig(sfa.stageInfo, false, sfa.sc.cfg)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) {

type s3ClientCreatorMock struct {
extract func(string) (*s3Location, error)
create func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
create func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
}

func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) {
return mock.extract(location)
}

func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return mock.create(info, useAccelerateEndpoint)
func (mock *s3ClientCreatorMock) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return mock.create(info, useAccelerateEndpoint, cfg)
}

type s3BucketAccelerateConfigGetterMock struct {
Expand Down Expand Up @@ -96,7 +96,7 @@ func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil
},
})
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return nil, errors.New("failed creation")
},
})
Expand All @@ -172,7 +172,7 @@ func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return 1, nil
},
})
Expand Down
10 changes: 5 additions & 5 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -226,7 +226,7 @@ func (util *snowflakeGcsClient) uploadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -307,7 +307,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -409,9 +409,9 @@ func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool {
return resp.StatusCode == 401
}

func newGcsClient() gcsAPI {
func newGcsClient(cfg *Config) gcsAPI {
return &http.Client{
Transport: SnowflakeTransport,
Transport: getTransport(cfg),
}
}

Expand Down
8 changes: 4 additions & 4 deletions ocsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate)
}
ocspClient := &http.Client{
Timeout: timeout,
Transport: snowflakeInsecureTransport,
Transport: snowflakeNoOcspTransport,
}
ocspRes, ocspResBytes, ocspS := retryOCSP(
ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout)
Expand Down Expand Up @@ -786,7 +786,7 @@ func downloadOCSPCacheServer() {
}
ocspClient := &http.Client{
Timeout: timeout,
Transport: snowflakeInsecureTransport,
Transport: snowflakeNoOcspTransport,
}
ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout)
if ocspStatus.code != ocspSuccess {
Expand Down Expand Up @@ -1075,8 +1075,8 @@ func init() {
initOCSPCache()
}

// snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check.
var snowflakeInsecureTransport = &http.Transport{
// snowflakeNoOcspTransport is the transport object that doesn't do certificate revocation check with OCSP.
var snowflakeNoOcspTransport = &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Minute,
Proxy: http.ProxyFromEnvironment,
Expand Down
2 changes: 1 addition & 1 deletion ocsp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestOCSP(t *testing.T) {
}

transports := []*http.Transport{
snowflakeInsecureTransport,
snowflakeNoOcspTransport,
SnowflakeTransport,
}

Expand Down
9 changes: 8 additions & 1 deletion s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,20 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
BaseEndpoint: endPoint,
UseAccelerate: useAccelerateEndpoint,
HTTPClient: &http.Client{
Transport: SnowflakeTransport,
Transport: getTransport(util.cfg),
},
ClientLogMode: S3LoggingMode,
Logger: s3Logger,
}), nil
}

// to be used with S3 transferAccelerateConfigWithUtil
func (util *snowflakeS3Client) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
// copy snowflakeFileTransferAgent's config onto the cloud client so we could decide which Transport to use
util.cfg = cfg
return util.createClient(info, useAccelerateEndpoint)
}

func getS3CustomEndpoint(info *execResponseStageInfo) *string {
var endPoint *string
isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL
Expand Down

0 comments on commit 72a121f

Please sign in to comment.