diff --git a/headerforwarder/context_headers.go b/headerforwarder/context_headers.go index 8d3d1f6f..b47eaa7b 100644 --- a/headerforwarder/context_headers.go +++ b/headerforwarder/context_headers.go @@ -25,6 +25,8 @@ type contextKey string const requestIDKey = contextKey("request_id") +const outgoingHeadersKey = contextKey("outgoing_headers") + func ContextWithRosettaID(ctx context.Context) context.Context { return context.WithValue(ctx, requestIDKey, uuid.NewString()) } @@ -46,3 +48,16 @@ func RosettaIDFromRequest(r *http.Request) string { return "" } } + +func ContextWithOutgoingHeaders(ctx context.Context, headers http.Header) context.Context { + return context.WithValue(ctx, outgoingHeadersKey, headers) +} + +func OutgoingHeadersFromContext(ctx context.Context) http.Header { + switch val := ctx.Value(outgoingHeadersKey).(type) { + case http.Header: + return val + default: + return nil + } +} diff --git a/headerforwarder/forwarder.go b/headerforwarder/forwarder.go index 7896bccd..28ff4959 100644 --- a/headerforwarder/forwarder.go +++ b/headerforwarder/forwarder.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net/http" + "strings" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -39,9 +40,11 @@ type HeaderForwarder struct { actualTransport http.RoundTripper } -// TODO: make transport an optional parameter, add "WithTransport" style functions to make it easier -// to add the actual RPC clients to this struct -func NewHeaderForwarder(interestingHeaders []string, transport http.RoundTripper) (*HeaderForwarder, error) { +func NewHeaderForwarder( + interestingHeaders []string, + transport http.RoundTripper, + // outgoingContextFromRequest func(r *http.Request) context.Context, +) (*HeaderForwarder, error) { if len(interestingHeaders) == 0 { return nil, fmt.Errorf("must provide at least one interesting header") } @@ -103,8 +106,14 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons ctx := req.Context() rosettaRequestID := RosettaIDFromContext(ctx) + // For multiple requests with the same rosetta ID, we want to remember all of the headers + // For repeated response headers, later values will overwrite earlier ones + headersToRemember, exists := hf.requestHeaders[rosettaRequestID] + if !exists { + headersToRemember = make(http.Header) + } + // Only remember interesting headers - headersToRemember := make(http.Header) for _, interestingHeader := range hf.interestingHeaders { headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader)) } @@ -121,8 +130,9 @@ func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any, } // If any of the interesting headers are in the response metadata, remember it + // grpc metadata uses lowercase keys rather than http canonicalized keys for _, interestingHeader := range hf.interestingHeaders { - if _, responseHasHeader := resp[http.CanonicalHeaderKey(interestingHeader)]; responseHasHeader { + if _, responseHasHeader := resp[strings.ToLower(interestingHeader)]; responseHasHeader { return true } } @@ -135,7 +145,13 @@ func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any, func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, req any, resp metadata.MD) { rosettaID := RosettaIDFromContext(ctx) - headersToRemember := make(http.Header) + // For multiple requests with the same rosetta ID, we want to remember all of the headers + // For repeated response headers, later values will overwrite earlier ones + headersToRemember, exists := hf.requestHeaders[rosettaID] + if !exists { + headersToRemember = make(http.Header) + } + for _, interestingHeader := range hf.interestingHeaders { for _, value := range resp.Get(interestingHeader) { headersToRemember.Set(interestingHeader, value) @@ -162,7 +178,6 @@ func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Hea // those headers on the response func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println("HeaderForwarder Handler") // add a unique ID to the request context, and make a new request for it requestWithID := hf.RequestWithRequestID(r) @@ -182,10 +197,9 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle // RoundTrip implements http.RoundTripper and will be used to construct an http Client which // saves the native node response headers if necessary. func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) { - fmt.Println("HeaderForwarder RoundTrip") - resp, err := hf.actualTransport.RoundTrip(req) + // TODO: add outgoing headers to the request - fmt.Println("HeaderForwarder RoundTrip: response headers", resp.Header) + resp, err := hf.actualTransport.RoundTrip(req) if err == nil && hf.shouldRememberHeaders(req, resp) { hf.rememberHeaders(req, resp) @@ -195,22 +209,15 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) } func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - fmt.Println("HeaderForwarder grpc interceptor") - - fmt.Println("request id: ", RosettaIDFromContext(ctx)) - // append a header DialOption to the request - var responseMD metadata.MD - opts = append(opts, grpc.Header(&responseMD)) + var header metadata.MD + opts = append(opts, grpc.Header(&header)) err := invoker(ctx, method, req, reply, cc, opts...) - if hf.shouldRememberMetadata(ctx, req, responseMD) { - hf.rememberMetadata(ctx, req, responseMD) + if hf.shouldRememberMetadata(ctx, req, header) { + hf.rememberMetadata(ctx, req, header) } - // get headers from response - fmt.Println("HeaderForwarder grpc interceptor: headers from response", responseMD) - return err }