@@ -3,63 +3,68 @@ package server
33import (
44 "context"
55 "encoding/json"
6+ "errors"
67 "fmt"
78 "net/http"
8- "sync"
99
1010 "github.com/aws/aws-sdk-go/aws"
1111 "github.com/localstack/lambda-runtime-init/internal/localstack"
1212 log "github.com/sirupsen/logrus"
1313 "go.amzn.com/lambda/core/directinvoke"
14+ "go.amzn.com/lambda/core/statejson"
1415 "go.amzn.com/lambda/interop"
1516 "go.amzn.com/lambda/metering"
1617 "go.amzn.com/lambda/rapi/model"
1718 "go.amzn.com/lambda/rapidcore"
1819 "golang.org/x/sync/errgroup"
1920)
2021
21- type CustomInteropServer struct {
22+ type LocalStackInteropsServer struct {
2223 * rapidcore.Server
2324 localStackAdapter * localstack.LocalStackClient
24- mutex * sync.Mutex
2525}
2626
27- func NewInteropServer (server * rapidcore.Server , ls * localstack.LocalStackClient ) * CustomInteropServer {
28- return & CustomInteropServer {
27+ func NewInteropServer (server * rapidcore.Server , ls * localstack.LocalStackClient ) * LocalStackInteropsServer {
28+ return & LocalStackInteropsServer {
2929 Server : server ,
3030 localStackAdapter : ls ,
31- mutex : & sync.Mutex {},
3231 }
3332}
3433
35- func (c * CustomInteropServer ) Invoke ( responseWriter http.ResponseWriter , invoke * interop.Invoke ) error {
34+ func (c * LocalStackInteropsServer ) Execute ( ctx context. Context , responseWriter http.ResponseWriter , invoke * interop.Invoke ) error {
3635 ctx , cancel := context .WithTimeout (context .Background (), c .Server .GetInvokeTimeout ())
3736 defer cancel ()
3837
39- if err := c .reserveForInvoke (ctx , invoke ); err != nil {
38+ if err := c .reserve (ctx , invoke ); err != nil {
4039 return err
4140 }
4241
43- return c .executeInvoke (ctx , responseWriter , invoke )
42+ if err := c .executeInvoke (ctx , responseWriter , invoke ); err != nil {
43+ return err
44+ }
45+
46+ return nil
4447}
4548
46- func (c * CustomInteropServer ) executeInvoke (ctx context.Context , responseWriter http.ResponseWriter , invoke * interop.Invoke ) error {
49+ func (c * LocalStackInteropsServer ) Invoke (responseWriter http.ResponseWriter , invoke * interop.Invoke ) error {
50+ return c .Execute (context .Background (), responseWriter , invoke )
51+ }
52+
53+ func (c * LocalStackInteropsServer ) executeInvoke (ctx context.Context , responseWriter http.ResponseWriter , invoke * interop.Invoke ) error {
4754 g , gCtx := errgroup .WithContext (ctx )
4855
4956 g .Go (func () error {
5057 isDirect := directinvoke .MaxDirectResponseSize > interop .MaxPayloadSize
51- if err := c .Server .FastInvoke (responseWriter , invoke , isDirect ); err != nil {
52- log .Debugf ("FastInvoke() error: %s" , err )
58+ err := c .Server .FastInvoke (responseWriter , invoke , isDirect )
59+ if err != nil {
60+ log .WithError (err ).Debug ("FastInvoke() failed" )
5361 }
54- return nil
62+ return err
5563 })
5664
5765 g .Go (func () error {
58- _ , err := c .Server .AwaitRelease ()
59- if err != nil {
60- return c .handleReleaseError (err )
61- }
62- return nil
66+ _ , err := c .AwaitRelease ()
67+ return err
6368 })
6469
6570 done := make (chan error , 1 )
@@ -71,11 +76,17 @@ func (c *CustomInteropServer) executeInvoke(ctx context.Context, responseWriter
7176 case err := <- done :
7277 return err
7378 case <- gCtx .Done ():
74- return c .handleTimeout ()
79+ if errors .Is (gCtx .Err (), context .DeadlineExceeded ) {
80+ if _ , resetErr := c .Server .Reset ("Timeout" , 2000 ); resetErr != nil {
81+ log .WithError (resetErr ).Errorf ("Reset failed" )
82+ }
83+ return rapidcore .ErrInvokeTimeout
84+ }
85+ return nil
7586 }
7687}
7788
78- func (c * CustomInteropServer ) reserveForInvoke (ctx context.Context , invoke * interop.Invoke ) error {
89+ func (c * LocalStackInteropsServer ) reserve (ctx context.Context , invoke * interop.Invoke ) error {
7990 reserveResp , err := c .Server .Reserve (invoke .ID , invoke .TraceID , invoke .LambdaSegmentID )
8091 if err != nil {
8192 return err
@@ -89,12 +100,18 @@ func (c *CustomInteropServer) reserveForInvoke(ctx context.Context, invoke *inte
89100 switch err {
90101 case rapidcore .ErrInitDoneFailed :
91102 if _ , resetErr := c .Server .Reset ("InitFailed" , 2000 ); resetErr != nil {
92- log .Errorf ( "Reset failed: %v" , resetErr )
103+ log .WithError ( resetErr ). Debug ( "Reset failed" )
93104 }
94105
95- if _ , reserveErr := c .Server .Reserve (invoke .ID , invoke .TraceID , invoke .LambdaSegmentID ); reserveErr != nil {
96- return reserveErr
106+ if _ , err := c .Server .Reserve (invoke .ID , invoke .TraceID , invoke .LambdaSegmentID ); err != nil {
107+ return err
97108 }
109+
110+ // If the original INIT failed, let's do another wait since we've triggered a RESERVE
111+ if err := c .Server .AwaitInitialized (); err != nil {
112+ return err
113+ }
114+
98115 return nil
99116 default :
100117 return err
@@ -104,31 +121,25 @@ func (c *CustomInteropServer) reserveForInvoke(ctx context.Context, invoke *inte
104121 return nil
105122}
106123
107- func (c * CustomInteropServer ) handleReleaseError (err error ) error {
124+ func (c * LocalStackInteropsServer ) AwaitRelease () (* statejson.ReleaseResponse , error ) {
125+ resp , err := c .Server .AwaitRelease ()
108126 switch err {
109- case rapidcore .ErrReleaseReservationDone :
110- return nil
127+ case rapidcore .ErrReleaseReservationDone , nil :
128+ return resp , nil
111129 case rapidcore .ErrInitDoneFailed , rapidcore .ErrInvokeDoneFailed :
112130 if _ , resetErr := c .Server .Reset ("ReleaseFail" , 2000 ); resetErr != nil {
113131 log .Errorf ("Reset failed: %v" , resetErr )
114132 }
115- return err
133+ return nil , err
116134 default :
117135 if _ , resetErr := c .Server .Reset ("UnexpectedError" , 2000 ); resetErr != nil {
118136 log .Errorf ("Reset failed: %v" , resetErr )
119137 }
120- return err
121- }
122- }
123-
124- func (c * CustomInteropServer ) handleTimeout () error {
125- if _ , resetErr := c .Server .Reset ("Timeout" , 2000 ); resetErr != nil {
126- log .Errorf ("Reset failed: %v" , resetErr )
138+ return nil , err
127139 }
128- return rapidcore .ErrInvokeTimeout
129140}
130141
131- func (c * CustomInteropServer ) SendInitErrorResponse (resp * interop.ErrorInvokeResponse ) error {
142+ func (c * LocalStackInteropsServer ) SendInitErrorResponse (resp * interop.ErrorInvokeResponse ) error {
132143 errResp := & model.ErrorResponse {}
133144 err := json .Unmarshal (resp .Payload , errResp )
134145 if err != nil {
0 commit comments