Skip to content

Commit

Permalink
Merge pull request #101 from Appliscale/mfa-configuration
Browse files Browse the repository at this point in the history
mfa configuration #91
  • Loading branch information
Piotr Figwer authored Jun 18, 2018
2 parents d4f0d32 + 191727b commit f88b9b7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 80 deletions.
4 changes: 3 additions & 1 deletion configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ func postProcessing(config *Configuration, cliArguments cliparser.CliArguments)
if *cliArguments.Profile != "" {
config.DefaultProfile = *cliArguments.Profile
}
if *cliArguments.MFA != config.DefaultDecisionForMFA {
if *cliArguments.MFA {
config.DefaultDecisionForMFA = *cliArguments.MFA
} else {
*cliArguments.MFA = config.DefaultDecisionForMFA
}
if *cliArguments.DurationForMFA > 0 {
config.DefaultDurationForMFA = *cliArguments.DurationForMFA
Expand Down
149 changes: 83 additions & 66 deletions mysession/mysession.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,34 @@ package mysession

import (
"errors"
"github.com/Appliscale/perun/cliparser"
"github.com/Appliscale/perun/context"
"github.com/Appliscale/perun/utilities"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/go-ini/ini"
"os"
"os/user"
"time"
)

const dateFormat = "2006-01-02 15:04:05 MST"

func InitializeSession(context *context.Context) *session.Session {
tokenError := UpdateSessionToken(context.Config.DefaultProfile, context.Config.DefaultRegion, context.Config.DefaultDurationForMFA, context)
if tokenError != nil {
context.Logger.Error(tokenError.Error())
os.Exit(1)
}
session, sessionError := CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion)
if sessionError != nil {
context.Logger.Error(sessionError.Error())
os.Exit(1)
}
return session
}

func CreateSession(context *context.Context, profile string, region *string) (*session.Session, error) {
context.Logger.Info("Profile: " + profile)
context.Logger.Info("Region: " + *region)
Expand All @@ -34,89 +50,90 @@ func CreateSession(context *context.Context, profile string, region *string) (*s
}

func UpdateSessionToken(profile string, region string, defaultDuration int64, context *context.Context) error {
user, userError := user.Current()
if userError != nil {
return userError
}
if *context.CliArguments.MFA || *context.CliArguments.Mode == cliparser.MfaMode {
user, userError := user.Current()
if userError != nil {
return userError
}

credentialsFilePath := user.HomeDir + "/.aws/credentials"
configuration, loadCredentialsError := ini.Load(credentialsFilePath)
if loadCredentialsError != nil {
return loadCredentialsError
}
credentialsFilePath := user.HomeDir + "/.aws/credentials"
configuration, loadCredentialsError := ini.Load(credentialsFilePath)
if loadCredentialsError != nil {
return loadCredentialsError
}

section, sectionError := configuration.GetSection(profile)
if sectionError != nil {
section, sectionError = configuration.NewSection(profile)
section, sectionError := configuration.GetSection(profile)
if sectionError != nil {
return sectionError
section, sectionError = configuration.NewSection(profile)
if sectionError != nil {
return sectionError
}
}
}

profileLongTerm := profile + "-long-term"
sectionLongTerm, profileLongTermError := configuration.GetSection(profileLongTerm)
if profileLongTermError != nil {
return profileLongTermError
}

sessionToken := section.Key("aws_session_token")
expiration := section.Key("expiration")

expirationDate, dataError := time.Parse(dateFormat, section.Key("expiration").Value())
if dataError == nil {
context.Logger.Info("Session token will expire in " + utilities.TruncateDuration(time.Since(expirationDate)).String() + " (" + expirationDate.Format(dateFormat) + ")")
}
profileLongTerm := profile + "-long-term"
sectionLongTerm, profileLongTermError := configuration.GetSection(profileLongTerm)
if profileLongTermError != nil {
return profileLongTermError
}

mfaDevice := sectionLongTerm.Key("mfa_serial").Value()
if mfaDevice == "" {
return errors.New("There is no mfa_serial for the profile " + profileLongTerm)
}
sessionToken := section.Key("aws_session_token")
expiration := section.Key("expiration")

if sessionToken.Value() == "" || expiration.Value() == "" || time.Since(expirationDate).Nanoseconds() > 0 {
session, sessionError := session.NewSessionWithOptions(
session.Options{
Config: aws.Config{
Region: &region,
},
Profile: profileLongTerm,
})
if sessionError != nil {
return sessionError
expirationDate, dataError := time.Parse(dateFormat, section.Key("expiration").Value())
if dataError == nil {
context.Logger.Info("Session token will expire in " + utilities.TruncateDuration(time.Since(expirationDate)).String() + " (" + expirationDate.Format(dateFormat) + ")")
}

var tokenCode string
sessionError = context.Logger.GetInput("MFA token code", &tokenCode)
if sessionError != nil {
return sessionError
mfaDevice := sectionLongTerm.Key("mfa_serial").Value()
if mfaDevice == "" {
return errors.New("There is no mfa_serial for the profile " + profileLongTerm + ". If you haven't used --mfa option you can change the default decision for MFA in the configuration file")
}

var duration int64
if defaultDuration == 0 {
sessionError = context.Logger.GetInput("Duration", &duration)
if sessionToken.Value() == "" || expiration.Value() == "" || time.Since(expirationDate).Nanoseconds() > 0 {
session, sessionError := session.NewSessionWithOptions(
session.Options{
Config: aws.Config{
Region: &region,
},
Profile: profileLongTerm,
})
if sessionError != nil {
return sessionError
}
} else {
duration = defaultDuration
}

stsSession := sts.New(session)
newToken, tokenError := stsSession.GetSessionToken(&sts.GetSessionTokenInput{
DurationSeconds: &duration,
SerialNumber: aws.String(mfaDevice),
TokenCode: &tokenCode,
})
if tokenError != nil {
return tokenError
}
var tokenCode string
sessionError = context.Logger.GetInput("MFA token code", &tokenCode)
if sessionError != nil {
return sessionError
}

section.Key("aws_access_key_id").SetValue(*newToken.Credentials.AccessKeyId)
section.Key("aws_secret_access_key").SetValue(*newToken.Credentials.SecretAccessKey)
sessionToken.SetValue(*newToken.Credentials.SessionToken)
section.Key("expiration").SetValue(newToken.Credentials.Expiration.Format(dateFormat))
var duration int64
if defaultDuration == 0 {
sessionError = context.Logger.GetInput("Duration", &duration)
if sessionError != nil {
return sessionError
}
} else {
duration = defaultDuration
}

configuration.SaveTo(credentialsFilePath)
}
stsSession := sts.New(session)
newToken, tokenError := stsSession.GetSessionToken(&sts.GetSessionTokenInput{
DurationSeconds: &duration,
SerialNumber: aws.String(mfaDevice),
TokenCode: &tokenCode,
})
if tokenError != nil {
return tokenError
}

section.Key("aws_access_key_id").SetValue(*newToken.Credentials.AccessKeyId)
section.Key("aws_secret_access_key").SetValue(*newToken.Credentials.SecretAccessKey)
sessionToken.SetValue(*newToken.Credentials.SessionToken)
section.Key("expiration").SetValue(newToken.Credentials.Expiration.Format(dateFormat))

configuration.SaveTo(credentialsFilePath)
}
}
return nil
}
21 changes: 8 additions & 13 deletions stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudformation"
"io/ioutil"
"os"
)

// This function gets template and name of stack. It creates "CreateStackInput" structure.
Expand Down Expand Up @@ -50,14 +51,7 @@ func createStack(templateStruct cloudformation.CreateStackInput, session *sessio
func NewStack(context *context.Context) {
template, stackName := getTemplateFromFile(context)
templateStruct := createStackInput(context, &template, &stackName)
tokenError := mysession.UpdateSessionToken(context.Config.DefaultProfile, context.Config.DefaultRegion, context.Config.DefaultDurationForMFA, context)
if tokenError != nil {
context.Logger.Error(tokenError.Error())
}
session, createSessionError := mysession.CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion)
if createSessionError != nil {
context.Logger.Error(createSessionError.Error())
}
session := mysession.InitializeSession(context)
createStackError := createStack(templateStruct, session)
if createStackError != nil {
context.Logger.Error(createStackError.Error())
Expand All @@ -67,12 +61,13 @@ func NewStack(context *context.Context) {
// This function bases on "DeleteStackInput" structure and destroys stack. It uses "StackName" to choose which stack will be destroy. Before that it creates session.
func DestroyStack(context *context.Context) {
delStackInput := deleteStackInput(context)
session, sessionError := mysession.CreateSession(context, context.Config.DefaultProfile, &context.Config.DefaultRegion)
if sessionError != nil {
context.Logger.Error(sessionError.Error())
}
session := mysession.InitializeSession(context)
api := cloudformation.New(session)
api.DeleteStack(&delStackInput)
_, err := api.DeleteStack(&delStackInput)
if err != nil {
context.Logger.Error(err.Error())
os.Exit(1)
}
}

// This function gets "StackName" from Stack in CliArguments and creates "DeleteStackInput" structure.
Expand Down

0 comments on commit f88b9b7

Please sign in to comment.