Skip to content

Commit

Permalink
added support for TLS (#74)
Browse files Browse the repository at this point in the history
* added support for TLS

* added no-tls tests

---------

Co-authored-by: Lucas Theisen <[email protected]>
  • Loading branch information
lucastheisen and Lucas Theisen authored Jul 30, 2024
1 parent c9657f0 commit d37d262
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 7 deletions.
17 changes: 15 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@ func (d drv) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
socket, err := thrift.NewTSocket(cfg.Addr)

tlsCfg, err := cfg.TLSCfg.Load()
if err != nil {
return nil, err
return nil, fmt.Errorf("load tls config: %w", err)
}

var socket thrift.TTransport
if tlsCfg != nil {
socket = thrift.NewTSSLSocketConf(
cfg.Addr,
&thrift.TConfiguration{TLSConfig: tlsCfg})
} else {
socket = thrift.NewTSocketConf(
cfg.Addr,
&thrift.TConfiguration{})
}

var transport thrift.TTransport
if cfg.Auth == "NOSASL" {
transport = thrift.NewTBufferedTransport(socket, 4096)
Expand Down
108 changes: 103 additions & 5 deletions dsn.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package gohive

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net/url"
"os"
"regexp"
"strconv"
"strings"
Expand All @@ -16,6 +20,13 @@ type Config struct {
Auth string
Batch int
SessionCfg map[string]string
TLSCfg *TLSConfig
}

type TLSConfig struct {
InsecureSkipVerify bool
RootCAs []string
RootCAFiles []string
}

var (
Expand All @@ -25,11 +36,15 @@ var (
)

const (
sessionConfPrefix = "session."
authConfName = "auth"
defaultAuth = "NOSASL"
batchSizeName = "batch"
defaultBatchSize = 10000
sessionConfPrefix = "session."
tlsConfPrefix = "tls."
tlsInsecureSkipVerify = "insecure_skip_verify"
tlsRootCA = "root_ca"
tlsRootCAFile = "root_ca_file"
authConfName = "auth"
defaultAuth = "NOSASL"
batchSizeName = "batch"
defaultBatchSize = 10000
)

// ParseDSN requires DSN names in the format [user[:password]@]addr/dbname.
Expand Down Expand Up @@ -61,6 +76,8 @@ func ParseDSN(dsn string) (*Config, error) {
auth := defaultAuth
batch := defaultBatchSize
sc := make(map[string]string)
var tls *TLSConfig
var err error
if len(sub[3]) > 0 && sub[3][0] == '?' {
qry, _ := url.ParseQuery(sub[3][1:])

Expand All @@ -79,6 +96,35 @@ func ParseDSN(dsn string) (*Config, error) {
if strings.HasPrefix(k, sessionConfPrefix) {
sc[k[len(sessionConfPrefix):]] = v[0]
}
if strings.HasPrefix(k, tlsConfPrefix) {
if tls == nil {
tls = &TLSConfig{}
}

key := k[len(tlsConfPrefix):]
switch key {
case tlsInsecureSkipVerify:
tls.InsecureSkipVerify, err = strconv.ParseBool(v[0])
if err != nil {
return nil, fmt.Errorf("parse insecure_skip_verify: %w", err)
}
case tlsRootCA:
for _, val := range v {
pem, err := base64.URLEncoding.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("decode root ca: %w", err)
}
tls.RootCAs = append(tls.RootCAs, string(pem))
}
case tlsRootCAFile:
for _, val := range v {
tls.RootCAFiles = append(tls.RootCAFiles, val)
}
default:
return nil, fmt.Errorf("unsupported tls option: [%s]", key)
}
sc[k[len(sessionConfPrefix):]] = v[0]
}
}
}

Expand All @@ -90,6 +136,7 @@ func ParseDSN(dsn string) (*Config, error) {
Auth: auth,
Batch: batch,
SessionCfg: sc,
TLSCfg: tls,
}, nil
}

Expand All @@ -108,5 +155,56 @@ func (cfg *Config) FormatDSN() string {
dsn += fmt.Sprintf("&%s%s=%s", sessionConfPrefix, k, v)
}
}
if cfg.TLSCfg != nil {
if cfg.TLSCfg.InsecureSkipVerify {
dsn += fmt.Sprintf(
"&%s%s=%t",
tlsConfPrefix,
tlsInsecureSkipVerify,
cfg.TLSCfg.InsecureSkipVerify)
}
for _, ca := range cfg.TLSCfg.RootCAs {
dsn += fmt.Sprintf(
"&%s%s=%s",
tlsConfPrefix,
tlsRootCA,
base64.URLEncoding.EncodeToString([]byte(ca)))
}
for _, caFile := range cfg.TLSCfg.RootCAFiles {
dsn += fmt.Sprintf(
"&%s%s=%s",
tlsConfPrefix,
tlsRootCAFile,
caFile)
}
}
return dsn
}

func (c *TLSConfig) Load() (*tls.Config, error) {
if c == nil {
return nil, nil
}

cfg := tls.Config{
InsecureSkipVerify: c.InsecureSkipVerify,
}

var err error
cfg.RootCAs, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("load system certs: %w", err)
}
for _, f := range c.RootCAFiles {
pem, err := os.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("read %s: %w", f, err)
}
cfg.RootCAs.AppendCertsFromPEM(pem)
}
for _, pem := range c.RootCAs {
cfg.RootCAs.AppendCertsFromPEM([]byte(pem))
}

return &cfg, nil
}
156 changes: 156 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
package gohive

