diff --git a/driver.go b/driver.go index 2d0d25c..5f28e25 100644 --- a/driver.go +++ b/driver.go @@ -18,10 +18,23 @@ func (d drv) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - socket, err := thrift.NewTSocket(cfg.Addr) + + tlsCfg, err := cfg.TLSCfg.Load() if err != nil { - return nil, err + return nil, fmt.Errorf("load tls config: %w", err) + } + + var socket thrift.TTransport + if tlsCfg != nil { + socket = thrift.NewTSSLSocketConf( + cfg.Addr, + &thrift.TConfiguration{TLSConfig: tlsCfg}) + } else { + socket = thrift.NewTSocketConf( + cfg.Addr, + &thrift.TConfiguration{}) } + var transport thrift.TTransport if cfg.Auth == "NOSASL" { transport = thrift.NewTBufferedTransport(socket, 4096) diff --git a/dsn.go b/dsn.go index 2c7c591..621442b 100644 --- a/dsn.go +++ b/dsn.go @@ -1,8 +1,12 @@ package gohive import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" "fmt" "net/url" + "os" "regexp" "strconv" "strings" @@ -16,6 +20,13 @@ type Config struct { Auth string Batch int SessionCfg map[string]string + TLSCfg *TLSConfig +} + +type TLSConfig struct { + InsecureSkipVerify bool + RootCAs []string + RootCAFiles []string } var ( @@ -25,11 +36,15 @@ var ( ) const ( - sessionConfPrefix = "session." - authConfName = "auth" - defaultAuth = "NOSASL" - batchSizeName = "batch" - defaultBatchSize = 10000 + sessionConfPrefix = "session." + tlsConfPrefix = "tls." + tlsInsecureSkipVerify = "insecure_skip_verify" + tlsRootCA = "root_ca" + tlsRootCAFile = "root_ca_file" + authConfName = "auth" + defaultAuth = "NOSASL" + batchSizeName = "batch" + defaultBatchSize = 10000 ) // ParseDSN requires DSN names in the format [user[:password]@]addr/dbname. @@ -61,6 +76,8 @@ func ParseDSN(dsn string) (*Config, error) { auth := defaultAuth batch := defaultBatchSize sc := make(map[string]string) + var tls *TLSConfig + var err error if len(sub[3]) > 0 && sub[3][0] == '?' { qry, _ := url.ParseQuery(sub[3][1:]) @@ -79,6 +96,35 @@ func ParseDSN(dsn string) (*Config, error) { if strings.HasPrefix(k, sessionConfPrefix) { sc[k[len(sessionConfPrefix):]] = v[0] } + if strings.HasPrefix(k, tlsConfPrefix) { + if tls == nil { + tls = &TLSConfig{} + } + + key := k[len(tlsConfPrefix):] + switch key { + case tlsInsecureSkipVerify: + tls.InsecureSkipVerify, err = strconv.ParseBool(v[0]) + if err != nil { + return nil, fmt.Errorf("parse insecure_skip_verify: %w", err) + } + case tlsRootCA: + for _, val := range v { + pem, err := base64.URLEncoding.DecodeString(val) + if err != nil { + return nil, fmt.Errorf("decode root ca: %w", err) + } + tls.RootCAs = append(tls.RootCAs, string(pem)) + } + case tlsRootCAFile: + for _, val := range v { + tls.RootCAFiles = append(tls.RootCAFiles, val) + } + default: + return nil, fmt.Errorf("unsupported tls option: [%s]", key) + } + sc[k[len(sessionConfPrefix):]] = v[0] + } } } @@ -90,6 +136,7 @@ func ParseDSN(dsn string) (*Config, error) { Auth: auth, Batch: batch, SessionCfg: sc, + TLSCfg: tls, }, nil } @@ -108,5 +155,56 @@ func (cfg *Config) FormatDSN() string { dsn += fmt.Sprintf("&%s%s=%s", sessionConfPrefix, k, v) } } + if cfg.TLSCfg != nil { + if cfg.TLSCfg.InsecureSkipVerify { + dsn += fmt.Sprintf( + "&%s%s=%t", + tlsConfPrefix, + tlsInsecureSkipVerify, + cfg.TLSCfg.InsecureSkipVerify) + } + for _, ca := range cfg.TLSCfg.RootCAs { + dsn += fmt.Sprintf( + "&%s%s=%s", + tlsConfPrefix, + tlsRootCA, + base64.URLEncoding.EncodeToString([]byte(ca))) + } + for _, caFile := range cfg.TLSCfg.RootCAFiles { + dsn += fmt.Sprintf( + "&%s%s=%s", + tlsConfPrefix, + tlsRootCAFile, + caFile) + } + } return dsn } + +func (c *TLSConfig) Load() (*tls.Config, error) { + if c == nil { + return nil, nil + } + + cfg := tls.Config{ + InsecureSkipVerify: c.InsecureSkipVerify, + } + + var err error + cfg.RootCAs, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("load system certs: %w", err) + } + for _, f := range c.RootCAFiles { + pem, err := os.ReadFile(f) + if err != nil { + return nil, fmt.Errorf("read %s: %w", f, err) + } + cfg.RootCAs.AppendCertsFromPEM(pem) + } + for _, pem := range c.RootCAs { + cfg.RootCAs.AppendCertsFromPEM([]byte(pem)) + } + + return &cfg, nil +} diff --git a/dsn_test.go b/dsn_test.go index d41ca75..8827b44 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -1,9 +1,43 @@ package gohive import ( + "encoding/base64" + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + // This certificate was sourced from the examples in go's documentation + // for [x509 Certificate.Verify]. + // + // [x509 Certificate.Verify]: https://pkg.go.dev/crypto/x509#example-Certificate.Verify + rootPEM = `-----BEGIN CERTIFICATE----- +MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT +MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i +YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG +EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy +bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP +VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv +h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE +ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ +EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC +DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 +qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD +VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g +K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI +KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n +ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB +BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY +/iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ +zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza +HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto +WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 +yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx +-----END CERTIFICATE-----` ) func TestParseDSNWithSessionConf(t *testing.T) { @@ -100,6 +134,59 @@ func TestParseDSNWithoutDBName(t *testing.T) { assert.Equal(t, cfg.Addr, "127.0.0.1") } +func TestParseDSNWithTLSConfig(t *testing.T) { + b64 := base64.URLEncoding.EncodeToString([]byte(rootPEM)) + t.Run("no tls", func(t *testing.T) { + cfg, e := ParseDSN("127.0.0.1") + require.NoError(t, e) + require.Nil(t, cfg.TLSCfg) + }) + + t.Run("one ca", func(t *testing.T) { + cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca=%s", b64)) + require.NoError(t, e) + require.NotNil(t, cfg.TLSCfg) + require.Len(t, cfg.TLSCfg.RootCAs, 1) + require.Equal(t, cfg.TLSCfg.RootCAs[0], rootPEM) + }) + + t.Run("two cas", func(t *testing.T) { + cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca=%s&tls.root_ca=%s", b64, b64)) + require.NoError(t, e) + require.NotNil(t, cfg.TLSCfg) + require.Len(t, cfg.TLSCfg.RootCAs, 2) + require.Equal(t, cfg.TLSCfg.RootCAs[0], rootPEM) + require.Equal(t, cfg.TLSCfg.RootCAs[1], rootPEM) + }) + + t.Run("one ca fiel", func(t *testing.T) { + file := "cert.pem" + cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca_file=%s", file)) + require.NoError(t, e) + require.NotNil(t, cfg.TLSCfg) + require.Len(t, cfg.TLSCfg.RootCAFiles, 1) + require.Equal(t, cfg.TLSCfg.RootCAFiles[0], file) + }) + + t.Run("two ca files", func(t *testing.T) { + file := "cert.pem" + file2 := "cert2.pem" + cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca_file=%s&tls.root_ca_file=%s", file, file2)) + require.NoError(t, e) + require.NotNil(t, cfg.TLSCfg) + require.Len(t, cfg.TLSCfg.RootCAFiles, 2) + require.Equal(t, cfg.TLSCfg.RootCAFiles[0], file) + require.Equal(t, cfg.TLSCfg.RootCAFiles[1], file2) + }) + + t.Run("insecure skip verify", func(t *testing.T) { + cfg, e := ParseDSN("127.0.0.1?tls.insecure_skip_verify=true") + require.NoError(t, e) + require.NotNil(t, cfg.TLSCfg) + require.True(t, cfg.TLSCfg.InsecureSkipVerify) + }) +} + func TestFormatDSNWithDBName(t *testing.T) { ds := "user:passwd@127.0.0.1/mnist?batch=100000&auth=NOSASL" cfg, e := ParseDSN(ds) @@ -117,3 +204,72 @@ func TestFormatDSNWithoutDBName(t *testing.T) { ds2 := cfg.FormatDSN() assert.Equal(t, ds2, ds) } + +func TestFormatDSNWithTLSConfig(t *testing.T) { + b64 := base64.URLEncoding.EncodeToString([]byte(rootPEM)) + t.Run("no tls", func(t *testing.T) { + require.Equal(t, + ":@127.0.0.1?batch=0", + (&Config{ + Addr: "127.0.0.1", + }).FormatDSN()) + }) + + t.Run("one ca", func(t *testing.T) { + require.Equal(t, + fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca=%s", b64), + (&Config{ + Addr: "127.0.0.1", + TLSCfg: &TLSConfig{ + RootCAs: []string{rootPEM}, + }, + }).FormatDSN()) + }) + + t.Run("two cas", func(t *testing.T) { + require.Equal(t, + fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca=%s&tls.root_ca=%s", b64, b64), + (&Config{ + Addr: "127.0.0.1", + TLSCfg: &TLSConfig{ + RootCAs: []string{rootPEM, rootPEM}, + }, + }).FormatDSN()) + }) + + t.Run("one ca file", func(t *testing.T) { + file := "cert.pem" + require.Equal(t, + fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca_file=%s", file), + (&Config{ + Addr: "127.0.0.1", + TLSCfg: &TLSConfig{ + RootCAFiles: []string{file}, + }, + }).FormatDSN()) + }) + + t.Run("two ca files", func(t *testing.T) { + file := "cert.pem" + file2 := "cert2.pem" + require.Equal(t, + fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca_file=%s&tls.root_ca_file=%s", file, file2), + (&Config{ + Addr: "127.0.0.1", + TLSCfg: &TLSConfig{ + RootCAFiles: []string{file, file2}, + }, + }).FormatDSN()) + }) + + t.Run("insecure skip verify", func(t *testing.T) { + require.Equal(t, + ":@127.0.0.1?batch=0&tls.insecure_skip_verify=true", + (&Config{ + Addr: "127.0.0.1", + TLSCfg: &TLSConfig{ + InsecureSkipVerify: true, + }, + }).FormatDSN()) + }) +}