diff --git a/secure_storage_manager.go b/secure_storage_manager.go index 0b4f89a54..6caaa62a2 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -5,14 +5,11 @@ package gosnowflake import ( "encoding/json" "fmt" + "github.com/99designs/keyring" "os" "path/filepath" "runtime" "strings" - "sync" - "time" - - "github.com/99designs/keyring" ) type tokenType string @@ -78,23 +75,17 @@ func newSecureStorageManager() secureStorageManager { } type fileBasedSecureStorageManager struct { - credCacheFilePath string - localCredCache map[string]string - credCacheLock sync.RWMutex + credDirPath string } func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) { - ssm := &fileBasedSecureStorageManager{ - localCredCache: map[string]string{}, - credCacheLock: sync.RWMutex{}, + credDirPath := buildCredCacheDirPath() + if credDirPath == "" { + return nil, fmt.Errorf("failed to build cache dir path") } - credCacheDir := ssm.buildCredCacheDirPath() - if err := ssm.createCacheDir(credCacheDir); err != nil { - return nil, err + ssm := &fileBasedSecureStorageManager{ + credDirPath: credDirPath, } - credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName) - logger.Infof("Credentials cache path: %v", credCacheFilePath) - ssm.credCacheFilePath = credCacheFilePath return ssm, nil } @@ -140,7 +131,7 @@ func lookupCacheDir(envVar string, pathSegments ...string) (string, error) { } } - if fileInfo.Mode()&os.ModePerm != 0o700 { + if fileInfo.Mode().Perm() != 0o700 { err := os.Chmod(cacheDir, 0o700) if err != nil { return "", fmt.Errorf("failed to chmod cache directory. %v, err: %w", cacheDir, err) @@ -150,7 +141,7 @@ func lookupCacheDir(envVar string, pathSegments ...string) (string, error) { return cacheDir, nil } -func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string { +func buildCredCacheDirPath() string { type cacheDirConf struct { envVar string pathSegments []string @@ -174,126 +165,155 @@ func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string { } func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) { - if value == "" { - logger.Debug("no token provided") - } else { - credentialsKey := tokenSpec.buildKey() - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() - ssm.localCredCache[credentialsKey] = value - - j, err := json.Marshal(ssm.localCredCache) - if err != nil { - logger.Warnf("failed to convert credential to JSON.") - return - } + credentialsKey := tokenSpec.buildKey() + err := ssm.lockFile() + if err != nil { + logger.Warnf("Set credential failed. Unable to lock cache. %v", err) + return + } + defer ssm.unlockFile() - logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) - credCacheLockFileName := ssm.credCacheFilePath + ".lck" - logger.Debugf("Creating lock file. %v", credCacheLockFileName) - err = os.Mkdir(credCacheLockFileName, 0600) + credCache, err := ssm.readTemporaryCacheFile() + if err != nil { + logger.Warnf("Set credential failed. Unable to read cache. %v", err) + return + } - switch { - case os.IsExist(err): - statinfo, err := os.Stat(credCacheLockFileName) - if err != nil { - logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) - return - } - if time.Since(statinfo.ModTime()) < 15*time.Minute { - logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) - return - } - if err = os.Remove(credCacheLockFileName); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - } - defer os.RemoveAll(credCacheLockFileName) + credCache["tokens"][credentialsKey] = value - if err = os.WriteFile(ssm.credCacheFilePath, j, 0644); err != nil { - logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) - } + err = ssm.writeTemporaryCacheFile(credCache) + if err != nil { + logger.Warnf("Set credential failed. Unable to write cache. %v", err) + return } + + return +} + +func (ssm *fileBasedSecureStorageManager) lockFile() error { + // TODO Implement locks + return nil +} + +func (ssm *fileBasedSecureStorageManager) unlockFile() { + // TODO Implement locks } func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string { credentialsKey := tokenSpec.buildKey() - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() - localCredCache := ssm.readTemporaryCacheFile() - cred := localCredCache[credentialsKey] + credCache := map[string]map[string]string{} + + err := ssm.lockFile() + if err != nil { + logger.Warn("Failed to lock credential cache file.") + return "" + } + + credCache, err = ssm.readTemporaryCacheFile() + ssm.unlockFile() + if err != nil { + logger.Warnf("Failed to read temporary cache file. %v.\n", err) + return "" + } + + cred := credCache["tokens"][credentialsKey] if cred != "" { logger.Debug("Successfully read token. Returning as string") } else { logger.Debug("Returned credential is empty") } + return cred } -func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string { - jsonData, err := os.ReadFile(ssm.credCacheFilePath) +func (ssm *fileBasedSecureStorageManager) credFilePath() string { + return filepath.Join(ssm.credDirPath, credCacheFileName) +} + +func (ssm *fileBasedSecureStorageManager) ensurePermissions() error { + dirInfo, err := os.Stat(ssm.credDirPath) if err != nil { - logger.Debugf("Failed to read credential file: %v", err) - return nil + return err } - err = json.Unmarshal([]byte(jsonData), &ssm.localCredCache) + + if dirInfo.Mode().Perm() != 0o700 { + return fmt.Errorf("incorrect permissions(%o, expected 700) for %s.", dirInfo.Mode().Perm(), ssm.credDirPath) + } + + fileInfo, err := os.Stat(ssm.credFilePath()) if err != nil { - logger.Debugf("failed to read JSON. Err: %v", err) - return nil + return err + } + + if fileInfo.Mode().Perm() != 0o600 { + logger.Debugf("Incorrect permissions(%o, expected 600) for credential file.", fileInfo.Mode().Perm()) + err := os.Chmod(ssm.credFilePath(), 0o600) + if err != nil { + return fmt.Errorf("Failed to chmod credential file: %v", err) + } + logger.Debug("Successfully fixed credential file permissions.") } - return ssm.localCredCache + return nil +} + +func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() (map[string]map[string]string, error) { + err := ssm.ensurePermissions() + if err != nil { + return nil, err + } + + jsonData, err := os.ReadFile(ssm.credFilePath()) + if err != nil { + return nil, fmt.Errorf("failed to read credential cache file: %w", err) + } + + credentialsMap := map[string]map[string]string{} + err = json.Unmarshal([]byte(jsonData), &credentialsMap) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal credential cache file: %w", err) + } + + return credentialsMap, nil } func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) { - ssm.credCacheLock.Lock() - defer ssm.credCacheLock.Unlock() credentialsKey := tokenSpec.buildKey() - delete(ssm.localCredCache, credentialsKey) - j, err := json.Marshal(ssm.localCredCache) + err := ssm.lockFile() if err != nil { - logger.Warnf("failed to convert credential to JSON.") + logger.Warnf("Set credential failed. Unable to lock cache. %v", err) return } - ssm.writeTemporaryCacheFile(j) -} + defer ssm.unlockFile() -func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte) { - logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) - credCacheLockFileName := ssm.credCacheFilePath + ".lck" - err := os.Mkdir(credCacheLockFileName, 0600) - logger.Debugf("Creating lock file. %v", credCacheLockFileName) + credCache, err := ssm.readTemporaryCacheFile() + if err != nil { + logger.Warnf("Set credential failed. Unable to read cache. %v", err) + return + } - switch { - case os.IsExist(err): - statinfo, err := os.Stat(credCacheLockFileName) - if err != nil { - logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) - return - } - if time.Since(statinfo.ModTime()) < 15*time.Minute { - logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) - return - } - if err = os.Remove(credCacheLockFileName); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } + delete(credCache["tokens"], credentialsKey) + + err = ssm.writeTemporaryCacheFile(credCache) + if err != nil { + logger.Warnf("Set credential failed. Unable to write cache. %v", err) + return + } + + return +} + +func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[string]map[string]string) error { + bytes, err := json.Marshal(cache) + if err != nil { + return fmt.Errorf("failed to marshal credential cache map. %w", err) } - defer os.RemoveAll(credCacheLockFileName) - if err = os.WriteFile(ssm.credCacheFilePath, input, 0644); err != nil { - logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) + err = os.WriteFile(ssm.credFilePath(), bytes, 0600) + if err != nil { + return fmt.Errorf("failed to write the credential cache file: %w", err) } + return nil } type keyringSecureStorageManager struct {