import (
"encoding/base64"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
// This certificate was sourced from the examples in go's documentation
// for [x509 Certificate.Verify].
//
// [x509 Certificate.Verify]: https://pkg.go.dev/crypto/x509#example-Certificate.Verify
rootPEM = `-----BEGIN CERTIFICATE-----
MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT
MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i
YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG
EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy
bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP
VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv
h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE
ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ
EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC
DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7
qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD
VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g
K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI
KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n
ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB
BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY
/iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/
zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza
HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto
WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6
yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx
-----END CERTIFICATE-----`
)

func TestParseDSNWithSessionConf(t *testing.T) {
Expand Down Expand Up @@ -100,6 +134,59 @@ func TestParseDSNWithoutDBName(t *testing.T) {
assert.Equal(t, cfg.Addr, "127.0.0.1")
}

func TestParseDSNWithTLSConfig(t *testing.T) {
b64 := base64.URLEncoding.EncodeToString([]byte(rootPEM))
t.Run("no tls", func(t *testing.T) {
cfg, e := ParseDSN("127.0.0.1")
require.NoError(t, e)
require.Nil(t, cfg.TLSCfg)
})

t.Run("one ca", func(t *testing.T) {
cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca=%s", b64))
require.NoError(t, e)
require.NotNil(t, cfg.TLSCfg)
require.Len(t, cfg.TLSCfg.RootCAs, 1)
require.Equal(t, cfg.TLSCfg.RootCAs[0], rootPEM)
})

t.Run("two cas", func(t *testing.T) {
cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca=%s&tls.root_ca=%s", b64, b64))
require.NoError(t, e)
require.NotNil(t, cfg.TLSCfg)
require.Len(t, cfg.TLSCfg.RootCAs, 2)
require.Equal(t, cfg.TLSCfg.RootCAs[0], rootPEM)
require.Equal(t, cfg.TLSCfg.RootCAs[1], rootPEM)
})

t.Run("one ca fiel", func(t *testing.T) {
file := "cert.pem"
cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca_file=%s", file))
require.NoError(t, e)
require.NotNil(t, cfg.TLSCfg)
require.Len(t, cfg.TLSCfg.RootCAFiles, 1)
require.Equal(t, cfg.TLSCfg.RootCAFiles[0], file)
})

t.Run("two ca files", func(t *testing.T) {
file := "cert.pem"
file2 := "cert2.pem"
cfg, e := ParseDSN(fmt.Sprintf("127.0.0.1?tls.root_ca_file=%s&tls.root_ca_file=%s", file, file2))
require.NoError(t, e)
require.NotNil(t, cfg.TLSCfg)
require.Len(t, cfg.TLSCfg.RootCAFiles, 2)
require.Equal(t, cfg.TLSCfg.RootCAFiles[0], file)
require.Equal(t, cfg.TLSCfg.RootCAFiles[1], file2)
})

t.Run("insecure skip verify", func(t *testing.T) {
cfg, e := ParseDSN("127.0.0.1?tls.insecure_skip_verify=true")
require.NoError(t, e)
require.NotNil(t, cfg.TLSCfg)
require.True(t, cfg.TLSCfg.InsecureSkipVerify)
})
}

func TestFormatDSNWithDBName(t *testing.T) {
ds := "user:[email protected]/mnist?batch=100000&auth=NOSASL"
cfg, e := ParseDSN(ds)
Expand All @@ -117,3 +204,72 @@ func TestFormatDSNWithoutDBName(t *testing.T) {
ds2 := cfg.FormatDSN()
assert.Equal(t, ds2, ds)
}

func TestFormatDSNWithTLSConfig(t *testing.T) {
b64 := base64.URLEncoding.EncodeToString([]byte(rootPEM))
t.Run("no tls", func(t *testing.T) {
require.Equal(t,
":@127.0.0.1?batch=0",
(&Config{
Addr: "127.0.0.1",
}).FormatDSN())
})

t.Run("one ca", func(t *testing.T) {
require.Equal(t,
fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca=%s", b64),
(&Config{
Addr: "127.0.0.1",
TLSCfg: &TLSConfig{
RootCAs: []string{rootPEM},
},
}).FormatDSN())
})

t.Run("two cas", func(t *testing.T) {
require.Equal(t,
fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca=%s&tls.root_ca=%s", b64, b64),
(&Config{
Addr: "127.0.0.1",
TLSCfg: &TLSConfig{
RootCAs: []string{rootPEM, rootPEM},
},
}).FormatDSN())
})

t.Run("one ca file", func(t *testing.T) {
file := "cert.pem"
require.Equal(t,
fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca_file=%s", file),
(&Config{
Addr: "127.0.0.1",
TLSCfg: &TLSConfig{
RootCAFiles: []string{file},
},
}).FormatDSN())
})

t.Run("two ca files", func(t *testing.T) {
file := "cert.pem"
file2 := "cert2.pem"
require.Equal(t,
fmt.Sprintf(":@127.0.0.1?batch=0&tls.root_ca_file=%s&tls.root_ca_file=%s", file, file2),
(&Config{
Addr: "127.0.0.1",
TLSCfg: &TLSConfig{
RootCAFiles: []string{file, file2},
},
}).FormatDSN())
})

t.Run("insecure skip verify", func(t *testing.T) {
require.Equal(t,
":@127.0.0.1?batch=0&tls.insecure_skip_verify=true",
(&Config{
Addr: "127.0.0.1",
TLSCfg: &TLSConfig{
InsecureSkipVerify: true,
},
}).FormatDSN())
})
}

0 comments on commit d37d262

Please sign in to comment.