Skip to content

Commit

Permalink
* moving getTransport util.go -> connection.go
Browse files Browse the repository at this point in the history
* adding ability to handle custom Transporter in getTransport
* adding test for the above scenario
  • Loading branch information
sfc-gh-dszmolka committed Jan 14, 2025
1 parent f257d25 commit 4b27f9c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 57 deletions.
18 changes: 18 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,21 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err

return sc, nil
}

func getTransport(cfg *Config) http.RoundTripper {
if cfg == nil {
logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
// if user configured a custom Transporter, prioritize that
if cfg.Transporter != nil {
logger.Debug("getTransport: using Transporter configured by the user")
return cfg.Transporter
}
if cfg.DisableOCSPChecks || cfg.InsecureMode {
logger.Debug("getTransport: skipping OCSP validation for cloud storage")
return snowflakeInsecureTransport
}
logger.Debug("getTransport: will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
53 changes: 53 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,56 @@ func TestBeginCreatesTransaction(t *testing.T) {
}
})
}

type EmptyTransporter struct{}

func (t EmptyTransporter) RoundTrip(req *http.Request) (*http.Response, error) {
return snowflakeInsecureTransport.RoundTrip(req)
}

func TestGetTransport(t *testing.T) {
testcases := []struct {
name string
cfg *Config
transport http.RoundTripper
}{
{
name: "DisableOCSPChecks and InsecureMode false",
cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false},
transport: SnowflakeTransport,
},
{
name: "DisableOCSPChecks true and InsecureMode false",
cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false},
transport: snowflakeInsecureTransport,
},
{
name: "DisableOCSPChecks false and InsecureMode true",
cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true},
transport: snowflakeInsecureTransport,
},
{
name: "DisableOCSPChecks and InsecureMode missing from Config",
cfg: &Config{Account: "four"},
transport: SnowflakeTransport,
},
{
name: "whole Config is missing",
cfg: nil,
transport: SnowflakeTransport,
},
{
name: "Using custom Transporter",
cfg: &Config{Account: "five", DisableOCSPChecks: true, InsecureMode: false, Transporter: EmptyTransporter{}},
transport: EmptyTransporter{},
},
}
for _, test := range testcases {
t.Run(test.name, func(t *testing.T) {
result := getTransport(test.cfg)
if test.transport != result {
t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result)
}
})
}
}
14 changes: 0 additions & 14 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -349,16 +348,3 @@ func findByPrefix(in []string, prefix string) int {
}
return -1
}

func getTransport(cfg *Config) *http.Transport {
if cfg == nil {
logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
if cfg.DisableOCSPChecks || cfg.InsecureMode {
logger.Debug("getTransport: skipping OCSP validation for cloud storage")
return snowflakeInsecureTransport
}
logger.Debug("getTransport: will perform OCSP validation for cloud storage")
return SnowflakeTransport
}
43 changes: 0 additions & 43 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"database/sql/driver"
"fmt"
"math/rand"
"net/http"
"os"
"strconv"
"sync"
Expand Down Expand Up @@ -421,45 +420,3 @@ func TestFindByPrefix(t *testing.T) {
assertEqualE(t, findByPrefix(nonEmpty, "dd"), -1)
assertEqualE(t, findByPrefix([]string{}, "dd"), -1)
}

func TestGetTransport(t *testing.T) {
testcases := []struct {
name string
cfg *Config
transport *http.Transport
}{
{
name: "DisableOCSPChecks and InsecureMode false",
cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false},
transport: SnowflakeTransport,
},
{
name: "DisableOCSPChecks true and InsecureMode false",
cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false},
transport: snowflakeInsecureTransport,
},
{
name: "DisableOCSPChecks false and InsecureMode true",
cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true},
transport: snowflakeInsecureTransport,
},
{
name: "DisableOCSPChecks and InsecureMode missing from Config",
cfg: &Config{Account: "four"},
transport: SnowflakeTransport,
},
{
name: "whole Config is missing",
cfg: nil,
transport: SnowflakeTransport,
},
}
for _, test := range testcases {
t.Run(test.name, func(t *testing.T) {
result := getTransport(test.cfg)
if test.transport != result {
t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result)
}
})
}
}

0 comments on commit 4b27f9c

Please sign in to comment.