diff --git a/pkg/token/token.go b/pkg/token/token.go index 4e981f43a..6749d7c76 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -335,8 +335,16 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, return Token{}, err } + // Fetch the timestamp when the credentials we're going to use for signing will not be valid anymore + // This operation is potentially racey, but the worst case is that we expire a token early + // Not all credential providers support this, so we ignore any returned errors + credentialsExpiration, _ := request.Config.Credentials.ExpiresAt() + // Set token expiration to 1 minute before the presigned URL expires for some cushion tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) + if !credentialsExpiration.IsZero() && credentialsExpiration.Before(tokenExpiration) { + tokenExpiration = credentialsExpiration.Add(-1 * time.Minute) + } // TODO: this may need to be a constant-time base64 encoding return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 5a1594d77..4883c0f32 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -590,6 +590,10 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { func TestGetWithSTS(t *testing.T) { clusterID := "test-cluster" + // Example non-real credentials + decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") + decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") + cases := []struct { name string creds *credentials.Credentials @@ -598,23 +602,34 @@ func TestGetWithSTS(t *testing.T) { wantErr error }{ { - "Non-zero time", - // Example non-real credentials - func() *credentials.Credentials { - decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") - decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") - return credentials.NewStaticCredentials( - string(decodedAkid), - string(decodedSk), - "", - ) - }(), - time.Unix(1682640000, 0), - Token{ + name: "Non-zero time", + creds: credentials.NewStaticCredentials( + string(decodedAkid), + string(decodedSk), + "", + ), + nowTime: time.Unix(1682640000, 0), + want: Token{ Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCUzQngtazhzLWF3cy1pZCZYLUFtei1TaWduYXR1cmU9ZTIxMWRiYTc3YWJhOWRjNDRiMGI2YmUzOGI4ZWFhZDA5MjU5OWM1MTU3ZjYzMTQ0NDRjNWI5ZDg1NzQ3ZjVjZQ", Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), }, - nil, + wantErr: nil, + }, + { + name: "Signing creds expire before token", + creds: credentials.NewCredentials(&fakeCredentialProvider{ + value: credentials.Value{ + AccessKeyID: string(decodedAkid), + SecretAccessKey: string(decodedSk), + }, + expiresAt: time.Unix(1682640000, 0).Local().Add(time.Minute * 10), + }), + nowTime: time.Unix(1682640000, 0), + want: Token{ + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCUzQngtazhzLWF3cy1pZCZYLUFtei1TaWduYXR1cmU9ZTIxMWRiYTc3YWJhOWRjNDRiMGI2YmUzOGI4ZWFhZDA5MjU5OWM1MTU3ZjYzMTQ0NDRjNWI5ZDg1NzQ3ZjVjZQ", + Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 9), + }, + wantErr: nil, }, } @@ -647,6 +662,25 @@ func TestGetWithSTS(t *testing.T) { } } +type fakeCredentialProvider struct { + value credentials.Value + expiresAt time.Time +} + +func (f *fakeCredentialProvider) Retrieve() (credentials.Value, error) { + return f.value, nil +} + +func (f *fakeCredentialProvider) IsExpired() bool { + return false +} + +var _ credentials.Expirer = (*fakeCredentialProvider)(nil) + +func (f *fakeCredentialProvider) ExpiresAt() time.Time { + return f.expiresAt +} + func TestGetStsRegion(t *testing.T) { tests := []struct { host string