Skip to content

Commit

Permalink
chore: add unit tests to headerforwarder package (#512)
Browse files Browse the repository at this point in the history
* add unit tests for HeaderForwarder class

* add unit tests for response_writer

* make add-license

* make format

* fix linting
  • Loading branch information
potterbm-cb authored Nov 27, 2024
1 parent 07dc44a commit 06d0fe9
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 6 deletions.
27 changes: 21 additions & 6 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons

// Only remember interesting headers
for _, interestingHeader := range hf.interestingHeaders {
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
values := resp.Header.Values(interestingHeader)

// Only remember the header if it is not empty
if len(values) > 0 {
for _, v := range values {
headersToRemember.Add(interestingHeader, v)
}
}
}

hf.incomingHeaderLock.Lock()
Expand Down Expand Up @@ -176,7 +183,7 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M

for _, interestingHeader := range hf.interestingHeaders {
for _, value := range resp.Get(strings.ToLower(interestingHeader)) {
headersToRemember.Set(interestingHeader, value)
headersToRemember.Add(interestingHeader, value)
}
}

Expand All @@ -187,7 +194,7 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M

// GetResponseHeaders returns any headers that should be returned to a rosetta response. These
// consist of native node response headers/metadata that were remembered for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
func (hf *HeaderForwarder) GetResponseHeaders(rosettaRequestID string) (http.Header, bool) {
hf.incomingHeaderLock.RLock()
headers, ok := hf.incomingHeaders[rosettaRequestID]
hf.incomingHeaderLock.RUnlock()
Expand Down Expand Up @@ -223,7 +230,7 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle
wrappedResponseWriter := NewResponseWriter(
w,
RosettaIDFromRequest(requestWithID),
hf.getResponseHeaders,
hf.GetResponseHeaders,
)

// Serve the request
Expand Down Expand Up @@ -256,9 +263,17 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error)
return resp, err
}

func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
func (hf *HeaderForwarder) UnaryClientInterceptor(
ctx context.Context,
method string,
req,
reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
// Capture incoming headers from the grpc call
var header metadata.MD
header := make(metadata.MD)
opts = append(opts, grpc.Header(&header))

// Get outgoing headers from the request ID in context
Expand Down
Loading

0 comments on commit 06d0fe9

Please sign in to comment.