From 096292303b7c354e87817bdf8e9da5f0410afa17 Mon Sep 17 00:00:00 2001 From: sfc-gh-dszmolka Date: Sun, 12 Jan 2025 14:48:01 +0100 Subject: [PATCH] SNOW-1859664 * pass Config to cloud storage functions, to be able to create OCSP-verifying and non-OCSP-verifying cloud clients, depending on user's configuration * modify existing tests accordingly, to also pass Config --- azure_storage_client.go | 16 +++++++------ azure_storage_client_test.go | 24 ++++++++++++------- file_transfer_agent.go | 4 ++-- file_transfer_agent_test.go | 16 ++++++------- gcs_storage_client.go | 16 ++++++++----- gcs_storage_client_test.go | 36 +++++++++++++++++++---------- put_get_with_aws_test.go | 6 ++--- s3_storage_client.go | 5 ++-- s3_storage_client_test.go | 33 +++++++++++++++++--------- storage_client.go | 4 ++-- util.go | 15 ++++++++++++ util_test.go | 45 ++++++++++++++++++++++++++++++++++++ 12 files changed, 159 insertions(+), 61 deletions(-) diff --git a/azure_storage_client.go b/azure_storage_client.go index de611fcb3..80d3742fc 100644 --- a/azure_storage_client.go +++ b/azure_storage_client.go @@ -39,12 +39,13 @@ type azureAPI interface { GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) } -func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) { +func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool, cfg *Config) (cloudClient, error) { sasToken := info.Creds.AzureSasToken u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken)) if err != nil { return nil, err } + transport := getTransport(cfg) client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{ ClientOptions: azcore.ClientOptions{ Retry: policy.RetryOptions{ @@ -52,7 +53,7 @@ func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bo RetryDelay: 2 * time.Second, }, Transport: &http.Client{ - Transport: SnowflakeTransport, + Transport: transport, }, }, }) @@ -74,7 +75,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str return nil, err } path := azureLoc.path + strings.TrimLeft(filename, "/") - containerClient, err := createContainerClient(client.URL()) + containerClient, err := createContainerClient(client.URL(), util.cfg) if err != nil { return nil, &SnowflakeError{ Message: "failed to create container client", @@ -188,7 +189,7 @@ func (util *snowflakeAzureClient) uploadFile( Message: "failed to cast to azure client", } } - containerClient, err := createContainerClient(client.URL()) + containerClient, err := createContainerClient(client.URL(), util.cfg) if err != nil { return &SnowflakeError{ @@ -273,7 +274,7 @@ func (util *snowflakeAzureClient) nativeDownloadFile( Message: "failed to cast to azure client", } } - containerClient, err := createContainerClient(client.URL()) + containerClient, err := createContainerClient(client.URL(), util.cfg) if err != nil { return &SnowflakeError{ Message: "failed to create container client", @@ -348,10 +349,11 @@ func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Respons strings.Contains(errStr, "Server failed to authenticate the request") } -func createContainerClient(clientURL string) (*container.Client, error) { +func createContainerClient(clientURL string, cfg *Config) (*container.Client, error) { + transport := getTransport(cfg) return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{ Transport: &http.Client{ - Transport: SnowflakeTransport, + Transport: transport, }, }}) } diff --git a/azure_storage_client_test.go b/azure_storage_client_test.go index 1490e3a6d..764aefab8 100644 --- a/azure_storage_client_test.go +++ b/azure_storage_client_test.go @@ -152,7 +152,8 @@ func TestUploadFileWithAzureUploadFailedError(t *testing.T) { SMKID: 92019681909886, } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -210,7 +211,8 @@ func TestUploadStreamWithAzureUploadFailedError(t *testing.T) { SMKID: 92019681909886, } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -273,7 +275,8 @@ func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) { panic(err) } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -349,7 +352,8 @@ func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) { panic(err) } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -412,7 +416,8 @@ func TestDownloadOneFileToAzureFailed(t *testing.T) { t.Error(err) } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -456,7 +461,8 @@ func TestGetFileHeaderErrorStatus(t *testing.T) { LocationType: "AZURE", } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -558,7 +564,8 @@ func TestUploadFileToAzureClientCastFail(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -600,7 +607,8 @@ func TestAzureGetHeaderClientCastFail(t *testing.T) { Location: "azblob/rwyi-testacco/users/9220/", LocationType: "AZURE", } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } diff --git a/file_transfer_agent.go b/file_transfer_agent.go index c30f9868c..ccae18716 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -597,7 +597,7 @@ type s3BucketAccelerateConfigGetter interface { type s3ClientCreator interface { extractBucketNameAndPath(location string) (*s3Location, error) - createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) + createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) } func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error { @@ -605,7 +605,7 @@ func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s if err != nil { return err } - s3Cli, err := s3Util.createClient(sfa.stageInfo, false) + s3Cli, err := s3Util.createClient(sfa.stageInfo, false, sfa.sc.cfg) if err != nil { return err } diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index bf6c6a5bc..51b791830 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -53,15 +53,15 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) { type s3ClientCreatorMock struct { extract func(string) (*s3Location, error) - create func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) + create func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) } func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) { return mock.extract(location) } -func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { - return mock.create(info, useAccelerateEndpoint) +func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { + return mock.create(info, useAccelerateEndpoint, cfg) } type s3BucketAccelerateConfigGetterMock struct { @@ -96,7 +96,7 @@ func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) { extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, - create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { + create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil }, }) @@ -146,7 +146,7 @@ func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) { extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, - create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { + create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { return nil, errors.New("failed creation") }, }) @@ -172,7 +172,7 @@ func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) { extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, - create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { + create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { return 1, nil }, }) @@ -472,7 +472,7 @@ func TestUpdateMetadataWithPresignedUrl(t *testing.T) { }, nil } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, sct.sc.cfg) if err != nil { t.Error(err) } @@ -525,7 +525,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, sct.sc.cfg) if err != nil { t.Error(err) } diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 8558094ba..8e59356f7 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -32,7 +32,10 @@ type gcsLocation struct { path string } -func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) { +func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool, cfg *Config) (cloudClient, error) { + // we don't seem to actually return the client from createClient here, but to implement + // the interface, need to have the same spec as in snowflakeS3Client and snowflakeAzureClient + _ = cfg if info.Creds.GcsAccessToken != "" { logger.Debug("Using GCS downscoped token") return info.Creds.GcsAccessToken, nil @@ -73,7 +76,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin for k, v := range gcsHeaders { req.Header.Add(k, v) } - client := newGcsClient() + client := newGcsClient(util.cfg) // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient @@ -221,7 +224,7 @@ func (util *snowflakeGcsClient) uploadFile( for k, v := range gcsHeaders { req.Header.Add(k, v) } - client := newGcsClient() + client := newGcsClient(util.cfg) // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient @@ -302,7 +305,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile( for k, v := range gcsHeaders { req.Header.Add(k, v) } - client := newGcsClient() + client := newGcsClient(util.cfg) // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient @@ -404,9 +407,10 @@ func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool { return resp.StatusCode == 401 } -func newGcsClient() gcsAPI { +func newGcsClient(cfg *Config) gcsAPI { + transport := getTransport(cfg) return &http.Client{ - Transport: SnowflakeTransport, + Transport: transport, } } diff --git a/gcs_storage_client_test.go b/gcs_storage_client_test.go index 88c24176e..f0b18431e 100644 --- a/gcs_storage_client_test.go +++ b/gcs_storage_client_test.go @@ -141,7 +141,8 @@ func TestUploadFileWithGcsUploadFailedError(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -201,7 +202,8 @@ func TestUploadFileWithGcsUploadFailedWithRetry(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -268,7 +270,8 @@ func TestUploadFileWithGcsUploadFailedWithTokenExpired(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -329,7 +332,8 @@ func TestDownloadOneFileFromGcsFailed(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -375,7 +379,8 @@ func TestDownloadOneFileFromGcsFailedWithRetry(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -432,7 +437,8 @@ func TestDownloadOneFileFromGcsFailedWithTokenExpired(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -489,7 +495,8 @@ func TestDownloadOneFileFromGcsFailedWithFileNotFound(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -735,7 +742,8 @@ func TestUploadStreamFailed(t *testing.T) { initialParallel := int64(100) src := []byte{65, 66, 67} - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -785,7 +793,8 @@ func TestUploadFileWithBadRequest(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -942,7 +951,8 @@ func TestUploadFileToGcsNoStatus(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -1000,7 +1010,8 @@ func TestDownloadFileFromGcsError(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } @@ -1049,7 +1060,8 @@ func TestDownloadFileWithBadRequest(t *testing.T) { t.Error(err) } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + sgc := new(snowflakeGcsClient) + gcsCli, err := sgc.createClient(&info, false, sgc.cfg) if err != nil { t.Error(err) } diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index dde9676fa..58a81e56a 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -127,7 +127,7 @@ func TestPutWithInvalidToken(t *testing.T) { } s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false, s3Util.cfg) if err != nil { t.Error(err) } @@ -170,7 +170,7 @@ func TestPutWithInvalidToken(t *testing.T) { AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, }, } - s3Cli, err = s3Util.createClient(&info, false) + s3Cli, err = s3Util.createClient(&info, false, s3Util.cfg) if err != nil { t.Error(err) } @@ -226,7 +226,7 @@ func TestPretendToPutButList(t *testing.T) { } s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false, s3Util.cfg) if err != nil { t.Error(err) } diff --git a/s3_storage_client.go b/s3_storage_client.go index 35d962bfc..e98b871ae 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -45,10 +45,11 @@ type s3Location struct { // See https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#ClientLogMode for allowed values. var S3LoggingMode aws.ClientLogMode -func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { +func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { stageCredentials := info.Creds s3Logger := logging.LoggerFunc(s3LoggingFunc) endPoint := getS3CustomEndpoint(info) + transport := getTransport(cfg) return s3.New(s3.Options{ Region: info.Region, @@ -59,7 +60,7 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce BaseEndpoint: endPoint, UseAccelerate: useAccelerateEndpoint, HTTPClient: &http.Client{ - Transport: SnowflakeTransport, + Transport: transport, }, ClientLogMode: S3LoggingMode, Logger: s3Logger, diff --git a/s3_storage_client_test.go b/s3_storage_client_test.go index d7db9de2c..f3d1db9d8 100644 --- a/s3_storage_client_test.go +++ b/s3_storage_client_test.go @@ -69,7 +69,8 @@ func TestUploadOneFileToS3WSAEConnAborted(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -145,7 +146,8 @@ func TestUploadOneFileToS3ConnReset(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -204,7 +206,8 @@ func TestUploadFileWithS3UploadFailedError(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -375,7 +378,8 @@ func TestDownloadFileWithS3TokenExpired(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -429,7 +433,8 @@ func TestDownloadFileWithS3ConnReset(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -482,7 +487,8 @@ func TestDownloadOneFileToS3WSAEConnAborted(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -536,7 +542,8 @@ func TestDownloadOneFileToS3Failed(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -587,7 +594,8 @@ func TestUploadFileToS3ClientCastFail(t *testing.T) { t.Error(err) } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -629,7 +637,8 @@ func TestGetHeaderClientCastFail(t *testing.T) { Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } - azureCli, err := new(snowflakeAzureClient).createClient(&info, false) + sac := new(snowflakeAzureClient) + azureCli, err := sac.createClient(&info, false, sac.cfg) if err != nil { t.Error(err) } @@ -666,7 +675,8 @@ func TestS3UploadRetryWithHeaderNotFound(t *testing.T) { t.Error(err) } - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } @@ -726,7 +736,8 @@ func TestS3UploadStreamFailed(t *testing.T) { initialParallel := int64(100) src := []byte{65, 66, 67} - s3Cli, err := new(snowflakeS3Client).createClient(&info, false) + ss3c := new(snowflakeS3Client) + s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg) if err != nil { t.Error(err) } diff --git a/storage_client.go b/storage_client.go index 316c5ad38..b2fcc9b1c 100644 --- a/storage_client.go +++ b/storage_client.go @@ -26,7 +26,7 @@ type storageUtil interface { // implemented by snowflakeS3Util, snowflakeAzureUtil and snowflakeGcsUtil type cloudUtil interface { - createClient(*execResponseStageInfo, bool) (cloudClient, error) + createClient(*execResponseStageInfo, bool, *Config) (cloudClient, error) getFileHeader(*fileMetadata, string) (*fileHeader, error) uploadFile(string, *fileMetadata, *encryptMetadata, int, int64) error nativeDownloadFile(*fileMetadata, string, int64) error @@ -58,7 +58,7 @@ func (rsu *remoteStorageUtil) getNativeCloudType(cli string, cfg *Config) cloudU // call cloud utils' native create client methods func (rsu *remoteStorageUtil) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { utilClass := rsu.getNativeCloudType(info.LocationType, cfg) - return utilClass.createClient(info, useAccelerateEndpoint) + return utilClass.createClient(info, useAccelerateEndpoint, cfg) } func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error { diff --git a/util.go b/util.go index 3319777dd..804d341d5 100644 --- a/util.go +++ b/util.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "math/rand" + "net/http" "os" "strings" "sync" @@ -348,3 +349,17 @@ func findByPrefix(in []string, prefix string) int { } return -1 } + +func getTransport(cfg *Config) *http.Transport { + if cfg == nil { + logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage") + return SnowflakeTransport + } else { + if cfg.DisableOCSPChecks || cfg.InsecureMode { + logger.Debug("getTransport: skipping OCSP validation for cloud storage") + return snowflakeInsecureTransport + } + logger.Debug("getTransport: will perform OCSP validation for cloud storage") + return SnowflakeTransport + } +} diff --git a/util_test.go b/util_test.go index bed222559..bb0902af9 100644 --- a/util_test.go +++ b/util_test.go @@ -7,6 +7,7 @@ import ( "database/sql/driver" "fmt" "math/rand" + "net/http" "os" "strconv" "sync" @@ -420,3 +421,47 @@ func TestFindByPrefix(t *testing.T) { assertEqualE(t, findByPrefix(nonEmpty, "dd"), -1) assertEqualE(t, findByPrefix([]string{}, "dd"), -1) } + +type tcConfigWithTransport struct { + name string + cfg *Config + transport *http.Transport +} + +func TestGetTransport(t *testing.T) { + testcases := []tcConfigWithTransport{ + { + name: "DisableOCSPChecks and InsecureMode false", + cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false}, + transport: SnowflakeTransport, + }, + { + name: "DisableOCSPChecks true and InsecureMode false", + cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false}, + transport: snowflakeInsecureTransport, + }, + { + name: "DisableOCSPChecks false and InsecureMode true", + cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true}, + transport: snowflakeInsecureTransport, + }, + { + name: "DisableOCSPChecks and InsecureMode missing from Config", + cfg: &Config{Account: "four"}, + transport: SnowflakeTransport, + }, + { + name: "whole Config is missing", + cfg: nil, + transport: SnowflakeTransport, + }, + } + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + result := getTransport(test.cfg) + if test.transport != result { + t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result) + } + }) + } +}