Skip to content

Commit 7c7e069

Browse files
authored
add unwrap stream interceptor (#6813)
Signed-off-by: yeya24 <[email protected]>
1 parent fd1dca4 commit 7c7e069

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

pkg/util/grpcclient/instrumentation.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ func Instrument(requestDuration *prometheus.HistogramVec) ([]grpc.UnaryClientInt
1919
cortexmiddleware.PrometheusGRPCUnaryInstrumentation(requestDuration),
2020
}, []grpc.StreamClientInterceptor{
2121
grpcutil.HTTPHeaderPropagationStreamClientInterceptor,
22+
unwrapErrorStreamClientInterceptor(),
2223
otgrpc.OpenTracingStreamClientInterceptor(opentracing.GlobalTracer()),
2324
middleware.StreamClientUserHeaderInterceptor,
2425
cortexmiddleware.PrometheusGRPCStreamInstrumentation(requestDuration),

pkg/util/grpcclient/unwrap.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package grpcclient
2+
3+
import (
4+
"context"
5+
"errors"
6+
7+
"google.golang.org/grpc"
8+
)
9+
10+
// unwrapErrorStreamClientInterceptor unwraps errors wrapped by OpenTracingStreamClientInterceptor
11+
func unwrapErrorStreamClientInterceptor() grpc.StreamClientInterceptor {
12+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
13+
stream, err := streamer(ctx, desc, cc, method, opts...)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
return &unwrapErrorClientStream{
19+
ClientStream: stream,
20+
}, nil
21+
}
22+
}
23+
24+
type unwrapErrorClientStream struct {
25+
grpc.ClientStream
26+
}
27+
28+
func (s *unwrapErrorClientStream) RecvMsg(m interface{}) error {
29+
err := s.ClientStream.RecvMsg(m)
30+
if err != nil {
31+
// Try to unwrap the error to get the original error
32+
if wrappedErr := errors.Unwrap(err); wrappedErr != nil {
33+
return wrappedErr
34+
}
35+
}
36+
return err
37+
}

pkg/util/grpcclient/unwrap_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package grpcclient
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
otgrpc "github.com/opentracing-contrib/go-grpc"
9+
"github.com/opentracing/opentracing-go"
10+
"github.com/opentracing/opentracing-go/mocktracer"
11+
"github.com/stretchr/testify/require"
12+
"google.golang.org/grpc"
13+
"google.golang.org/grpc/metadata"
14+
)
15+
16+
type mockClientStream struct {
17+
recvErr error
18+
}
19+
20+
func (m *mockClientStream) RecvMsg(msg interface{}) error {
21+
return m.recvErr
22+
}
23+
24+
func (m *mockClientStream) Header() (metadata.MD, error) {
25+
return nil, nil
26+
}
27+
28+
func (m *mockClientStream) Trailer() metadata.MD {
29+
return nil
30+
}
31+
32+
func (m *mockClientStream) CloseSend() error {
33+
return nil
34+
}
35+
36+
func (m *mockClientStream) Context() context.Context {
37+
return context.Background()
38+
}
39+
40+
func (m *mockClientStream) SendMsg(interface{}) error {
41+
return nil
42+
}
43+
44+
func TestUnwrapErrorStreamClientInterceptor(t *testing.T) {
45+
// Create a mock tracer
46+
tracer := mocktracer.New()
47+
opentracing.SetGlobalTracer(tracer)
48+
49+
originalErr := errors.New("original error")
50+
// Create a mock stream that returns the original error
51+
mockStream := &mockClientStream{
52+
recvErr: originalErr,
53+
}
54+
55+
// Create a mock streamer that returns our mock stream
56+
mockStreamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
57+
return mockStream, nil
58+
}
59+
60+
// Create the interceptor chain
61+
otStreamInterceptor := otgrpc.OpenTracingStreamClientInterceptor(tracer)
62+
interceptors := []grpc.StreamClientInterceptor{
63+
unwrapErrorStreamClientInterceptor(),
64+
otStreamInterceptor,
65+
}
66+
67+
// Chain the interceptors
68+
chainedStreamer := mockStreamer
69+
for i := len(interceptors) - 1; i >= 0; i-- {
70+
chainedStreamer = func(interceptor grpc.StreamClientInterceptor, next grpc.Streamer) grpc.Streamer {
71+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
72+
return interceptor(ctx, desc, cc, method, next, opts...)
73+
}
74+
}(interceptors[i], chainedStreamer)
75+
}
76+
77+
// Call the chained streamer
78+
ctx := context.Background()
79+
stream, err := chainedStreamer(ctx, &grpc.StreamDesc{}, nil, "test")
80+
require.NoError(t, err)
81+
var msg interface{}
82+
err = stream.RecvMsg(&msg)
83+
require.Error(t, err)
84+
require.EqualError(t, err, originalErr.Error())
85+
86+
// Only wrap OpenTracingStreamClientInterceptor.
87+
chainedStreamerWithoutUnwrapErr := func(interceptor grpc.StreamClientInterceptor, next grpc.Streamer) grpc.Streamer {
88+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
89+
return interceptor(ctx, desc, cc, method, next, opts...)
90+
}
91+
}(otStreamInterceptor, mockStreamer)
92+
stream, err = chainedStreamerWithoutUnwrapErr(ctx, &grpc.StreamDesc{}, nil, "test")
93+
require.NoError(t, err)
94+
err = stream.RecvMsg(&msg)
95+
require.Error(t, err)
96+
// Error is wrapped by OpenTracingStreamClientInterceptor and not unwrapped.
97+
require.Contains(t, err.Error(), "failed to receive message")
98+
}

0 commit comments

Comments
 (0)