Skip to content

Commit a5a8cb7

Browse files
committed
fix(gRPC): retrieve status or biz error for non-ServerStreaming
1 parent 7018a57 commit a5a8cb7

File tree

4 files changed

+203
-4
lines changed

4 files changed

+203
-4
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
run: |
2020
cd ..
2121
rm -rf kitex-tests
22-
git clone --depth=1 https://github.com/cloudwego/kitex-tests.git
22+
git clone -b fix/grpc_clientStreaming_recv_err --depth=1 https://github.com/DMwangnima/kitex-tests.git
2323
cd kitex-tests
2424
KITEX_TOOL_USE_PROTOC=0 ./run.sh ${{github.workspace}}
2525
cd ${{github.workspace}}

pkg/remote/codec/grpc/grpc.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/binary"
2222
"errors"
2323
"fmt"
24+
"io"
2425

2526
"github.com/cloudwego/fastpb"
2627

@@ -39,7 +40,10 @@ import (
3940

4041
const dataFrameHeaderLen = 5
4142

42-
var ErrInvalidPayload = errors.New("grpc invalid payload")
43+
var (
44+
ErrInvalidPayload = errors.New("grpc invalid payload")
45+
errWrongGRPCImplementation = errors.New("KITEX: grpc client streaming protocol violation: get <nil>, want <EOF>")
46+
)
4347

4448
// gogoproto generate
4549
type marshaler interface {
@@ -197,6 +201,22 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo
197201

198202
func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) {
199203
d, err := decodeGRPCFrame(ctx, in)
204+
// For ClientStreaming, server may return an err(e.g. status) as trailer frame after calling SendAndClose.
205+
// We need to receive this trailer frame.
206+
if message.RPCRole() == remote.Client && message.RPCInfo().Invocation().StreamingMode() == serviceinfo.StreamingClient && err == nil {
207+
// Receive trailer frame
208+
// If err == nil, wrong gRPC protocol implementation.
209+
// If err == io.EOF, it means server returns nil, just ignore io.EOF.
210+
// If err != io.EOF, it means server returns status err or BizStatusErr, or other gRPC transport error came out,
211+
// we need to throw it to users.
212+
_, err = decodeGRPCFrame(ctx, in)
213+
if err == nil {
214+
return errWrongGRPCImplementation
215+
}
216+
if err == io.EOF {
217+
err = nil
218+
}
219+
}
200220
if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil {
201221
// record recv size, even when err != nil (0 is recorded to the lastRecvSize)
202222
rpcStats.IncrRecvSize(uint64(len(d)))

pkg/remote/codec/grpc/grpc_test.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package grpc
18+
19+
import (
20+
"context"
21+
"io"
22+
"testing"
23+
24+
"github.com/golang/mock/gomock"
25+
26+
mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
27+
"github.com/cloudwego/kitex/internal/test"
28+
"github.com/cloudwego/kitex/pkg/remote"
29+
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
30+
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
31+
"github.com/cloudwego/kitex/pkg/rpcinfo"
32+
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
33+
"github.com/cloudwego/kitex/pkg/serviceinfo"
34+
)
35+
36+
func Test_grpcCodec_Decode(t *testing.T) {
37+
codec := NewGRPCCodec()
38+
ctrl := gomock.NewController(t)
39+
defer ctrl.Finish()
40+
41+
testcases := []struct {
42+
desc string
43+
role remote.RPCRole
44+
mode serviceinfo.StreamingMode
45+
getByteBufferFunc func() remote.ByteBuffer
46+
expectErr error
47+
}{
48+
{
49+
desc: "client-side ClientStreaming decodes first grpc frame failed",
50+
role: remote.Client,
51+
mode: serviceinfo.StreamingClient,
52+
getByteBufferFunc: func() remote.ByteBuffer {
53+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
54+
mockIn.EXPECT().Next(5).Return(nil, status.Err(codes.Internal, "test")).Times(1)
55+
return mockIn
56+
},
57+
expectErr: status.Err(codes.Internal, "test"),
58+
},
59+
{
60+
desc: "client-side ClientStreaming decodes second grpc frame successfully => wrong gRPC protocol implementation on the server side",
61+
role: remote.Client,
62+
mode: serviceinfo.StreamingClient,
63+
getByteBufferFunc: func() remote.ByteBuffer {
64+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
65+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
66+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
67+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
68+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
69+
return mockIn
70+
},
71+
expectErr: errWrongGRPCImplementation,
72+
},
73+
{
74+
desc: "client-side ClientStreaming decodes second grpc frame getting io.EOF => normal exit on the server side",
75+
role: remote.Client,
76+
mode: serviceinfo.StreamingClient,
77+
getByteBufferFunc: func() remote.ByteBuffer {
78+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
79+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
80+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
81+
mockIn.EXPECT().Next(5).Return(nil, io.EOF).Times(1)
82+
return mockIn
83+
},
84+
expectErr: ErrInvalidPayload,
85+
},
86+
{
87+
desc: "client-side ClientStreaming decodes second grpc frame getting gRPC errors",
88+
role: remote.Client,
89+
mode: serviceinfo.StreamingClient,
90+
getByteBufferFunc: func() remote.ByteBuffer {
91+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
92+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
93+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
94+
mockIn.EXPECT().Next(5).Return(nil, status.Err(codes.Internal, "gRPC errors")).Times(1)
95+
return mockIn
96+
},
97+
expectErr: status.Err(codes.Internal, "gRPC errors"),
98+
},
99+
{
100+
desc: "client-side ServerStreaming decodes",
101+
role: remote.Client,
102+
mode: serviceinfo.StreamingServer,
103+
getByteBufferFunc: func() remote.ByteBuffer {
104+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
105+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
106+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
107+
return mockIn
108+
},
109+
expectErr: ErrInvalidPayload,
110+
},
111+
{
112+
desc: "client-side BidiStreaming decodes",
113+
role: remote.Client,
114+
mode: serviceinfo.StreamingBidirectional,
115+
getByteBufferFunc: func() remote.ByteBuffer {
116+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
117+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
118+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
119+
return mockIn
120+
},
121+
expectErr: ErrInvalidPayload,
122+
},
123+
{
124+
desc: "client-side Unary decodes",
125+
role: remote.Client,
126+
mode: serviceinfo.StreamingUnary,
127+
getByteBufferFunc: func() remote.ByteBuffer {
128+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
129+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
130+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
131+
return mockIn
132+
},
133+
expectErr: ErrInvalidPayload,
134+
},
135+
{
136+
desc: "client-side None decodes",
137+
role: remote.Client,
138+
mode: serviceinfo.StreamingNone,
139+
getByteBufferFunc: func() remote.ByteBuffer {
140+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
141+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
142+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
143+
return mockIn
144+
},
145+
expectErr: ErrInvalidPayload,
146+
},
147+
{
148+
desc: "server-side decodes",
149+
role: remote.Server,
150+
getByteBufferFunc: func() remote.ByteBuffer {
151+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
152+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
153+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
154+
return mockIn
155+
},
156+
expectErr: ErrInvalidPayload,
157+
},
158+
}
159+
mockServiceName := "grpcService"
160+
mockMethod := "InvokeClientStreaming"
161+
for _, tc := range testcases {
162+
t.Run(tc.desc, func(t *testing.T) {
163+
inv := rpcinfo.NewInvocation(mockServiceName, mockMethod)
164+
inv.SetStreamingMode(tc.mode)
165+
cfg := rpcinfo.NewRPCConfig()
166+
// avoid unmarshal
167+
rpcinfo.AsMutableRPCConfig(cfg).SetPayloadCodec(serviceinfo.PayloadCodec(-1))
168+
ri := rpcinfo.NewRPCInfo(
169+
rpcinfo.EmptyEndpointInfo(),
170+
remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: mockServiceName}, mockMethod).ImmutableView(),
171+
inv, cfg, rpcinfo.NewRPCStats())
172+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
173+
mockMsg := remote.NewMessage(nil, ri, remote.Stream, tc.role)
174+
mockIn := tc.getByteBufferFunc()
175+
err := codec.Decode(ctx, mockMsg, mockIn)
176+
test.DeepEqual(t, err, tc.expectErr)
177+
})
178+
}
179+
}

pkg/remote/trans/nphttp2/stream.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (s *serverStream) SetTrailer(tl streaming.Trailer) error {
121121
func (s *serverStream) RecvMsg(ctx context.Context, m interface{}) error {
122122
ri := s.rpcInfo
123123

124-
msg := remote.NewMessage(m, ri, remote.Stream, remote.Client)
124+
msg := remote.NewMessage(m, ri, remote.Stream, remote.Server)
125125
defer msg.Recycle()
126126

127127
_, err := s.handler.Read(s.ctx, s.conn, msg)
@@ -133,7 +133,7 @@ func (s *serverStream) RecvMsg(ctx context.Context, m interface{}) error {
133133
func (s *serverStream) SendMsg(ctx context.Context, m interface{}) error {
134134
ri := s.rpcInfo
135135

136-
msg := remote.NewMessage(m, ri, remote.Stream, remote.Client)
136+
msg := remote.NewMessage(m, ri, remote.Stream, remote.Server)
137137
defer msg.Recycle()
138138

139139
_, err := s.handler.Write(s.ctx, s.conn, msg)

0 commit comments

Comments
 (0)