diff --git a/connection.go b/connection.go index bd7dc084a..8900bfba2 100644 --- a/connection.go +++ b/connection.go @@ -856,3 +856,21 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err return sc, nil } + +func getTransport(cfg *Config) http.RoundTripper { + if cfg == nil { + logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage") + return SnowflakeTransport + } + // if user configured a custom Transporter, prioritize that + if cfg.Transporter != nil { + logger.Debug("getTransport: using Transporter configured by the user") + return cfg.Transporter + } + if cfg.DisableOCSPChecks || cfg.InsecureMode { + logger.Debug("getTransport: skipping OCSP validation for cloud storage") + return snowflakeInsecureTransport + } + logger.Debug("getTransport: will perform OCSP validation for cloud storage") + return SnowflakeTransport +} diff --git a/connection_test.go b/connection_test.go index e28cb3e92..06a5ff7a0 100644 --- a/connection_test.go +++ b/connection_test.go @@ -826,3 +826,56 @@ func TestBeginCreatesTransaction(t *testing.T) { } }) } + +type EmptyTransporter struct{} + +func (t EmptyTransporter) RoundTrip(req *http.Request) (*http.Response, error) { + return snowflakeInsecureTransport.RoundTrip(req) +} + +func TestGetTransport(t *testing.T) { + testcases := []struct { + name string + cfg *Config + transport http.RoundTripper + }{ + { + name: "DisableOCSPChecks and InsecureMode false", + cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false}, + transport: SnowflakeTransport, + }, + { + name: "DisableOCSPChecks true and InsecureMode false", + cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false}, + transport: snowflakeInsecureTransport, + }, + { + name: "DisableOCSPChecks false and InsecureMode true", + cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true}, + transport: snowflakeInsecureTransport, + }, + { + name: "DisableOCSPChecks and InsecureMode missing from Config", + cfg: &Config{Account: "four"}, + transport: SnowflakeTransport, + }, + { + name: "whole Config is missing", + cfg: nil, + transport: SnowflakeTransport, + }, + { + name: "Using custom Transporter", + cfg: &Config{Account: "five", DisableOCSPChecks: true, InsecureMode: false, Transporter: EmptyTransporter{}}, + transport: EmptyTransporter{}, + }, + } + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + result := getTransport(test.cfg) + if test.transport != result { + t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result) + } + }) + } +} diff --git a/util.go b/util.go index c1a078470..3319777dd 100644 --- a/util.go +++ b/util.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "math/rand" - "net/http" "os" "strings" "sync" @@ -349,16 +348,3 @@ func findByPrefix(in []string, prefix string) int { } return -1 } - -func getTransport(cfg *Config) *http.Transport { - if cfg == nil { - logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage") - return SnowflakeTransport - } - if cfg.DisableOCSPChecks || cfg.InsecureMode { - logger.Debug("getTransport: skipping OCSP validation for cloud storage") - return snowflakeInsecureTransport - } - logger.Debug("getTransport: will perform OCSP validation for cloud storage") - return SnowflakeTransport -} diff --git a/util_test.go b/util_test.go index 9744f8e22..bed222559 100644 --- a/util_test.go +++ b/util_test.go @@ -7,7 +7,6 @@ import ( "database/sql/driver" "fmt" "math/rand" - "net/http" "os" "strconv" "sync" @@ -421,45 +420,3 @@ func TestFindByPrefix(t *testing.T) { assertEqualE(t, findByPrefix(nonEmpty, "dd"), -1) assertEqualE(t, findByPrefix([]string{}, "dd"), -1) } - -func TestGetTransport(t *testing.T) { - testcases := []struct { - name string - cfg *Config - transport *http.Transport - }{ - { - name: "DisableOCSPChecks and InsecureMode false", - cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false}, - transport: SnowflakeTransport, - }, - { - name: "DisableOCSPChecks true and InsecureMode false", - cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false}, - transport: snowflakeInsecureTransport, - }, - { - name: "DisableOCSPChecks false and InsecureMode true", - cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true}, - transport: snowflakeInsecureTransport, - }, - { - name: "DisableOCSPChecks and InsecureMode missing from Config", - cfg: &Config{Account: "four"}, - transport: SnowflakeTransport, - }, - { - name: "whole Config is missing", - cfg: nil, - transport: SnowflakeTransport, - }, - } - for _, test := range testcases { - t.Run(test.name, func(t *testing.T) { - result := getTransport(test.cfg) - if test.transport != result { - t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result) - } - }) - } -}