diff --git a/configuration/configuration.go b/configuration/configuration.go index 031cea9..3d3d7eb 100644 --- a/configuration/configuration.go +++ b/configuration/configuration.go @@ -93,8 +93,10 @@ func postProcessing(config *Configuration, cliArguments cliparser.CliArguments) if *cliArguments.Profile != "" { config.DefaultProfile = *cliArguments.Profile } - if *cliArguments.MFA != config.DefaultDecisionForMFA { + if *cliArguments.MFA { config.DefaultDecisionForMFA = *cliArguments.MFA + } else { + *cliArguments.MFA = config.DefaultDecisionForMFA } if *cliArguments.DurationForMFA > 0 { config.DefaultDurationForMFA = *cliArguments.DurationForMFA diff --git a/mysession/mysession.go b/mysession/mysession.go index 3032763..0327f61 100644 --- a/mysession/mysession.go +++ b/mysession/mysession.go @@ -2,18 +2,34 @@ package mysession import ( "errors" + "github.com/Appliscale/perun/cliparser" "github.com/Appliscale/perun/context" "github.com/Appliscale/perun/utilities" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/go-ini/ini" + "os" "os/user" "time" ) const dateFormat = "2006-01-02 15:04:05 MST" +func InitializeSession(context *context.Context) *session.Session { + tokenError := UpdateSessionToken(context.Config.DefaultProfile, context.Config.DefaultRegion, context.Config.DefaultDurationForMFA, context) + if tokenError != nil { + context.Logger.Error(tokenError.Error()) + os.Exit(1) + } + session, sessionError := CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion) + if sessionError != nil { + context.Logger.Error(sessionError.Error()) + os.Exit(1) + } + return session +} + func CreateSession(context *context.Context, profile string, region *string) (*session.Session, error) { context.Logger.Info("Profile: " + profile) context.Logger.Info("Region: " + *region) @@ -34,89 +50,90 @@ func CreateSession(context *context.Context, profile string, region *string) (*s } func UpdateSessionToken(profile string, region string, defaultDuration int64, context *context.Context) error { - user, userError := user.Current() - if userError != nil { - return userError - } + if *context.CliArguments.MFA || *context.CliArguments.Mode == cliparser.MfaMode { + user, userError := user.Current() + if userError != nil { + return userError + } - credentialsFilePath := user.HomeDir + "/.aws/credentials" - configuration, loadCredentialsError := ini.Load(credentialsFilePath) - if loadCredentialsError != nil { - return loadCredentialsError - } + credentialsFilePath := user.HomeDir + "/.aws/credentials" + configuration, loadCredentialsError := ini.Load(credentialsFilePath) + if loadCredentialsError != nil { + return loadCredentialsError + } - section, sectionError := configuration.GetSection(profile) - if sectionError != nil { - section, sectionError = configuration.NewSection(profile) + section, sectionError := configuration.GetSection(profile) if sectionError != nil { - return sectionError + section, sectionError = configuration.NewSection(profile) + if sectionError != nil { + return sectionError + } } - } - - profileLongTerm := profile + "-long-term" - sectionLongTerm, profileLongTermError := configuration.GetSection(profileLongTerm) - if profileLongTermError != nil { - return profileLongTermError - } - sessionToken := section.Key("aws_session_token") - expiration := section.Key("expiration") - - expirationDate, dataError := time.Parse(dateFormat, section.Key("expiration").Value()) - if dataError == nil { - context.Logger.Info("Session token will expire in " + utilities.TruncateDuration(time.Since(expirationDate)).String() + " (" + expirationDate.Format(dateFormat) + ")") - } + profileLongTerm := profile + "-long-term" + sectionLongTerm, profileLongTermError := configuration.GetSection(profileLongTerm) + if profileLongTermError != nil { + return profileLongTermError + } - mfaDevice := sectionLongTerm.Key("mfa_serial").Value() - if mfaDevice == "" { - return errors.New("There is no mfa_serial for the profile " + profileLongTerm) - } + sessionToken := section.Key("aws_session_token") + expiration := section.Key("expiration") - if sessionToken.Value() == "" || expiration.Value() == "" || time.Since(expirationDate).Nanoseconds() > 0 { - session, sessionError := session.NewSessionWithOptions( - session.Options{ - Config: aws.Config{ - Region: ®ion, - }, - Profile: profileLongTerm, - }) - if sessionError != nil { - return sessionError + expirationDate, dataError := time.Parse(dateFormat, section.Key("expiration").Value()) + if dataError == nil { + context.Logger.Info("Session token will expire in " + utilities.TruncateDuration(time.Since(expirationDate)).String() + " (" + expirationDate.Format(dateFormat) + ")") } - var tokenCode string - sessionError = context.Logger.GetInput("MFA token code", &tokenCode) - if sessionError != nil { - return sessionError + mfaDevice := sectionLongTerm.Key("mfa_serial").Value() + if mfaDevice == "" { + return errors.New("There is no mfa_serial for the profile " + profileLongTerm + ". If you haven't used --mfa option you can change the default decision for MFA in the configuration file") } - var duration int64 - if defaultDuration == 0 { - sessionError = context.Logger.GetInput("Duration", &duration) + if sessionToken.Value() == "" || expiration.Value() == "" || time.Since(expirationDate).Nanoseconds() > 0 { + session, sessionError := session.NewSessionWithOptions( + session.Options{ + Config: aws.Config{ + Region: ®ion, + }, + Profile: profileLongTerm, + }) if sessionError != nil { return sessionError } - } else { - duration = defaultDuration - } - stsSession := sts.New(session) - newToken, tokenError := stsSession.GetSessionToken(&sts.GetSessionTokenInput{ - DurationSeconds: &duration, - SerialNumber: aws.String(mfaDevice), - TokenCode: &tokenCode, - }) - if tokenError != nil { - return tokenError - } + var tokenCode string + sessionError = context.Logger.GetInput("MFA token code", &tokenCode) + if sessionError != nil { + return sessionError + } - section.Key("aws_access_key_id").SetValue(*newToken.Credentials.AccessKeyId) - section.Key("aws_secret_access_key").SetValue(*newToken.Credentials.SecretAccessKey) - sessionToken.SetValue(*newToken.Credentials.SessionToken) - section.Key("expiration").SetValue(newToken.Credentials.Expiration.Format(dateFormat)) + var duration int64 + if defaultDuration == 0 { + sessionError = context.Logger.GetInput("Duration", &duration) + if sessionError != nil { + return sessionError + } + } else { + duration = defaultDuration + } - configuration.SaveTo(credentialsFilePath) - } + stsSession := sts.New(session) + newToken, tokenError := stsSession.GetSessionToken(&sts.GetSessionTokenInput{ + DurationSeconds: &duration, + SerialNumber: aws.String(mfaDevice), + TokenCode: &tokenCode, + }) + if tokenError != nil { + return tokenError + } + section.Key("aws_access_key_id").SetValue(*newToken.Credentials.AccessKeyId) + section.Key("aws_secret_access_key").SetValue(*newToken.Credentials.SecretAccessKey) + sessionToken.SetValue(*newToken.Credentials.SessionToken) + section.Key("expiration").SetValue(newToken.Credentials.Expiration.Format(dateFormat)) + + configuration.SaveTo(credentialsFilePath) + } + } return nil } diff --git a/stack/stack.go b/stack/stack.go index f395e41..fd8e52a 100644 --- a/stack/stack.go +++ b/stack/stack.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/cloudformation" "io/ioutil" + "os" ) // This function gets template and name of stack. It creates "CreateStackInput" structure. @@ -50,14 +51,7 @@ func createStack(templateStruct cloudformation.CreateStackInput, session *sessio func NewStack(context *context.Context) { template, stackName := getTemplateFromFile(context) templateStruct := createStackInput(context, &template, &stackName) - tokenError := mysession.UpdateSessionToken(context.Config.DefaultProfile, context.Config.DefaultRegion, context.Config.DefaultDurationForMFA, context) - if tokenError != nil { - context.Logger.Error(tokenError.Error()) - } - session, createSessionError := mysession.CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion) - if createSessionError != nil { - context.Logger.Error(createSessionError.Error()) - } + session := mysession.InitializeSession(context) createStackError := createStack(templateStruct, session) if createStackError != nil { context.Logger.Error(createStackError.Error()) @@ -67,12 +61,13 @@ func NewStack(context *context.Context) { // This function bases on "DeleteStackInput" structure and destroys stack. It uses "StackName" to choose which stack will be destroy. Before that it creates session. func DestroyStack(context *context.Context) { delStackInput := deleteStackInput(context) - session, sessionError := mysession.CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion) - if sessionError != nil { - context.Logger.Error(sessionError.Error()) - } + session := mysession.InitializeSession(context) api := cloudformation.New(session) - api.DeleteStack(&delStackInput) + _, err := api.DeleteStack(&delStackInput) + if err != nil { + context.Logger.Error(err.Error()) + os.Exit(1) + } } // This function gets "StackName" from Stack in CliArguments and creates "DeleteStackInput" structure.