diff --git a/secure_storage_manager.go b/secure_storage_manager.go index 6be71a579..e27add725 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -4,15 +4,15 @@ package gosnowflake import ( "encoding/json" + "errors" "fmt" + "github.com/99designs/keyring" + "golang.org/x/sys/unix" "os" "path/filepath" "runtime" "strings" - "sync" "time" - - "github.com/99designs/keyring" ) type tokenType string @@ -23,7 +23,6 @@ const ( ) const ( - driverName = "SNOWFLAKE-GO-DRIVER" credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR" credCacheFileName = "temporary_credential.json" ) @@ -78,23 +77,17 @@ func newSecureStorageManager() secureStorageManager { } type fileBasedSecureStorageManager struct { - credCacheFilePath string - localCredCache map[string]string - credCacheLock sync.RWMutex + credDirPath string } func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) { - ssm := &fileBasedSecureStorageManager{ - localCredCache: map[string]string{}, - credCacheLock: sync.RWMutex{}, + credDirPath := buildCredCacheDirPath() + if credDirPath == "" { + return nil, fmt.Errorf("failed to build cache dir path") } - credCacheDir := ssm.buildCredCacheDirPath() - if err := ssm.createCacheDir(credCacheDir); err != nil { - return nil, err + ssm := &fileBasedSecureStorageManager{ + credDirPath: credDirPath, } - credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName) - logger.Infof("Credentials cache path: %v", credCacheFilePath) - ssm.credCacheFilePath = credCacheFilePath return ssm, nil } @@ -109,141 +102,268 @@ func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) er return err } -func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string { - credCacheDir := os.Getenv(credCacheDirEnv) - if credCacheDir != "" { - return credCacheDir +func lookupCacheDir(envVar string, pathSegments ...string) (string, error) { + envVal := os.Getenv(envVar) + if envVal == "" { + return "", fmt.Errorf("environment variable %s not set", envVar) } - home := os.Getenv("HOME") - if home == "" { - logger.Info("HOME is blank") - return "" + + fileInfo, err := os.Stat(envVal) + if err != nil { + return "", fmt.Errorf("failed to stat %s=%s, due to %w", envVar, envVal, err) } - credCacheDir = filepath.Join(home, ".cache", "snowflake") - return credCacheDir + + if !fileInfo.IsDir() { + return "", fmt.Errorf("environment variable %s=%s is not a directory", envVar, envVal) + } + + cacheDir := envVal + + if len(pathSegments) > 0 { + for _, pathSegment := range pathSegments { + err := os.Mkdir(pathSegment, os.ModePerm) + if err != nil { + return "", fmt.Errorf("failed to create cache directory. %v, err: %w", pathSegment, err) + } + cacheDir = filepath.Join(cacheDir, pathSegment) + } + fileInfo, err = os.Stat(cacheDir) + if err != nil { + return "", fmt.Errorf("failed to stat %s=%s, due to %w", envVar, cacheDir, err) + } + } + + if fileInfo.Mode().Perm() != 0o700 { + err := os.Chmod(cacheDir, 0o700) + if err != nil { + return "", fmt.Errorf("failed to chmod cache directory. %v, err: %w", cacheDir, err) + } + } + + return cacheDir, nil +} + +func buildCredCacheDirPath() string { + type cacheDirConf struct { + envVar string + pathSegments []string + } + confs := []cacheDirConf{ + {envVar: credCacheDirEnv, pathSegments: []string{}}, + {envVar: "XDG_CACHE_DIR", pathSegments: []string{"snowflake"}}, + {envVar: "HOME", pathSegments: []string{".cache", "snowflake"}}, + } + for _, conf := range confs { + path, err := lookupCacheDir(conf.envVar, conf.pathSegments...) + if err != nil { + logger.Debugf("Skipping %s in cache directory lookup due to %w", conf.envVar, err) + } else { + logger.Infof("Using %s as cache directory", path) + return path + } + } + + return "" +} + +func (ssm *fileBasedSecureStorageManager) getTokens(data map[string]any) map[string]interface{} { + val, ok := data["tokens"] + emptyMap := map[string]interface{}{} + if !ok { + data["tokens"] = emptyMap + return emptyMap + } + + tokens, ok := val.(map[string]interface{}) + if !ok { + data["tokens"] = emptyMap + return emptyMap + } + + return tokens } func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) { - if value == "" { - logger.Debug("no token provided") - } else { - credentialsKey := tokenSpec.buildKey() - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() - ssm.localCredCache[credentialsKey] = value + credentialsKey := tokenSpec.buildKey() + err := ssm.lockFile() + if err != nil { + logger.Warnf("Set credential failed. Unable to lock cache. %v", err) + return + } + defer ssm.unlockFile() + + credCache := ssm.readTemporaryCacheFile() + ssm.getTokens(credCache)[credentialsKey] = value + + err = ssm.writeTemporaryCacheFile(credCache) + if err != nil { + logger.Warnf("Set credential failed. Unable to write cache. %v", err) + return + } + + return +} - j, err := json.Marshal(ssm.localCredCache) +func (ssm *fileBasedSecureStorageManager) lockPath() string { + return filepath.Join(ssm.credDirPath, credCacheFileName+".lck") +} + +func (ssm *fileBasedSecureStorageManager) lockFile() error { + const NUM_RETRIES = 10 + const RETRY_INTERVAL = 100 * time.Millisecond + lockPath := ssm.lockPath() + locked := false + for i := 0; i < NUM_RETRIES; i++ { + err := os.Mkdir(lockPath, 0o700) if err != nil { - logger.Warnf("failed to convert credential to JSON.") - return + if errors.Is(err, os.ErrExist) { + time.Sleep(RETRY_INTERVAL) + continue + } + return fmt.Errorf("failed to create cache lock: %v, err: %v", lockPath, err) } + locked = true + break + } - logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) - credCacheLockFileName := ssm.credCacheFilePath + ".lck" - logger.Debugf("Creating lock file. %v", credCacheLockFileName) - err = os.Mkdir(credCacheLockFileName, 0600) + if !locked { + logger.Warnf("failed to lock cache lock. lockPath: %v.", lockPath) + var stat unix.Stat_t + err := unix.Stat(lockPath, &stat) + if err != nil { + return fmt.Errorf("failed to stat %v and determine if lock is stale. err: %v", lockPath, err) + } - switch { - case os.IsExist(err): - statinfo, err := os.Stat(credCacheLockFileName) + if stat.Ctim.Nano()+time.Second.Nanoseconds() < time.Now().UnixNano() { + err := os.Remove(lockPath) if err != nil { - logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) - return - } - if time.Since(statinfo.ModTime()) < 15*time.Minute { - logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) - return + return fmt.Errorf("failed to remove %v while trying to remove stale lock. err: %v", lockPath, err) } - if err = os.Remove(credCacheLockFileName); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return + err = os.Mkdir(lockPath, 0o700) + if err != nil { + return fmt.Errorf("failed to recreate cache lock after removing stale lock. %v, err: %v", lockPath, err) } } - defer os.RemoveAll(credCacheLockFileName) + } + return nil +} - if err = os.WriteFile(ssm.credCacheFilePath, j, 0644); err != nil { - logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) - } +func (ssm *fileBasedSecureStorageManager) unlockFile() { + lockPath := ssm.lockPath() + err := os.Remove(lockPath) + if err != nil { + logger.Warnf("Failed to unlock cache lock: %v. %v", lockPath, err) } } func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string { credentialsKey := tokenSpec.buildKey() - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() - localCredCache := ssm.readTemporaryCacheFile() - cred := localCredCache[credentialsKey] - if cred != "" { - logger.Debug("Successfully read token. Returning as string") - } else { - logger.Debug("Returned credential is empty") + err := ssm.lockFile() + if err != nil { + logger.Warn("Failed to lock credential cache file.") + return "" } - return cred + + credCache := ssm.readTemporaryCacheFile() + ssm.unlockFile() + cred, ok := ssm.getTokens(credCache)[credentialsKey] + if !ok { + return "" + } + + credStr, ok := cred.(string) + if !ok { + return "" + } + + return credStr +} + +func (ssm *fileBasedSecureStorageManager) credFilePath() string { + return filepath.Join(ssm.credDirPath, credCacheFileName) } -func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string { - jsonData, err := os.ReadFile(ssm.credCacheFilePath) +func (ssm *fileBasedSecureStorageManager) ensurePermissions() error { + dirInfo, err := os.Stat(ssm.credDirPath) if err != nil { - logger.Debugf("Failed to read credential file: %v", err) - return nil + return err + } + + if dirInfo.Mode().Perm() != 0o700 { + return fmt.Errorf("incorrect permissions(%o, expected 700) for %s.", dirInfo.Mode().Perm(), ssm.credDirPath) } - err = json.Unmarshal([]byte(jsonData), &ssm.localCredCache) + + fileInfo, err := os.Stat(ssm.credFilePath()) if err != nil { - logger.Debugf("failed to read JSON. Err: %v", err) - return nil + return err } - return ssm.localCredCache + if fileInfo.Mode().Perm() != 0o600 { + logger.Debugf("Incorrect permissions(%o, expected 600) for credential file.", fileInfo.Mode().Perm()) + err := os.Chmod(ssm.credFilePath(), 0o600) + if err != nil { + return fmt.Errorf("Failed to chmod credential file: %v", err) + } + logger.Debug("Successfully fixed credential file permissions.") + } + + return nil +} + +func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]any { + err := ssm.ensurePermissions() + if err != nil { + logger.Warnf("Failed to ensure permission for temporary cache file. %v.\n", err) + return map[string]any{} + } + + jsonData, err := os.ReadFile(ssm.credFilePath()) + if err != nil { + logger.Warnf("Failed to read credential cache file. %v.\n", err) + return map[string]any{} + } + + credentialsMap := map[string]any{} + err = json.Unmarshal([]byte(jsonData), &credentialsMap) + if err != nil { + logger.Warnf("Failed to unmarshal credential cache file. %v.\n", err) + } + + return credentialsMap } func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) { - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() credentialsKey := tokenSpec.buildKey() - delete(ssm.localCredCache, credentialsKey) - j, err := json.Marshal(ssm.localCredCache) + err := ssm.lockFile() if err != nil { - logger.Warnf("failed to convert credential to JSON.") + logger.Warnf("Set credential failed. Unable to lock cache. %v", err) return } - ssm.writeTemporaryCacheFile(j) -} + defer ssm.unlockFile() -func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte) { - logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) - credCacheLockFileName := ssm.credCacheFilePath + ".lck" - err := os.Mkdir(credCacheLockFileName, 0600) - logger.Debugf("Creating lock file. %v", credCacheLockFileName) + credCache := ssm.readTemporaryCacheFile() + delete(ssm.getTokens(credCache), credentialsKey) - switch { - case os.IsExist(err): - statinfo, err := os.Stat(credCacheLockFileName) - if err != nil { - logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) - return - } - if time.Since(statinfo.ModTime()) < 15*time.Minute { - logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) - return - } - if err = os.Remove(credCacheLockFileName); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } + err = ssm.writeTemporaryCacheFile(credCache) + if err != nil { + logger.Warnf("Set credential failed. Unable to write cache. %v", err) + return } - defer os.RemoveAll(credCacheLockFileName) - if err = os.WriteFile(ssm.credCacheFilePath, input, 0644); err != nil { - logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) + return +} + +func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[string]any) error { + bytes, err := json.Marshal(cache) + if err != nil { + return fmt.Errorf("failed to marshal credential cache map. %w", err) + } + + err = os.WriteFile(ssm.credFilePath(), bytes, 0600) + if err != nil { + return fmt.Errorf("failed to write the credential cache file: %w", err) } + return nil } type keyringSecureStorageManager struct { @@ -342,10 +462,8 @@ func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenS } func buildCredentialsKey(host, user string, credType tokenType) string { - host = strings.ToUpper(host) - user = strings.ToUpper(user) - credTypeStr := strings.ToUpper(string(credType)) - return host + ":" + user + ":" + driverName + ":" + credTypeStr + credTypeStr := string(credType) + return host + ":" + user + ":" + credTypeStr } type noopSecureStorageManager struct { @@ -362,5 +480,5 @@ func (ssm *noopSecureStorageManager) getCredential(_ *secureTokenSpec) string { return "" } -func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { //TODO implement me +func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { } diff --git a/secure_storage_manager_test.go b/secure_storage_manager_test.go index 22299708b..6b09e0828 100644 --- a/secure_storage_manager_test.go +++ b/secure_storage_manager_test.go @@ -3,9 +3,47 @@ package gosnowflake import ( + "os" "testing" ) +type EnvOverride struct { + env string + oldValue string +} + +func (e *EnvOverride) rollback() { + if e.oldValue != "" { + os.Setenv(e.env, e.oldValue) + } else { + os.Unsetenv(e.env) + } +} + +func override_env(env string, value string) EnvOverride { + oldValue := os.Getenv(env) + os.Setenv(env, value) + return EnvOverride{env, oldValue} +} + +func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) { + //skipOnNonLinux(t, "Not supported on non-linux") + os.Mkdir("./testdata", 0777) + credCacheDirEnvOverride := override_env(credCacheDirEnv, "./testdata") + defer credCacheDirEnvOverride.rollback() + fbss, err := newFileBasedSecureStorageManager() + if err != nil { + t.Fatal(err) + } + + tokenSpec := newMfaTokenSpec("host.xd", "johndoe") + cred := "token123" + fbss.setCredential(tokenSpec, cred) + assertEqualE(t, fbss.getCredential(tokenSpec), cred) + fbss.deleteCredential(tokenSpec) + assertEqualE(t, fbss.getCredential(tokenSpec), "") +} + func TestSetAndGetCredentialMfa(t *testing.T) { for _, tokenSpec := range []*secureTokenSpec{ newMfaTokenSpec("testhost", "testuser"), diff --git a/util_test.go b/util_test.go index f67912dbb..6f686b6ca 100644 --- a/util_test.go +++ b/util_test.go @@ -404,6 +404,12 @@ func skipOnMac(t *testing.T, reason string) { } } +func skipOnNonLinux(t *testing.T, reason string) { + if runtime.GOOS != "linux" { + t.Skip("skipped on non-linux OS: " + reason) + } +} + func randomString(n int) string { r := rand.New(rand.NewSource(time.Now().UnixNano())) alpha := []rune("abcdefghijklmnopqrstuvwxyz")