Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for regional STS endpoint(s) #55

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions credentials_getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ type STSCredentialsGetter struct {
}

// NewSTSCredentialsGetter initializes a new STS based credentials fetcher.
func NewSTSCredentialsGetter(sess *session.Session, baseRoleARN, baseRoleARNPrefix string, configs ...*aws.Config) *STSCredentialsGetter {
func NewSTSCredentialsGetter(sess *session.Session, baseRoleARN, baseRoleARNPrefix string, config *aws.Config) *STSCredentialsGetter {
return &STSCredentialsGetter{
svc: sts.New(sess, configs...),
svc: sts.New(sess, config),
baseRoleARN: baseRoleARN,
baseRoleARNPrefix: baseRoleARNPrefix,
}
Expand Down
2 changes: 1 addition & 1 deletion credentials_getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (sts *mockSTSAPI) AssumeRole(*sts.AssumeRoleInput) (*sts.AssumeRoleOutput,

func TestGet(t *testing.T) {
sess := session.New(&aws.Config{Region: aws.String("region")})
getter := NewSTSCredentialsGetter(sess, "", "")
getter := NewSTSCredentialsGetter(sess, "", "", &aws.Config{})
getter.svc = &mockSTSAPI{
err: nil,
assumeRoleResp: &sts.AssumeRoleOutput{
Expand Down
79 changes: 68 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"fmt"
"net/url"
"os"
"os/signal"
Expand All @@ -11,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
log "github.com/sirupsen/logrus"
"github.com/zalando-incubator/kube-aws-iam-controller/pkg/clientset"
Expand All @@ -27,17 +29,28 @@ const (

var (
config struct {
Debug bool
Interval time.Duration
RefreshLimit time.Duration
EventQueueSize int
BaseRoleARN string
APIServer *url.URL
Namespace string
AssumeRole string
Debug bool
Interval time.Duration
RefreshLimit time.Duration
EventQueueSize int
BaseRoleARN string
APIServer *url.URL
Namespace string
AssumeRole string
UseRegionalEndpoint bool
}
)

type STSEndpoint struct {
Endpoint string
}

// NewSTSCredentialsGetter initializes a new STS based credentials fetcher.
func NewSTSEndpoint() *STSEndpoint {
return &STSEndpoint{
Endpoint: "sts.amazonaws.com",
}
}
func main() {
kingpin.Flag("debug", "Enable debug logging.").BoolVar(&config.Debug)
kingpin.Flag("interval", "Interval between syncing secrets.").
Expand All @@ -53,6 +66,7 @@ func main() {
kingpin.Flag("namespace", "Limit the controller to a certain namespace.").
Default(v1.NamespaceAll).StringVar(&config.Namespace)
kingpin.Flag("apiserver", "API server url.").URLVar(&config.APIServer)
kingpin.Flag("use-regional-sts-endpoint", "Use the regional sts endpoint if AWS_REGION is set").BoolVar(&config.UseRegionalEndpoint)
kingpin.Parse()

if config.Debug {
Expand Down Expand Up @@ -90,17 +104,21 @@ func main() {
}
log.Debugf("Parsed Base Role ARN prefix: %s", baseRoleARNPrefix)

awsConfigs := make([]*aws.Config, 0, 1)
awsConfigs := aws.NewConfig()
if config.AssumeRole != "" {
if !strings.HasPrefix(config.AssumeRole, baseRoleARNPrefix) {
config.AssumeRole = config.BaseRoleARN + config.AssumeRole
}
log.Infof("Using custom Assume Role: %s", config.AssumeRole)
creds := stscreds.NewCredentials(awsSess, config.AssumeRole)
awsConfigs = append(awsConfigs, &aws.Config{Credentials: creds})
awsConfigs = awsConfigs.WithCredentials(creds)
}

credsGetter := NewSTSCredentialsGetter(awsSess, config.BaseRoleARN, baseRoleARNPrefix, awsConfigs...)
if config.UseRegionalEndpoint {
awsConfigs = awsConfigs.WithEndpointResolver(NewSTSEndpoint())
}

credsGetter := NewSTSCredentialsGetter(awsSess, config.BaseRoleARN, baseRoleARNPrefix, awsConfigs)

podsEventCh := make(chan *PodEvent, config.EventQueueSize)

Expand Down Expand Up @@ -131,6 +149,45 @@ func main() {
controller.Run(ctx)
}

// GetEndpointFromRegion formas a standard sts endpoint url given a region
func GetEndpointFromRegion(region string) string {
endpoint := fmt.Sprintf("https://sts.%s.amazonaws.com", region)
if strings.HasPrefix(region, "cn-") {
endpoint = fmt.Sprintf("https://sts.%s.amazonaws.com.cn", region)
}
log.Debugf("Using Regional STS Endpoint: %s", endpoint)
return endpoint
}

// IsValidRegion tests for a vaild region name
func IsValidRegion(promisedLand string) bool {
partitions := endpoints.DefaultResolver().(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
for region := range p.Regions() {
if promisedLand == region {
return true
}
}
}
return false
}

// EndpointFor implements the endpoints.Resolver interface for use with sts
func (e *STSEndpoint) EndpointFor(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
// only for sts service
if service == "sts" {
// only if a valid region is explicitly set
if IsValidRegion(region) {
e.Endpoint = GetEndpointFromRegion(region)
return endpoints.ResolvedEndpoint{
URL: e.Endpoint,
SigningRegion: region,
}, nil
}
}
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}

// handleSigterm handles SIGTERM signal sent to the process.
func handleSigterm(cancelFunc func()) {
signals := make(chan os.Signal, 1)
Expand Down