Skip to content

Commit

Permalink
implement multiple SAML
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl committed Jan 9, 2024
1 parent 7328b30 commit 193d6fe
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 51 deletions.
3 changes: 2 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password,
sc.cfg.ExternalBrowserTimeout)
sc.cfg.ExternalBrowserTimeout,
sc.cfg.DisableConsoleLogin)
if err != nil {
sc.cleanup()
return err
Expand Down
16 changes: 16 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,22 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) {
assertEqualE(t, err.Error(), "failed to get SAML response")
}

func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuthSAML: postAuthSAMLError,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
if err == nil {
t.Fatalf("should have failed.")
}
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
Expand Down
52 changes: 44 additions & 8 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -70,11 +71,11 @@ func createLocalTCPListener() (*net.TCPListener, error) {
return tcpListener, nil
}

// Opens a browser window (or new tab) with the configured IDP Url.
// Opens a browser window (or new tab) with the configured login Url.
// This can / will fail if running inside a shell with no display, ie
// ssh'ing into a box attempting to authenticate via external browser.
func openBrowser(idpURL string) error {
err := browser.OpenURL(idpURL)
func openBrowser(loginURL string) error {
err := browser.OpenURL(loginURL)
if err != nil {
logger.Infof("failed to open a browser. err: %v", err)
return err
Expand All @@ -91,6 +92,7 @@ func getIdpURLProofKey(
authenticator string,
application string,
account string,
user string,
callbackPort int) (string, string, error) {

headers := make(map[string]string)
Expand All @@ -108,6 +110,7 @@ func getIdpURLProofKey(
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: account,
LoginName: user,
ClientEnvironment: clientEnvironment,
Authenticator: authenticator,
BrowserModeRedirectPort: strconv.Itoa(callbackPort),
Expand Down Expand Up @@ -144,6 +147,28 @@ func getIdpURLProofKey(
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}

// Gets the login URL for multiple SAML
func getLoginURL(
sr *snowflakeRestful,
user string,
callbackPort int) (string, string, error) {

proofKey := generateProofKey()

params := &url.Values{}
params.Add("login_name", user)
params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort))
params.Add("proof_key", proofKey)
url := sr.getFullURL(consoleLoginRequestPath, params)

return url.String(), proofKey, nil
}

func generateProofKey() string {
randomness := getSecureRandom(32)
return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness)
}

// The response returned from Snowflake looks like so:
// GET /?token=encodedSamlToken
// Host: localhost:54001
Expand Down Expand Up @@ -187,10 +212,11 @@ func authenticateByExternalBrowser(
user string,
password string,
externalBrowserTimeout time.Duration,
disableConsoleLogin ConfigBool,
) ([]byte, []byte, error) {
resultChan := make(chan authenticateByExternalBrowserResult, 1)
go func() {
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password)
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin)
}()
select {
case <-time.After(externalBrowserTimeout):
Expand All @@ -204,7 +230,7 @@ func authenticateByExternalBrowser(
// - the golang snowflake driver communicates to Snowflake that the user wishes to
// authenticate via external browser
// - snowflake sends back the IDP Url configured at the Snowflake side for the
// provided account
// provided account, or use the multiple SAML way via console login
// - the default browser is opened to that URL
// - user authenticates at the IDP, and is redirected to Snowflake
// - Snowflake directs the user back to the driver
Expand All @@ -217,6 +243,7 @@ func doAuthenticateByExternalBrowser(
account string,
user string,
password string,
disableConsoleLogin ConfigBool,
) authenticateByExternalBrowserResult {
l, err := createLocalTCPListener()
if err != nil {
Expand All @@ -225,13 +252,22 @@ func doAuthenticateByExternalBrowser(
defer l.Close()

callbackPort := l.Addr().(*net.TCPAddr).Port
idpURL, proofKey, err := getIdpURLProofKey(
ctx, sr, authenticator, application, account, callbackPort)

var loginURL string
var proofKey string
if disableConsoleLogin == ConfigBoolTrue {
// Gets the IDP URL and Proof Key from Snowflake
loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort)
} else {
// Multiple SAML way to do authentication via console login
loginURL, proofKey, err = getLoginURL(sr, user, callbackPort)
}

if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

if err = openBrowser(idpURL); err != nil {
if err = openBrowser(loginURL); err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

Expand Down
8 changes: 4 additions & 4 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -128,7 +128,7 @@ func TestAuthenticationTimeout(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err.Error() != "authentication timed out" {
t.Fatal("should have timed out")
}
Expand Down
2 changes: 2 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ The following connection parameters are supported:
- clientConfigFile: specifies the location of the client configuration json file.
In this file you can configure Easy Logging feature.
- disableConsoleLogin: true by default. Set to true false to disable console login.
All other parameters are interpreted as session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html).
For example, the TIMESTAMP_OUTPUT_FORMAT session parameter can be set by adding:
Expand Down
20 changes: 20 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ type Config struct {
IncludeRetryReason ConfigBool // Should retried request contain retry reason

ClientConfigFile string // File path to the client configuration json file

DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled
}

// Validate enables testing if config is correct.
Expand Down Expand Up @@ -262,6 +264,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.ClientConfigFile != "" {
params.Add("clientConfigFile", cfg.ClientConfigFile)
}
if cfg.DisableConsoleLogin != configBoolNotSet {
params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse))
}

dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port)
if params.Encode() != "" {
Expand Down Expand Up @@ -495,6 +500,10 @@ func fillMissingConfigParameters(cfg *Config) error {
cfg.IncludeRetryReason = ConfigBoolTrue
}

if cfg.DisableConsoleLogin == configBoolNotSet {
cfg.DisableConsoleLogin = ConfigBoolTrue
}

if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) {
return &SnowflakeError{
Number: ErrCodeFailedToParseHost,
Expand Down Expand Up @@ -754,6 +763,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}
case "clientConfigFile":
cfg.ClientConfigFile = value
case "disableConsoleLogin":
var vv bool
vv, err = strconv.ParseBool(value)
if err != nil {
return
}

Check warning on line 771 in dsn.go

View check run for this annotation

Codecov / codecov/patch

dsn.go#L770-L771

Added lines #L770 - L771 were not covered by tests
if vv {
cfg.DisableConsoleLogin = ConfigBoolTrue
} else {
cfg.DisableConsoleLogin = ConfigBoolFalse
}
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand Down
Loading

0 comments on commit 193d6fe

Please sign in to comment.