Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-955538: Multiple SAML Integrations Support #1025

Merged
merged 17 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password,
sc.cfg.ExternalBrowserTimeout)
sc.cfg.ExternalBrowserTimeout,
sc.cfg.DisableConsoleLogin)
if err != nil {
sc.cleanup()
return err
Expand Down
16 changes: 16 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,22 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) {
assertEqualE(t, err.Error(), "failed to get SAML response")
}

func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuthSAML: postAuthSAMLError,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
if err == nil {
t.Fatalf("should have failed.")
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
}
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
Expand Down
52 changes: 44 additions & 8 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -70,11 +71,11 @@ func createLocalTCPListener() (*net.TCPListener, error) {
return tcpListener, nil
}

// Opens a browser window (or new tab) with the configured IDP Url.
// Opens a browser window (or new tab) with the configured login Url.
// This can / will fail if running inside a shell with no display, ie
// ssh'ing into a box attempting to authenticate via external browser.
func openBrowser(idpURL string) error {
err := browser.OpenURL(idpURL)
func openBrowser(loginURL string) error {
err := browser.OpenURL(loginURL)
if err != nil {
logger.Infof("failed to open a browser. err: %v", err)
return err
Expand All @@ -91,6 +92,7 @@ func getIdpURLProofKey(
authenticator string,
application string,
account string,
user string,
callbackPort int) (string, string, error) {

headers := make(map[string]string)
Expand All @@ -108,6 +110,7 @@ func getIdpURLProofKey(
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: account,
LoginName: user,
ClientEnvironment: clientEnvironment,
Authenticator: authenticator,
BrowserModeRedirectPort: strconv.Itoa(callbackPort),
Expand Down Expand Up @@ -144,6 +147,28 @@ func getIdpURLProofKey(
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}

// Gets the login URL for multiple SAML
func getLoginURL(
sr *snowflakeRestful,
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
user string,
callbackPort int) (string, string, error) {
sfc-gh-ext-simba-jl marked this conversation as resolved.
Show resolved Hide resolved

proofKey := generateProofKey()

params := &url.Values{}
params.Add("login_name", user)
params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort))
params.Add("proof_key", proofKey)
url := sr.getFullURL(consoleLoginRequestPath, params)

return url.String(), proofKey, nil
}

func generateProofKey() string {
randomness := getSecureRandom(32)
return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness)
}

// The response returned from Snowflake looks like so:
// GET /?token=encodedSamlToken
// Host: localhost:54001
Expand Down Expand Up @@ -187,10 +212,11 @@ func authenticateByExternalBrowser(
user string,
password string,
externalBrowserTimeout time.Duration,
disableConsoleLogin ConfigBool,
) ([]byte, []byte, error) {
resultChan := make(chan authenticateByExternalBrowserResult, 1)
go func() {
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password)
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin)
}()
select {
case <-time.After(externalBrowserTimeout):
Expand All @@ -204,7 +230,7 @@ func authenticateByExternalBrowser(
// - the golang snowflake driver communicates to Snowflake that the user wishes to
// authenticate via external browser
// - snowflake sends back the IDP Url configured at the Snowflake side for the
// provided account
// provided account, or use the multiple SAML way via console login
// - the default browser is opened to that URL
// - user authenticates at the IDP, and is redirected to Snowflake
// - Snowflake directs the user back to the driver
Expand All @@ -217,6 +243,7 @@ func doAuthenticateByExternalBrowser(
account string,
user string,
password string,
disableConsoleLogin ConfigBool,
) authenticateByExternalBrowserResult {
l, err := createLocalTCPListener()
if err != nil {
Expand All @@ -225,13 +252,22 @@ func doAuthenticateByExternalBrowser(
defer l.Close()

callbackPort := l.Addr().(*net.TCPAddr).Port
idpURL, proofKey, err := getIdpURLProofKey(
ctx, sr, authenticator, application, account, callbackPort)

var loginURL string
var proofKey string
if disableConsoleLogin == ConfigBoolTrue {
// Gets the IDP URL and Proof Key from Snowflake
loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort)
} else {
// Multiple SAML way to do authentication via console login
loginURL, proofKey, err = getLoginURL(sr, user, callbackPort)
}

if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

if err = openBrowser(idpURL); err != nil {
if err = openBrowser(loginURL); err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

Expand Down
8 changes: 4 additions & 4 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -128,7 +128,7 @@ func TestAuthenticationTimeout(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err.Error() != "authentication timed out" {
t.Fatal("should have timed out")
}
Expand Down
2 changes: 2 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
- clientConfigFile: specifies the location of the client configuration json file.
In this file you can configure Easy Logging feature.

- disableConsoleLogin: true by default. Set to true false to disable console login.
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved

All other parameters are interpreted as session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html).
For example, the TIMESTAMP_OUTPUT_FORMAT session parameter can be set by adding:

Expand Down Expand Up @@ -959,7 +961,7 @@
db.Query() function:

db.Query("GET file://<local_file> <stage_identifier> <optional_parameters>")

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AWS Go 1.19 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AWS Go 1.20 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AWS Go 1.21 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AZURE Go 1.19 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AZURE Go 1.20 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / AZURE Go 1.21 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / GCP Go 1.19 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / GCP Go 1.20 on Windows

package comment is detached; there should be no blank lines between it and the package statement

Check failure on line 964 in doc.go

View workflow job for this annotation

GitHub Actions / GCP Go 1.21 on Windows

package comment is detached; there should be no blank lines between it and the package statement
"<local_file>" should include the file path as well as the name. Snowflake recommends using
an absolute path rather than a relative path. For example:

Expand Down
20 changes: 20 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
IncludeRetryReason ConfigBool // Should retried request contain retry reason

ClientConfigFile string // File path to the client configuration json file

DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled
}

// Validate enables testing if config is correct.
Expand Down Expand Up @@ -262,6 +264,9 @@
if cfg.ClientConfigFile != "" {
params.Add("clientConfigFile", cfg.ClientConfigFile)
}
if cfg.DisableConsoleLogin != configBoolNotSet {
params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse))
}

dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port)
if params.Encode() != "" {
Expand Down Expand Up @@ -495,6 +500,10 @@
cfg.IncludeRetryReason = ConfigBoolTrue
}

if cfg.DisableConsoleLogin == configBoolNotSet {
cfg.DisableConsoleLogin = ConfigBoolTrue
}

if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) {
return &SnowflakeError{
Number: ErrCodeFailedToParseHost,
Expand Down Expand Up @@ -754,6 +763,17 @@
}
case "clientConfigFile":
cfg.ClientConfigFile = value
case "disableConsoleLogin":
var vv bool
vv, err = strconv.ParseBool(value)
if err != nil {
return
}

Check warning on line 771 in dsn.go

View check run for this annotation

Codecov / codecov/patch

dsn.go#L770-L771

Added lines #L770 - L771 were not covered by tests
if vv {
cfg.DisableConsoleLogin = ConfigBoolTrue
} else {
cfg.DisableConsoleLogin = ConfigBoolFalse
}
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand Down
Loading
Loading