Skip to content

Commit

Permalink
SNOW-1859664 honour OCSP check settings in cloud storage clients too
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dszmolka committed Jan 9, 2025
1 parent 8257f91 commit c2219f0
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
14 changes: 8 additions & 6 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bo
if err != nil {
return nil, err
}
transport := getTransport(util.cfg)
client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{
ClientOptions: azcore.ClientOptions{
Retry: policy.RetryOptions{
MaxRetries: 60,
RetryDelay: 2 * time.Second,
},
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
},
},
})
Expand All @@ -74,7 +75,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
return nil, err
}
path := azureLoc.path + strings.TrimLeft(filename, "/")
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return nil, &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -188,7 +189,7 @@ func (util *snowflakeAzureClient) uploadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)

if err != nil {
return &SnowflakeError{
Expand Down Expand Up @@ -273,7 +274,7 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
Message: "failed to cast to azure client",
}
}
containerClient, err := createContainerClient(client.URL())
containerClient, err := createContainerClient(client.URL(), util.cfg)
if err != nil {
return &SnowflakeError{
Message: "failed to create container client",
Expand Down Expand Up @@ -348,10 +349,11 @@ func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Respons
strings.Contains(errStr, "Server failed to authenticate the request")
}

func createContainerClient(clientURL string) (*container.Client, error) {
func createContainerClient(clientURL string, cfg *Config) (*container.Client, error) {
transport := getTransport(cfg)
return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{
Transport: &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
},
}})
}
11 changes: 6 additions & 5 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -221,7 +221,7 @@ func (util *snowflakeGcsClient) uploadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -302,7 +302,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
for k, v := range gcsHeaders {
req.Header.Add(k, v)
}
client := newGcsClient()
client := newGcsClient(util.cfg)
// for testing only
if meta.mockGcsClient != nil {
client = meta.mockGcsClient
Expand Down Expand Up @@ -404,9 +404,10 @@ func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool {
return resp.StatusCode == 401
}

func newGcsClient() gcsAPI {
func newGcsClient(cfg *Config) gcsAPI {
transport := getTransport(cfg)
return &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
}
}

Expand Down
3 changes: 2 additions & 1 deletion s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
stageCredentials := info.Creds
s3Logger := logging.LoggerFunc(s3LoggingFunc)
endPoint := getS3CustomEndpoint(info)
transport := getTransport(util.cfg)

return s3.New(s3.Options{
Region: info.Region,
Expand All @@ -59,7 +60,7 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
BaseEndpoint: endPoint,
UseAccelerate: useAccelerateEndpoint,
HTTPClient: &http.Client{
Transport: SnowflakeTransport,
Transport: transport,
},
ClientLogMode: S3LoggingMode,
Logger: s3Logger,
Expand Down
10 changes: 10 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -348,3 +349,12 @@ func findByPrefix(in []string, prefix string) int {
}
return -1
}

func getTransport(cfg *Config) *http.Transport {
if cfg.DisableOCSPChecks || cfg.InsecureMode {
logger.Debug("getTransport: won't perform OCSP validation")
return snowflakeInsecureTransport
}
logger.Debug("getTransport: will perform OCSP validation")
return SnowflakeTransport
}

0 comments on commit c2219f0

Please sign in to comment.