Skip to content

Commit

Permalink
SNOW-1825790 Token cache refactor - v2 (#1299)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus authored Feb 5, 2025
1 parent e926883 commit d8df82e
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 152 deletions.
22 changes: 6 additions & 16 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ const (
)

const (
idToken = "ID_TOKEN"
mfaToken = "MFATOKEN"
clientStoreTemporaryCredential = "CLIENT_STORE_TEMPORARY_CREDENTIAL"
clientRequestMfaToken = "CLIENT_REQUEST_MFA_TOKEN"
idTokenAuthenticator = "ID_TOKEN"
Expand Down Expand Up @@ -365,10 +363,10 @@ func authenticate(
logger.WithContext(ctx).Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
credentialsStorage.deleteCredential(sc, mfaToken)
credentialsStorage.deleteCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
}
if sessionParameters[clientStoreTemporaryCredential] == true {
credentialsStorage.deleteCredential(sc, idToken)
credentialsStorage.deleteCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
Expand All @@ -384,11 +382,11 @@ func authenticate(
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
credentialsStorage.setCredential(sc, mfaToken, token)
credentialsStorage.setCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
token := respd.Data.IDToken
credentialsStorage.setCredential(sc, idToken, token)
credentialsStorage.setCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}
return &respd.Data, nil
}
Expand Down Expand Up @@ -523,7 +521,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
}
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
fillCachedIDToken(sc)
sc.cfg.IDToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
}
// Disable console login by default
if sc.cfg.DisableConsoleLogin == configBoolNotSet {
Expand All @@ -536,7 +534,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
}
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
fillCachedMfaToken(sc)
sc.cfg.MfaToken = credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
}
}

Expand Down Expand Up @@ -573,11 +571,3 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}

func fillCachedIDToken(sc *snowflakeConn) {
credentialsStorage.getCredential(sc, idToken)
}

func fillCachedMfaToken(sc *snowflakeConn) {
credentialsStorage.getCredential(sc, mfaToken)
}
146 changes: 82 additions & 64 deletions secure_storage_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,61 @@ import (
"github.com/99designs/keyring"
)

type tokenType string

const (
idToken tokenType = "ID_TOKEN"
mfaToken tokenType = "MFATOKEN"
)

const (
driverName = "SNOWFLAKE-GO-DRIVER"
credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"
credCacheFileName = "temporary_credential.json"
)

type secureTokenSpec struct {
host, user string
tokenType tokenType
}

func (t *secureTokenSpec) buildKey() string {
return buildCredentialsKey(t.host, t.user, t.tokenType)
}

func newMfaTokenSpec(host, user string) *secureTokenSpec {
return &secureTokenSpec{
host,
user,
mfaToken,
}
}

func newIDTokenSpec(host, user string) *secureTokenSpec {
return &secureTokenSpec{
host,
user,
idToken,
}
}

type secureStorageManager interface {
setCredential(sc *snowflakeConn, credType, token string)
getCredential(sc *snowflakeConn, credType string)
deleteCredential(sc *snowflakeConn, credType string)
setCredential(tokenSpec *secureTokenSpec, value string)
getCredential(tokenSpec *secureTokenSpec) string
deleteCredential(tokenSpec *secureTokenSpec)
}

var credentialsStorage = newSecureStorageManager()

