Skip to content

Commit b246b4e

Browse files
ostermanclaudeaknysh
authored
fix: MFA authentication by preventing session token keyring overwrite (#1757)
* fix: Prevent session tokens from overwriting long-lived credentials in keyring Session tokens (temporary credentials with SessionToken field) should not be cached in the keyring as they overwrite long-lived credentials needed for subsequent authentication attempts. Session tokens are already persisted to provider-specific storage (AWS files, etc.) and loaded via LoadCredentials(). This fixes the issue where users had to run `atmos auth user configure` repeatedly because authentication would fail after the first attempt. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * fix: Display credential expiration times in user's local timezone Convert parsed credential expiration times from UTC to local timezone before returning them for display. This ensures consistency across all time displays in auth commands (whoami, list) and matches user expectations. Fixed in: - pkg/auth/types/aws_credentials.go: Convert AWS credential expiration to local time - pkg/auth/types/github_oidc_credentials.go: Convert OIDC token expiration to local time 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Claude <[email protected]> Co-authored-by: Andriy Knysh <[email protected]>
1 parent 7b37808 commit b246b4e

File tree

5 files changed

+209
-6
lines changed

5 files changed

+209
-6
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,6 @@ performance-optimization/
9999
*.bak
100100
*.go.bak
101101
**/*.go.bak
102+
103+
# Scratch/temporary analysis files
104+
.scratch/

pkg/auth/manager.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,17 @@ func (m *manager) getChainStepName(index int) string {
910910
return "unknown"
911911
}
912912

913+
// isSessionToken checks if credentials are temporary session tokens.
914+
// Session tokens are identified by the presence of a SessionToken field.
915+
// These should not be cached in keyring as they overwrite long-lived credentials.
916+
func isSessionToken(creds types.ICredentials) bool {
917+
if awsCreds, ok := creds.(*types.AWSCredentials); ok {
918+
return awsCreds.SessionToken != ""
919+
}
920+
// Add other credential types as needed.
921+
return false
922+
}
923+
913924
// authenticateIdentityChain performs sequential authentication through an identity chain.
914925
func (m *manager) authenticateIdentityChain(ctx context.Context, startIndex int, initialCreds types.ICredentials) (types.ICredentials, error) {
915926
log.Debug("Authenticating identity chain", "chainLength", len(m.chain), "startIndex", startIndex, "chain", m.chain)
@@ -936,11 +947,19 @@ func (m *manager) authenticateIdentityChain(ctx context.Context, startIndex int,
936947

937948
currentCreds = nextCreds
938949

939-
// Cache credentials for this level.
940-
if err := m.credentialStore.Store(identityStep, currentCreds); err != nil {
941-
log.Debug("Failed to cache credentials", "identityStep", identityStep, "error", err)
950+
// Cache credentials for this level, but skip session tokens.
951+
// Session tokens are already persisted to provider-specific storage (e.g., AWS files)
952+
// and can be loaded via identity.LoadCredentials().
953+
// Caching session tokens in keyring would overwrite long-lived credentials
954+
// that are needed for subsequent authentication attempts.
955+
if isSessionToken(currentCreds) {
956+
log.Debug("Skipping keyring cache for session tokens", "identityStep", identityStep)
942957
} else {
943-
log.Debug("Cached credentials", "identityStep", identityStep)
958+
if err := m.credentialStore.Store(identityStep, currentCreds); err != nil {
959+
log.Debug("Failed to cache credentials", "identityStep", identityStep, "error", err)
960+
} else {
961+
log.Debug("Cached credentials", "identityStep", identityStep)
962+
}
944963
}
945964

946965
log.Debug("Chained identity", "from", m.getChainStepName(i-1), "to", identityStep)
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/cloudposse/atmos/pkg/auth/types"
11+
"github.com/cloudposse/atmos/pkg/schema"
12+
)
13+
14+
// TestSessionTokenDoesNotOverwriteLongLivedCredentialsInKeyring verifies that session tokens
15+
// (temporary credentials) do NOT overwrite long-lived credentials in the keyring.
16+
//
17+
// Intended behavior:
18+
// - Session tokens should NOT be cached in keyring
19+
// - Long-lived credentials should remain in keyring unchanged
20+
// - Session tokens should only exist in provider-specific storage (AWS files, etc.)
21+
//
22+
// This ensures users don't need to reconfigure credentials after authentication,
23+
// as long-lived credentials remain available for generating new session tokens.
24+
func TestSessionTokenDoesNotOverwriteLongLivedCredentialsInKeyring(t *testing.T) {
25+
ctx := context.Background()
26+
27+
// Step 1: Store long-lived AWS credentials in keyring.
28+
// This simulates what `atmos auth user configure` does.
29+
longLivedCreds := &types.AWSCredentials{
30+
AccessKeyID: "AKIA_LONG_LIVED_KEY",
31+
SecretAccessKey: "long_lived_secret_key",
32+
MfaArn: "arn:aws:iam::123456789012:mfa/test-user",
33+
SessionDuration: "12h",
34+
Region: "us-east-1",
35+
// NO SessionToken - this is a long-lived credential
36+
}
37+
38+
// Create test credential store with long-lived credentials.
39+
// Pre-populate keyring with long-lived credentials for both provider and identity.
40+
// This simulates what `atmos auth user configure` does.
41+
store := &testStore{
42+
data: map[string]any{
43+
"test-provider": longLivedCreds, // Provider credentials (long-lived)
44+
"test-identity": longLivedCreds, // Identity credentials (long-lived) - should NOT be overwritten
45+
},
46+
expired: map[string]bool{
47+
"test-provider": false,
48+
"test-identity": false,
49+
},
50+
}
51+
52+
// Step 2: Create a mock identity that returns session tokens.
53+
// This simulates what AWS user identity does after calling STS GetSessionToken.
54+
sessionCreds := &types.AWSCredentials{
55+
AccessKeyID: "ASIA_SESSION_KEY",
56+
SecretAccessKey: "session_secret_key",
57+
SessionToken: "session_token_12345", // Session tokens have this field
58+
Region: "us-east-1",
59+
Expiration: "2099-12-31T23:59:59Z",
60+
}
61+
62+
mockIdentity := &mockIdentityReturningSessionTokens{
63+
sessionCreds: sessionCreds,
64+
}
65+
66+
// Step 3: Create auth manager with minimal setup.
67+
authConfig := &schema.AuthConfig{
68+
Providers: map[string]schema.Provider{
69+
"test-provider": {
70+
Kind: "test",
71+
},
72+
},
73+
Identities: map[string]schema.Identity{
74+
"test-identity": {
75+
Kind: "test",
76+
Via: &schema.IdentityVia{
77+
Provider: "test-provider",
78+
},
79+
},
80+
},
81+
}
82+
83+
m := &manager{
84+
config: authConfig,
85+
credentialStore: store,
86+
providers: map[string]types.Provider{},
87+
identities: map[string]types.Identity{
88+
"test-identity": mockIdentity,
89+
},
90+
chain: []string{"test-provider", "test-identity"},
91+
}
92+
93+
// Step 4: Authenticate through the identity chain.
94+
// This returns session tokens but should NOT cache them in keyring.
95+
returnedCreds, err := m.authenticateIdentityChain(ctx, 1, longLivedCreds)
96+
require.NoError(t, err, "authenticateIdentityChain should succeed")
97+
98+
// Verify that session tokens were returned (this is correct and expected).
99+
awsCreds, ok := returnedCreds.(*types.AWSCredentials)
100+
require.True(t, ok, "Returned credentials should be AWS credentials")
101+
assert.NotEmpty(t, awsCreds.SessionToken, "Should return session tokens")
102+
assert.Equal(t, "ASIA_SESSION_KEY", awsCreds.AccessKeyID, "Should return session access key")
103+
104+
// Step 5: Verify keyring STILL contains long-lived credentials (INTENDED BEHAVIOR).
105+
// Session tokens should NOT have been cached in keyring.
106+
retrievedCreds, err := store.Retrieve("test-identity")
107+
require.NoError(t, err, "Should retrieve credentials from keyring")
108+
109+
retrievedAWSCreds, ok := retrievedCreds.(*types.AWSCredentials)
110+
require.True(t, ok, "Retrieved credentials should be AWS credentials")
111+
112+
// INTENDED BEHAVIOR: Keyring should NOT contain session tokens.
113+
assert.Empty(t, retrievedAWSCreds.SessionToken,
114+
"Keyring should NOT contain session tokens - they should only be in provider storage (AWS files)")
115+
116+
// INTENDED BEHAVIOR: Keyring should STILL contain long-lived credentials.
117+
assert.Equal(t, "AKIA_LONG_LIVED_KEY", retrievedAWSCreds.AccessKeyID,
118+
"Keyring should preserve long-lived access key")
119+
120+
assert.Equal(t, "long_lived_secret_key", retrievedAWSCreds.SecretAccessKey,
121+
"Keyring should preserve long-lived secret key")
122+
123+
// INTENDED BEHAVIOR: Keyring should preserve MFA ARN and session duration.
124+
assert.Equal(t, "arn:aws:iam::123456789012:mfa/test-user", retrievedAWSCreds.MfaArn,
125+
"Keyring should preserve MFA ARN for future session token generation")
126+
127+
assert.Equal(t, "12h", retrievedAWSCreds.SessionDuration,
128+
"Keyring should preserve session duration for future session token generation")
129+
}
130+
131+
// mockIdentityReturningSessionTokens is a mock identity that returns session tokens.
132+
// This simulates what AWS user identity does after calling STS GetSessionToken.
133+
type mockIdentityReturningSessionTokens struct {
134+
sessionCreds *types.AWSCredentials
135+
}
136+
137+
func (m *mockIdentityReturningSessionTokens) Kind() string {
138+
return "test"
139+
}
140+
141+
func (m *mockIdentityReturningSessionTokens) Authenticate(ctx context.Context, baseCreds types.ICredentials) (types.ICredentials, error) {
142+
// Return session tokens (credentials WITH SessionToken field).
143+
// This simulates what generateSessionToken() does in AWS user identity.
144+
return m.sessionCreds, nil
145+
}
146+
147+
func (m *mockIdentityReturningSessionTokens) Validate() error {
148+
return nil
149+
}
150+
151+
func (m *mockIdentityReturningSessionTokens) Environment() (map[string]string, error) {
152+
return map[string]string{}, nil
153+
}
154+
155+
func (m *mockIdentityReturningSessionTokens) PrepareEnvironment(ctx context.Context, environ map[string]string) (map[string]string, error) {
156+
return environ, nil
157+
}
158+
159+
func (m *mockIdentityReturningSessionTokens) PostAuthenticate(ctx context.Context, params *types.PostAuthenticateParams) error {
160+
return nil
161+
}
162+
163+
func (m *mockIdentityReturningSessionTokens) CredentialsExist() (bool, error) {
164+
return true, nil
165+
}
166+
167+
func (m *mockIdentityReturningSessionTokens) LoadCredentials(ctx context.Context) (types.ICredentials, error) {
168+
// Return session credentials from "storage" (simulates loading from AWS files).
169+
return m.sessionCreds, nil
170+
}
171+
172+
func (m *mockIdentityReturningSessionTokens) Logout(ctx context.Context) error {
173+
return nil
174+
}
175+
176+
func (m *mockIdentityReturningSessionTokens) GetProviderName() (string, error) {
177+
return "test-provider", nil
178+
}

pkg/auth/types/aws_credentials.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ func (c *AWSCredentials) GetExpiration() (*time.Time, error) {
4545
if err != nil {
4646
return nil, fmt.Errorf("%w: failed parsing AWS credential expiration: %w", errUtils.ErrInvalidAuthConfig, err)
4747
}
48-
return &expTime, nil
48+
// Convert to local timezone for display to user.
49+
localTime := expTime.Local()
50+
return &localTime, nil
4951
}
5052

5153
// BuildWhoamiInfo implements ICredentials for AWSCredentials.

pkg/auth/types/github_oidc_credentials.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func (c *OIDCCredentials) GetExpiration() (*time.Time, error) {
4949
if claims.Exp == 0 {
5050
return nil, nil
5151
}
52-
t := time.Unix(claims.Exp, 0).UTC()
52+
// Convert to local timezone for display to user.
53+
t := time.Unix(claims.Exp, 0).Local()
5354
return &t, nil
5455
}
5556

0 commit comments

Comments
 (0)