Skip to content

Commit

Permalink
Add outgoing header forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
potterbm-cb committed Nov 7, 2024
1 parent b0d488a commit 98f4705
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
10 changes: 10 additions & 0 deletions headerforwarder/context_headers.go → headerforwarder/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ func RosettaIDFromRequest(r *http.Request) string {
}
}

// RequestWithRequestID adds a unique ID to the request context. A new request is returned that contains the
// new context
func RequestWithRequestID(req *http.Request) *http.Request {
ctx := req.Context()
ctxWithID := ContextWithRosettaID(ctx)
requestWithID := req.WithContext(ctxWithID)

return requestWithID
}

func ContextWithOutgoingHeaders(ctx context.Context, headers http.Header) context.Context {
return context.WithValue(ctx, outgoingHeadersKey, headers)
}
Expand Down
79 changes: 54 additions & 25 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ import (
//
// TODO: this should expire entries after a certain amount of time
type HeaderForwarder struct {
requestHeaders map[string]http.Header
incomingHeaders map[string]http.Header
outgoingHeaders map[string]http.Header
interestingHeaders []string
actualTransport http.RoundTripper
}

func NewHeaderForwarder(
interestingHeaders []string,
transport http.RoundTripper,
// outgoingContextFromRequest func(r *http.Request) context.Context,
) (*HeaderForwarder, error) {
if len(interestingHeaders) == 0 {
return nil, fmt.Errorf("must provide at least one interesting header")
}

return &HeaderForwarder{
requestHeaders: make(map[string]http.Header),
incomingHeaders: make(map[string]http.Header),
outgoingHeaders: make(map[string]http.Header),
interestingHeaders: interestingHeaders,
actualTransport: transport,
}, nil
Expand All @@ -61,14 +62,18 @@ func (hf *HeaderForwarder) WithTransport(transport http.RoundTripper) *HeaderFor
return hf
}

// RequestWithRequestID adds a unique ID to the request context. A new request is returned that contains the
// new context
func (hf *HeaderForwarder) RequestWithRequestID(req *http.Request) *http.Request {
func (hf *HeaderForwarder) captureOutgoingHeaders(req *http.Request) {
ctx := req.Context()
ctxWithID := ContextWithRosettaID(ctx)
requestWithID := req.WithContext(ctxWithID)
rosettaRequestID := RosettaIDFromContext(ctx)

hf.outgoingHeaders[rosettaRequestID] = make(http.Header)

return requestWithID
// Only capture interesting headers
for _, interestingHeader := range hf.interestingHeaders {
if _, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]; requestHasHeader {
hf.outgoingHeaders[rosettaRequestID].Set(interestingHeader, req.Header.Get(interestingHeader))
}
}
}

// shouldRememberHeaders reports whether response headers should be remembered for a
Expand Down Expand Up @@ -108,7 +113,7 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
headersToRemember, exists := hf.requestHeaders[rosettaRequestID]
headersToRemember, exists := hf.incomingHeaders[rosettaRequestID]
if !exists {
headersToRemember = make(http.Header)
}
Expand All @@ -118,12 +123,12 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}

hf.requestHeaders[rosettaRequestID] = headersToRemember
hf.incomingHeaders[rosettaRequestID] = headersToRemember
}

// shouldRememberMetadata reports whether response metadata should be remembered for a grpc unary
// RPC call. Response metadata will only be remembered if it contains any of the interesting headers.
func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any, resp metadata.MD) bool {
func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, resp metadata.MD) bool {
rosettaID := RosettaIDFromContext(ctx)
if rosettaID == "" {
return false
Expand All @@ -142,33 +147,36 @@ func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any,

// rememberMetadata saves the native node response metadata. The response object is metadata retrieved
// from a native node GRPC unary RPC call.
func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, req any, resp metadata.MD) {
func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.MD) {
rosettaID := RosettaIDFromContext(ctx)

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
headersToRemember, exists := hf.requestHeaders[rosettaID]
headersToRemember, exists := hf.incomingHeaders[rosettaID]
if !exists {
headersToRemember = make(http.Header)
}

for _, interestingHeader := range hf.interestingHeaders {
for _, value := range resp.Get(interestingHeader) {
for _, value := range resp.Get(strings.ToLower(interestingHeader)) {
headersToRemember.Set(interestingHeader, value)
}
}

hf.requestHeaders[rosettaID] = headersToRemember
hf.incomingHeaders[rosettaID] = headersToRemember
}

// GetResponseHeaders returns any headers that should be returned to a rosetta response. These
// consist of native node response headers/metadata that were remembered for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
headers, ok := hf.requestHeaders[rosettaRequestID]
headers, ok := hf.incomingHeaders[rosettaRequestID]

// Delete the headers from the map after they are retrieved
// This is safe to call even if the key doesn't exist
delete(hf.requestHeaders, rosettaRequestID)
delete(hf.incomingHeaders, rosettaRequestID)

// Also delete the outgoing headers from the map since we are done with them
delete(hf.outgoingHeaders, rosettaRequestID)

return headers, ok
}
Expand All @@ -179,25 +187,36 @@ func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Hea
func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// add a unique ID to the request context, and make a new request for it
requestWithID := hf.RequestWithRequestID(r)
requestWithID := RequestWithRequestID(r)

// Capture outgoing interesting headers
hf.captureOutgoingHeaders(requestWithID)

// Serve the request
// NOTE: for servers using github.com/coinbase/mesh-geth-sdk, ResponseWriter::WriteHeader() WILL
// be called here, so we can't set headers after this happens. We include a wrapper around the
// be called internally, so we can't set headers after this happens. We include a wrapper around the
// response writer that allows us to set headers just before WriteHeader is called
wrappedResponseWriter := NewResponseWriter(
w,
RosettaIDFromRequest(requestWithID),
hf.getResponseHeaders,
)

// Serve the request
next.ServeHTTP(wrappedResponseWriter, requestWithID)
})
}

// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add outgoing headers to the request
// add outgoing headers to the request
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromRequest(req)]; ok {
for header, values := range outgoingHeaders {
for _, value := range values {
req.Header.Add(header, value)
}
}
}

resp, err := hf.actualTransport.RoundTrip(req)

Expand All @@ -209,14 +228,24 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error)
}

func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// append a header DialOption to the request
// Capture incoming headers from the grpc call
var header metadata.MD
opts = append(opts, grpc.Header(&header))

// Add outgoing headers to the context
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromContext(ctx)]; ok {
for header, values := range outgoingHeaders {
for _, value := range values {
ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(header), value)
}
}
}

// Invoke the grpc call
err := invoker(ctx, method, req, reply, cc, opts...)

if hf.shouldRememberMetadata(ctx, req, header) {
hf.rememberMetadata(ctx, req, header)
if hf.shouldRememberMetadata(ctx, header) {
hf.rememberMetadata(ctx, header)
}

return err
Expand Down

0 comments on commit 98f4705

Please sign in to comment.