func newSecureStorageManager() secureStorageManager {
switch runtime.GOOS {
case "linux":
return newFileBasedSecureStorageManager()
ssm, err := newFileBasedSecureStorageManager()
if err != nil {
logger.Debugf("failed to create credentials cache dir. %v", err)
return newNoopSecureStorageManager()
}
return ssm
case "darwin", "windows":
return newKeyringBasedSecureStorageManager()
default:
Expand All @@ -46,20 +83,19 @@ type fileBasedSecureStorageManager struct {
credCacheLock sync.RWMutex
}

func newFileBasedSecureStorageManager() secureStorageManager {
func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) {
ssm := &fileBasedSecureStorageManager{
localCredCache: map[string]string{},
credCacheLock: sync.RWMutex{},
}
credCacheDir := ssm.buildCredCacheDirPath()
if err := ssm.createCacheDir(credCacheDir); err != nil {
logger.Debugf("failed to create credentials cache dir. %v", err)
return newNoopSecureStorageManager()
return nil, err
}
credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName)
logger.Infof("Credentials cache path: %v", credCacheFilePath)
ssm.credCacheFilePath = credCacheFilePath
return ssm
return ssm, nil
}

func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) error {
Expand Down Expand Up @@ -87,14 +123,14 @@ func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string {
return credCacheDir
}

func (ssm *fileBasedSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
if token == "" {
func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
if value == "" {
logger.Debug("no token provided")
} else {
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
credentialsKey := tokenSpec.buildKey()
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
ssm.localCredCache[credentialsKey] = token
ssm.localCredCache[credentialsKey] = value

j, err := json.Marshal(ssm.localCredCache)
if err != nil {
Expand Down Expand Up @@ -135,8 +171,8 @@ func (ssm *fileBasedSecureStorageManager) setCredential(sc *snowflakeConn, credT
}
}

func (ssm *fileBasedSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
credentialsKey := tokenSpec.buildKey()
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
localCredCache := ssm.readTemporaryCacheFile()
Expand All @@ -146,14 +182,7 @@ func (ssm *fileBasedSecureStorageManager) getCredential(sc *snowflakeConn, credT
} else {
logger.Debug("Returned credential is empty")
}

if credType == idToken {
sc.cfg.IDToken = cred
} else if credType == mfaToken {
sc.cfg.MfaToken = cred
} else {
logger.Debugf("Unrecognized type %v for local cached credential", credType)
}
return cred
}

func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string {
Expand All @@ -171,10 +200,10 @@ func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]st
return ssm.localCredCache
}

func (ssm *fileBasedSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) {
func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
credentialsKey := tokenSpec.buildKey()
delete(ssm.localCredCache, credentialsKey)
j, err := json.Marshal(ssm.localCredCache)
if err != nil {
Expand Down Expand Up @@ -220,37 +249,35 @@ func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte)
type keyringSecureStorageManager struct {
}

func newKeyringBasedSecureStorageManager() secureStorageManager {
func newKeyringBasedSecureStorageManager() *keyringSecureStorageManager {
return &keyringSecureStorageManager{}
}

func (ssm *keyringSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
if token == "" {
func (ssm *keyringSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
if value == "" {
logger.Debug("no token provided")
} else {
var credentialsKey string
credentialsKey := tokenSpec.buildKey()
if runtime.GOOS == "windows" {
credentialsKey = driverName + ":" + credType
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
ServiceName: strings.ToUpper(sc.cfg.User),
WinCredPrefix: strings.ToUpper(tokenSpec.host),
ServiceName: strings.ToUpper(tokenSpec.user),
})
item := keyring.Item{
Key: credentialsKey,
Data: []byte(token),
Data: []byte(value),
}
if err := ring.Set(item); err != nil {
logger.Debugf("Failed to write to Windows credential manager. Err: %v", err)
}
} else if runtime.GOOS == "darwin" {
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
ring, _ := keyring.Open(keyring.Config{
ServiceName: credentialsKey,
})
account := strings.ToUpper(sc.cfg.User)
account := strings.ToUpper(tokenSpec.user)
item := keyring.Item{
Key: account,
Data: []byte(token),
Data: []byte(value),
}
if err := ring.Set(item); err != nil {
logger.Debugf("Failed to write to keychain. Err: %v", err)
Expand All @@ -259,26 +286,24 @@ func (ssm *keyringSecureStorageManager) setCredential(sc *snowflakeConn, credTyp
}
}

func (ssm *keyringSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
var credentialsKey string
func (ssm *keyringSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
cred := ""
credentialsKey := tokenSpec.buildKey()
if runtime.GOOS == "windows" {
credentialsKey = driverName + ":" + credType
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
ServiceName: strings.ToUpper(sc.cfg.User),
WinCredPrefix: strings.ToUpper(tokenSpec.host),
ServiceName: strings.ToUpper(tokenSpec.user),
})
i, err := ring.Get(credentialsKey)
if err != nil {
logger.Debugf("Failed to read credentialsKey or could not find it in Windows Credential Manager. Error: %v", err)
}
cred = string(i.Data)
} else if runtime.GOOS == "darwin" {
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
ring, _ := keyring.Open(keyring.Config{
ServiceName: credentialsKey,
})
account := strings.ToUpper(sc.cfg.User)
account := strings.ToUpper(tokenSpec.user)
i, err := ring.Get(account)
if err != nil {
logger.Debugf("Failed to find the item in keychain or item does not exist. Error: %v", err)
Expand All @@ -290,59 +315,52 @@ func (ssm *keyringSecureStorageManager) getCredential(sc *snowflakeConn, credTyp
logger.Debug("Successfully read token. Returning as string")
}
}

if credType == idToken {
sc.cfg.IDToken = cred
} else if credType == mfaToken {
sc.cfg.MfaToken = cred
} else {
logger.Debugf("Unrecognized type %v for local cached credential", credType)
}
return cred
}

func (ssm *keyringSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) {
credentialsKey := driverName + ":" + credType
func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
credentialsKey := tokenSpec.buildKey()
if runtime.GOOS == "windows" {
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
ServiceName: strings.ToUpper(sc.cfg.User),
WinCredPrefix: strings.ToUpper(tokenSpec.host),
ServiceName: strings.ToUpper(tokenSpec.user),
})
err := ring.Remove(credentialsKey)
err := ring.Remove(string(credentialsKey))
if err != nil {
logger.Debugf("Failed to delete credentialsKey in Windows Credential Manager. Error: %v", err)
}
} else if runtime.GOOS == "darwin" {
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
ring, _ := keyring.Open(keyring.Config{
ServiceName: credentialsKey,
})
account := strings.ToUpper(sc.cfg.User)
account := strings.ToUpper(tokenSpec.user)
err := ring.Remove(account)
if err != nil {
logger.Debugf("Failed to delete credentialsKey in keychain. Error: %v", err)
}
}
}

func buildCredentialsKey(host, user, credType string) string {
func buildCredentialsKey(host, user string, credType tokenType) string {
host = strings.ToUpper(host)
user = strings.ToUpper(user)
credType = strings.ToUpper(credType)
return host + ":" + user + ":" + driverName + ":" + credType
credTypeStr := strings.ToUpper(string(credType))
return host + ":" + user + ":" + driverName + ":" + credTypeStr
}

type noopSecureStorageManager struct {
}

func newNoopSecureStorageManager() secureStorageManager {
func newNoopSecureStorageManager() *noopSecureStorageManager {
return &noopSecureStorageManager{}
}

func (ssm *noopSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
func (ssm *noopSecureStorageManager) setCredential(_ *secureTokenSpec, _ string) {
}

func (ssm *noopSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
func (ssm *noopSecureStorageManager) getCredential(_ *secureTokenSpec) string {
return ""
}

func (ssm *noopSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) { //TODO implement me
func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { //TODO implement me
}
Loading

0 comments on commit d8df82e

Please sign in to comment.