Skip to content

Commit

Permalink
SNOW-1859664
Browse files Browse the repository at this point in the history
* pass Config to cloud storage functions, to be able to create OCSP-verifying and non-OCSP-verifying cloud clients, depending on user's configuration
* modify existing tests accordingly, to also pass Config
  • Loading branch information
sfc-gh-dszmolka committed Jan 12, 2025
1 parent 8257f91 commit 0962923
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 61 deletions.
16 changes: 9 additions & 7 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,21 @@ type azureAPI interface {
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool, cfg *Config) (cloudClient, error) {
sasToken := info.Creds.AzureSasToken
u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken))
if err != nil {
return nil, err
}
transport := getTransport(cfg)
client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{
ClientOptions: azcore.ClientOptions{
Retry: policy.RetryOptions{
MaxRetries: 60,
RetryDelay: 2 * time.Second,
},
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
},
},
})
Expand All @@ -74,7 +75,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
return nil, err
}
path := azureLoc.path + strings.TrimLeft(filename, "/")
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return nil, &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -188,7 +189,7 @@ func (util *snowflakeAzureClient) uploadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)

if err != nil {
return &SnowflakeError{
Expand Down Expand Up @@ -273,7 +274,7 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -348,10 +349,11 @@ func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Respons
strings.Contains(errStr, "Server failed to authenticate the request")
}

func createContainerClient(clientURL string) (*container.Client, error) {
func createContainerClient(clientURL string, cfg *Config) (*container.Client, error) {
transport := getTransport(cfg)
return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
},
}})
}
24 changes: 16 additions & 8 deletions azure_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
SMKID: 92019681909886,
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -210,7 +211,8 @@ func TestUploadStreamWithAzureUploadFailedError(t *testing.T) {
SMKID: 92019681909886,
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -273,7 +275,8 @@ func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) {
panic(err)
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -349,7 +352,8 @@ func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) {
panic(err)
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -412,7 +416,8 @@ func TestDownloadOneFileToAzureFailed(t *testing.T) {
t.Error(err)
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -456,7 +461,8 @@ func TestGetFileHeaderErrorStatus(t *testing.T) {
LocationType: "AZURE",
}

azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
sac := new(snowflakeAzureClient)
azureCli, err := sac.createClient(&info, false, sac.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -558,7 +564,8 @@ func TestUploadFileToAzureClientCastFail(t *testing.T) {
t.Error(err)
}

s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
ss3c := new(snowflakeS3Client)
s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -600,7 +607,8 @@ func TestAzureGetHeaderClientCastFail(t *testing.T) {
Location: "azblob/rwyi-testacco/users/9220/",
LocationType: "AZURE",
}
s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
ss3c := new(snowflakeS3Client)
s3Cli, err := ss3c.createClient(&info, false, ss3c.cfg)
if err != nil {
t.Error(err)
}
Expand Down
4 changes: 2 additions & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,15 @@ type s3BucketAccelerateConfigGetter interface {

type s3ClientCreator interface {
extractBucketNameAndPath(location string) (*s3Location, error)
createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
}

func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error {
s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location)
if err != nil {
return err
}
s3Cli, err := s3Util.createClient(sfa.stageInfo, false)
s3Cli, err := s3Util.createClient(sfa.stageInfo, false, sfa.sc.cfg)
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) {

type s3ClientCreatorMock struct {
extract func(string) (*s3Location, error)
create func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
create func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
}

func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) {
return mock.extract(location)
}

func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return mock.create(info, useAccelerateEndpoint)
func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return mock.create(info, useAccelerateEndpoint, cfg)
}

type s3BucketAccelerateConfigGetterMock struct {
Expand Down Expand Up @@ -96,7 +96,7 @@ func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil
},
})
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return nil, errors.New("failed creation")
},
})
Expand All @@ -172,7 +172,7 @@ func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) {
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
return 1, nil
},
})
Expand Down Expand Up @@ -472,7 +472,7 @@ func TestUpdateMetadataWithPresignedUrl(t *testing.T) {
}, nil
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, sct.sc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -525,7 +525,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) {

testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123"

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, sct.sc.cfg)
if err != nil {
t.Error(err)
}
Expand Down
16 changes: 10 additions & 6 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ type gcsLocation struct {
path string
}

func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool, cfg *Config) (cloudClient, error) {
// we don't seem to actually return the client from createClient here, but to implement
// the interface, need to have the same spec as in snowflakeS3Client and snowflakeAzureClient
_ = cfg
if info.Creds.GcsAccessToken != "" {
logger.Debug("Using GCS downscoped token")
return info.Creds.GcsAccessToken, nil
Expand Down Expand Up @@ -73,7 +76,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -221,7 +224,7 @@ func (util *snowflakeGcsClient) uploadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -302,7 +305,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -404,9 +407,10 @@ func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool {
return resp.StatusCode == 401
}

func newGcsClient() gcsAPI {
func newGcsClient(cfg *Config) gcsAPI {
transport := getTransport(cfg)
return &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
}
}

Expand Down
36 changes: 24 additions & 12 deletions gcs_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ func TestUploadFileWithGcsUploadFailedError(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -201,7 +202,8 @@ func TestUploadFileWithGcsUploadFailedWithRetry(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -268,7 +270,8 @@ func TestUploadFileWithGcsUploadFailedWithTokenExpired(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -329,7 +332,8 @@ func TestDownloadOneFileFromGcsFailed(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -375,7 +379,8 @@ func TestDownloadOneFileFromGcsFailedWithRetry(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -432,7 +437,8 @@ func TestDownloadOneFileFromGcsFailedWithTokenExpired(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -489,7 +495,8 @@ func TestDownloadOneFileFromGcsFailedWithFileNotFound(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -735,7 +742,8 @@ func TestUploadStreamFailed(t *testing.T) {
initialParallel := int64(100)
src := []byte{65, 66, 67}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -785,7 +793,8 @@ func TestUploadFileWithBadRequest(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -942,7 +951,8 @@ func TestUploadFileToGcsNoStatus(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -1000,7 +1010,8 @@ func TestDownloadFileFromGcsError(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -1049,7 +1060,8 @@ func TestDownloadFileWithBadRequest(t *testing.T) {
t.Error(err)
}

gcsCli, err := new(snowflakeGcsClient).createClient(&info, false)
sgc := new(snowflakeGcsClient)
gcsCli, err := sgc.createClient(&info, false, sgc.cfg)
if err != nil {
t.Error(err)
}
Expand Down
Loading

0 comments on commit 0962923

Please sign in to comment.