Skip to content

Commit

Permalink
SNOW-1825476 Implement programmatic access token (PAT)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Feb 4, 2025
1 parent e926883 commit b4261df
Show file tree
Hide file tree
Showing 16 changed files with 634 additions and 6 deletions.
14 changes: 12 additions & 2 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ concurrency:
jobs:
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
name: Check linter
steps:
- uses: actions/checkout@v4
Expand All @@ -52,6 +50,10 @@ jobs:
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Ubuntu
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
Expand All @@ -78,6 +80,10 @@ jobs:
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Mac
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
Expand All @@ -103,6 +109,10 @@ jobs:
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Windows
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
Expand Down
24 changes: 24 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -51,6 +52,8 @@ const (
AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA
// AuthTypePat is to use programmatic access token
AuthTypePat
)

func determineAuthenticatorType(cfg *Config, value string) error {
Expand All @@ -74,6 +77,9 @@ func determineAuthenticatorType(cfg *Config, value string) error {
} else if upperCaseValue == AuthTypeTokenAccessor.String() {
cfg.Authenticator = AuthTypeTokenAccessor
return nil
} else if upperCaseValue == AuthTypePat.String() && experimentalAuthEnabled() {
cfg.Authenticator = AuthTypePat
return nil
} else {
// possibly Okta case
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
Expand Down Expand Up @@ -123,6 +129,8 @@ func (authType AuthType) String() string {
return "TOKENACCESSOR"
case AuthTypeUsernamePasswordMFA:
return "USERNAME_PASSWORD_MFA"
case AuthTypePat:
return "PROGRAMMATIC_ACCESS_TOKEN"
default:
return "UNKNOWN"
}
Expand Down Expand Up @@ -442,6 +450,17 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
return nil, err
}
requestMain.Token = jwtTokenString
case AuthTypePat:
if !experimentalAuthEnabled() {
return nil, errors.New("PAT is not ready to use")
}
logger.WithContext(sc.ctx).Info("Programmatic access token")
requestMain.Authenticator = AuthTypePat.String()
requestMain.LoginName = sc.cfg.User
requestMain.Token = sc.cfg.Token
if sc.cfg.Password != "" && sc.cfg.Token == "" {
requestMain.Token = sc.cfg.Password
}
case AuthTypeSnowflake:
logger.WithContext(sc.ctx).Info("Username and password")
requestMain.LoginName = sc.cfg.User
Expand Down Expand Up @@ -581,3 +600,8 @@ func fillCachedIDToken(sc *snowflakeConn) {
func fillCachedMfaToken(sc *snowflakeConn) {
credentialsStorage.getCredential(sc, mfaToken)
}

func experimentalAuthEnabled() bool {
val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION")
return ok && strings.EqualFold(val, "true")
}
68 changes: 68 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,71 @@ func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) {
assertStringContainsE(t, err.Error(), "context deadline exceeded")
cancel()
}

func TestPatSuccessfulFlow(t *testing.T) {
skipOnJenkins(t, "wiremock is not enabled")
enableExperimentalAuth(t)
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/successful_flow.json"},
wiremockMapping{filePath: "select1.json", params: map[string]string{
"%AUTHORIZATION_HEADER%": "Snowflake Token=\\\"session token\\\""},
},
)
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
rows, err := db.Query("SELECT 1")
assertNilF(t, err)
var v int
assertTrueE(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 1)
}

func enableExperimentalAuth(t *testing.T) {
err := os.Setenv("ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
assertNilF(t, err)
}

func TestPatSuccessfulFlowWithPatAsPasswordWithPatAuthenticator(t *testing.T) {
skipOnJenkins(t, "wiremock is not enabled")
enableExperimentalAuth(t)
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/successful_flow.json"},
wiremockMapping{filePath: "select1.json", params: map[string]string{
"%AUTHORIZATION_HEADER%": "Snowflake Token=\\\"session token\\\""},
},
)
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Password = "some PAT"
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
rows, err := db.Query("SELECT 1")
assertNilF(t, err)
var v int
assertTrueE(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 1)
}

func TestPatInvalidToken(t *testing.T) {
skipOnJenkins(t, "wiremock is not enabled")
enableExperimentalAuth(t)
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/invalid_token.json"},
)
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
_, err := db.Query("SELECT 1")
assertNotNilF(t, err)
var se *SnowflakeError
assertTrueF(t, errors.As(err, &se))
assertEqualE(t, se.Number, 394400)
assertEqualE(t, se.Message, "Programmatic access token is invalid.")
}
3 changes: 3 additions & 0 deletions ci/test.bat
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ setlocal EnableDelayedExpansion

start /b python ci\scripts\hang_webserver.py 12345

curl -O https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wiremock-standalone-3.11.0.jar
START /B java -jar wiremock-standalone-3.11.0.jar --port 14355

if "%CLOUD_PROVIDER%"=="AWS" set PARAMETER_FILENAME=parameters_aws_golang.json.gpg
if "%CLOUD_PROVIDER%"=="AZURE" set PARAMETER_FILENAME=parameters_azure_golang.json.gpg
if "%CLOUD_PROVIDER%"=="GCP" set PARAMETER_FILENAME=parameters_gcp_golang.json.gpg
Expand Down
3 changes: 3 additions & 0 deletions ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ set -o pipefail

CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

curl -O https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wiremock-standalone-3.11.0.jar
java -jar wiremock-standalone-3.11.0.jar --port 14355 &

if [[ -n "$JENKINS_HOME" ]]; then
ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)"
export WORKSPACE=${WORKSPACE:-/tmp}
Expand Down
1 change: 1 addition & 0 deletions cmd/programmatic_access_token/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pat
16 changes: 16 additions & 0 deletions cmd/programmatic_access_token/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
include ../../gosnowflake.mak
CMD_TARGET=pat

## Install
install: cinstall

## Run
run: crun

## Lint
lint: clint

## Format source codes
fmt: cfmt

.PHONY: install run lint fmt
52 changes: 52 additions & 0 deletions cmd/programmatic_access_token/pat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// you have to configure PAT on your user

package main

import (
"database/sql"
"flag"
"fmt"
sf "github.com/snowflakedb/gosnowflake"
"log"
)

func main() {
if !flag.Parsed() {
flag.Parse()
}

cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
{Name: "Token", EnvName: "SNOWFLAKE_TEST_PAT", FailOnMissing: true},
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
})
cfg.Authenticator = sf.AuthTypePat
if err != nil {
log.Fatalf("cannot build config. %v", err)
}

connector := sf.NewConnector(sf.SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()

query := "SELECT 1"
rows, err := db.Query(query)
if err != nil {
log.Fatalf("failed to run a query. %v, err: %v", query, err)
}
defer rows.Close()
var v int
if !rows.Next() {
log.Fatalf("no rows returned")
}
if err = rows.Scan(&v); err != nil {
log.Fatalf("failed to scan rows. %v", err)
}
if v != 1 {
log.Fatalf("unexpected result, expected 1, got %v", v)
}
fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query)
}
3 changes: 3 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri
if err := config.Validate(); err != nil {
return nil, err
}
if config.Params == nil {
config.Params = make(map[string]*string)
}
if config.Tracing != "" {
if err := logger.SetLogLevel(config.Tracing); err != nil {
return nil, err
Expand Down
11 changes: 8 additions & 3 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,16 @@ func buildHostFromAccountAndRegion(account, region string) string {
func authRequiresUser(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser
cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypePat
}

func authRequiresPassword(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypeJwt
cfg.Authenticator != AuthTypeJwt &&
cfg.Authenticator != AuthTypePat
}

// transformAccountToHost transforms account to host
Expand Down Expand Up @@ -905,7 +907,7 @@ type ConfigParam struct {

// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config
func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
var account, user, password, token, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
var privateKey *rsa.PrivateKey
var err error
if len(properties) == 0 || properties == nil {
Expand All @@ -923,6 +925,8 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
user = value
case "Password":
password = value
case "Token":
token = value
case "Role":
role = value
case "Host":
Expand Down Expand Up @@ -963,6 +967,7 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
Account: account,
User: user,
Password: password,
Token: token,
Role: role,
Host: host,
Port: port,
Expand Down
33 changes: 32 additions & 1 deletion dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,25 @@ func TestParseDSN(t *testing.T) {
ocspMode: ocspModeFailOpen,
err: nil,
},
{
dsn: "u:[email protected]:9876?account=a&protocol=http&authenticator=PROGRAMMATIC_ACCESS_TOKEN&disableSamlURLCheck=false&token=t",
config: &Config{
Account: "a", User: "u", Password: "p",
Authenticator: AuthTypePat,
Protocol: "http", Host: "a.snowflake.local", Port: 9876,
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
CloudStorageTimeout: defaultCloudStorageTimeout,
IncludeRetryReason: ConfigBoolTrue,
DisableSamlURLCheck: ConfigBoolFalse,
Token: "t",
},
ocspMode: ocspModeFailOpen,
err: nil,
},
}

for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
Expand Down Expand Up @@ -1213,7 +1232,8 @@ func TestParseDSN(t *testing.T) {
if test.config.DisableSamlURLCheck != cfg.DisableSamlURLCheck {
t.Fatalf("%v: Failed to match DisableSamlURLCheck. expected: %v, got: %v", i, test.config.DisableSamlURLCheck, cfg.DisableSamlURLCheck)
}
assertEqualF(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file")
assertEqualE(t, cfg.Token, test.config.Token, "token")
assertEqualE(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file")
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
driverErrG, okG := err.(*SnowflakeError)
Expand Down Expand Up @@ -1465,6 +1485,17 @@ func TestDSN(t *testing.T) {
},
dsn: "u:[email protected]:443?authenticator=externalbrowser&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a",
Token: "t",
Authenticator: AuthTypePat,
ClientStoreTemporaryCredential: ConfigBoolFalse,
},
dsn: "u:[email protected]:443?authenticator=programmatic_access_token&clientStoreTemporaryCredential=false&ocspFailOpen=true&token=t&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Expand Down
Loading

0 comments on commit b4261df

Please sign in to comment.