Skip to content

Commit

Permalink
Make solver a parameter to be passed in
Browse files Browse the repository at this point in the history
  • Loading branch information
edw-defang committed May 15, 2024
1 parent 69b2270 commit 9b8c09f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 37 deletions.
14 changes: 2 additions & 12 deletions acme/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ import (

"github.com/defang-io/cloudacme/aws/acm"
"github.com/defang-io/cloudacme/aws/alb"
"github.com/defang-io/cloudacme/solver"
"github.com/mholt/acmez"
"go.uber.org/zap"
)

var ChallengeSolver acmez.Solver
var logger *zap.Logger

func init() {
Expand All @@ -28,7 +26,7 @@ func init() {
}
}

func UpdateAcmeCertificate(ctx context.Context, albArn, domain string) error {
func UpdateAcmeCertificate(ctx context.Context, albArn, domain string, solver acmez.Solver) error {
accountKey, err := getAccountKey()
if err != nil {
return fmt.Errorf("failed to get account key: %w", err)
Expand Down Expand Up @@ -77,20 +75,12 @@ func UpdateAcmeCertificate(ctx context.Context, albArn, domain string) error {
acmeDirectory = DefaultAcmeDirectory
}

cs := ChallengeSolver
if cs == nil {
cs = solver.AlbHttp01Solver{
AlbArn: albArn,
Domains: []string{domain},
Logger: logger,
}
}
acmeClient := Acme{
Directory: acmeDirectory,
AccountKey: accountKey,
Logger: logger,
AlbArn: albArn,
HttpSolver: cs,
HttpSolver: solver,
}

key, chain, err := acmeClient.GetCertificate(ctx, []string{domain})
Expand Down
23 changes: 17 additions & 6 deletions cmd/lambda/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
awsalb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
"github.com/defang-io/cloudacme/acme"
"github.com/defang-io/cloudacme/aws/alb"
"github.com/defang-io/cloudacme/solver"
)

var version = "dev" // to be set by ldflags
Expand Down Expand Up @@ -48,7 +49,12 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve
}

host := evt.Headers["host"]
if err := acme.UpdateAcmeCertificate(ctx, albArn, host); err != nil {
albSolver := solver.AlbHttp01Solver{
AlbArn: albArn,
Domains: []string{host},
}

if err := acme.UpdateAcmeCertificate(ctx, albArn, host, albSolver); err != nil {
return nil, fmt.Errorf("failed to update certificate: %w", err)
}

Expand All @@ -57,11 +63,11 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve
PathPattern: []string{"/"},
}

if err := removeHttpRule(ctx, albArn, cond); err != nil {
if err := RemoveHttpRule(ctx, albArn, cond); err != nil {
return nil, fmt.Errorf("failed to remove http rule: %w", err)
}

validationCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
validationCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
if err := validateCertAttached(validationCtx, host); err != nil {
return nil, fmt.Errorf("failed to validate certificate: %w", err)
Expand Down Expand Up @@ -91,7 +97,7 @@ func validateCertAttached(ctx context.Context, domain string) error {
if _, err := http.DefaultClient.Do(req); err != nil {
var tlsErr *tls.CertificateVerificationError
if errors.As(err, &tlsErr) {
log.Printf("ssl cert for %v is still not valid", tlsErr)
log.Printf("ssl cert for %v is still not valid: %v", domain, tlsErr)
continue
}
return fmt.Errorf("failed https request to domain %v: %w", domain, err)
Expand All @@ -101,7 +107,7 @@ func validateCertAttached(ctx context.Context, domain string) error {
}
}

func removeHttpRule(ctx context.Context, albArn string, ruleCond alb.RuleCondition) error {
func RemoveHttpRule(ctx context.Context, albArn string, ruleCond alb.RuleCondition) error {
listener, err := alb.GetListener(ctx, albArn, awsalb.ProtocolEnumHttp, 80)
if err != nil {
return fmt.Errorf("cannot get http listener: %w", err)
Expand Down Expand Up @@ -133,7 +139,12 @@ func getHttpsRedirectURL(evt events.ALBTargetGroupRequest) string {
func HandleEventBridgeEvent(ctx context.Context, evt CertificateRenewalEvent) error {
log.Printf("Handling Certificate Renewal Event: %+v", evt)

if err := acme.UpdateAcmeCertificate(ctx, evt.AlbArn, evt.Domain); err != nil {
albSolver := solver.AlbHttp01Solver{
AlbArn: evt.AlbArn,
Domains: []string{evt.Domain},
}

if err := acme.UpdateAcmeCertificate(ctx, evt.AlbArn, evt.Domain, albSolver); err != nil {
return fmt.Errorf("failed to renew certificate: %w", err)
}

Expand Down
27 changes: 8 additions & 19 deletions solver/albhttp01solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"time"

"github.com/defang-io/cloudacme/aws/alb"
awsalb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
"github.com/defang-io/cloudacme/aws/alb"
"github.com/mholt/acmez/acme"
"go.uber.org/zap"
)
Expand All @@ -24,9 +25,7 @@ type AlbHttp01Solver struct {
}

func (s AlbHttp01Solver) Present(ctx context.Context, chal acme.Challenge) error {
if s.Logger != nil {
s.Logger.Info("Presenting challenge", zap.Strings("domains", s.Domains), zap.String("path", chal.HTTP01ResourcePath()))
}
s.Logger.Info("Presenting challenge", zap.Strings("domains", s.Domains), zap.String("path", chal.HTTP01ResourcePath()))
listener, err := alb.GetListener(ctx, s.AlbArn, awsalb.ProtocolEnumHttp, 80)
if err != nil {
return fmt.Errorf("cannot get http listener: %w", err)
Expand All @@ -44,9 +43,7 @@ func (s AlbHttp01Solver) Present(ctx context.Context, chal acme.Challenge) error
}

func (s AlbHttp01Solver) CleanUp(ctx context.Context, chal acme.Challenge) error {
if s.Logger != nil {
s.Logger.Info("Cleaning up challenge", zap.Strings("domains", s.Domains), zap.String("path", chal.HTTP01ResourcePath()))
}
log.Printf("Cleaning up challenge for domains %v at path %v", s.Domains, chal.HTTP01ResourcePath())
listener, err := alb.GetListener(ctx, s.AlbArn, awsalb.ProtocolEnumHttp, 80)
if err != nil {
return fmt.Errorf("cannot get http listener: %w", err)
Expand All @@ -59,19 +56,15 @@ func (s AlbHttp01Solver) CleanUp(ctx context.Context, chal acme.Challenge) error

err = alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, ruleCond)
if errors.Is(err, alb.ErrRuleNotFound) {
if s.Logger != nil {
s.Logger.Info("Challenge rule not found, skipping cleanup alb rule", zap.String("path", chal.HTTP01ResourcePath()))
}
log.Printf("Challenge rule not found, skipping cleanup alb rule for path: %v", chal.HTTP01ResourcePath())
} else if err != nil {
return fmt.Errorf("failed to delete listener static rule: %v", err)
}
return nil
}

func (s AlbHttp01Solver) Wait(ctx context.Context, chal acme.Challenge) error {
if s.Logger != nil {
s.Logger.Info("Waiting for challenge", zap.Strings("domains", s.Domains), zap.String("path", chal.HTTP01ResourcePath()))
}
log.Printf("Waiting for challenge for domains %v at path %v", s.Domains, chal.HTTP01ResourcePath())
timeout := s.WaitTimeout
if timeout == 0 {
timeout = DefaultWaitTimeout
Expand All @@ -81,16 +74,12 @@ func (s AlbHttp01Solver) Wait(ctx context.Context, chal acme.Challenge) error {
defer cancel()
for _, domain := range s.Domains {
chkUrl := "http://" + domain + chal.HTTP01ResourcePath()
if s.Logger != nil {
s.Logger.Info("Checking URL", zap.String("url", chkUrl))
}
log.Printf("Checking URL %v", chkUrl)
if err := checkUrl(chkCtx, chkUrl, chal.KeyAuthorization); err != nil {
return fmt.Errorf("failed waiting for challenge: %w", err)
}
}
if s.Logger != nil {
s.Logger.Info("Challenge is ready", zap.Strings("domains", s.Domains), zap.String("path", chal.HTTP01ResourcePath()))
}
log.Printf("Challenge is ready for domains %v, at path %v", s.Domains, chal.HTTP01ResourcePath())
return nil
}

Expand Down

0 comments on commit 9b8c09f

Please sign in to comment.