Skip to content

Commit cb9408e

Browse files
committed
fix: ttstream cancelFunc panic
1 parent cabe684 commit cb9408e

File tree

7 files changed

+117
-32
lines changed

7 files changed

+117
-32
lines changed

pkg/remote/trans/ttstream/exception.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ var (
3434

3535
errBizCancel = newException("user code invoking stream RPC with context processed by context.WithCancel or context.WithTimeout, then invoking cancel() actively",
3636
kerrors.ErrStreamingCanceled, 12007)
37-
errBizCancelWithCause = newException("user code canceled with cancelCause(error)", kerrors.ErrStreamingCanceled, 12008)
38-
errDownstreamCancel = newException("canceled by downstream", kerrors.ErrStreamingCanceled, 12009)
39-
errUpstreamCancel = newException("canceled by upstream", kerrors.ErrStreamingCanceled, 12010)
40-
errInternalCancel = newException("internal canceled", kerrors.ErrStreamingCanceled, 12011)
37+
errBizCancelWithCause = newException("user code canceled with cancelCause(error)", kerrors.ErrStreamingCanceled, 12008)
38+
errDownstreamCancel = newException("canceled by downstream", kerrors.ErrStreamingCanceled, 12009)
39+
errUpstreamCancel = newException("canceled by upstream", kerrors.ErrStreamingCanceled, 12010)
40+
errInternalCancel = newException("internal canceled", kerrors.ErrStreamingCanceled, 12011)
41+
errBizHandlerReturnCancel = newException("canceled by business handler returning", kerrors.ErrStreamingCanceled, 12012)
42+
errConnectionClosedCancel = newException("canceled by connection closed", kerrors.ErrStreamingCanceled, 12013)
4143
)
4244

4345
const (
@@ -103,8 +105,10 @@ func (e *Exception) Error() string {
103105
}
104106

105107
func (e *Exception) withCause(cause error) *Exception {
106-
e.cause = cause
107-
e.bitSet |= setCause
108+
if cause != nil {
109+
e.cause = cause
110+
e.bitSet |= setCause
111+
}
108112
return e
109113
}
110114

pkg/remote/trans/ttstream/exception_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ func TestErrors(t *testing.T) {
3939
test.Assert(t, errors.Is(appErr, errApplicationException), appErr)
4040
test.Assert(t, !errors.Is(appErr, kerrors.ErrStreamingProtocol), appErr)
4141
test.Assert(t, strings.Contains(appErr.Error(), causeErr.Error()))
42+
43+
newExWithNilErr := errIllegalFrame.newBuilder().withCause(nil)
44+
test.Assert(t, !newExWithNilErr.isCauseSet(), newExWithNilErr)
45+
test.Assert(t, newExWithNilErr.cause == nil, newExWithNilErr)
4246
}
4347

4448
func TestCommonParentKerror(t *testing.T) {

pkg/remote/trans/ttstream/stream_server.go

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func newServerStream(ctx context.Context, writer streamWriter, smeta streamFrame
5050
type serverStream struct {
5151
*stream
5252
state int32
53-
cancelFunc cancelWithReason
53+
cancelFunc cancelWithReason // use serverStream.cancel to prevent atomic.Value panic caused by passing different error types
5454
}
5555

5656
func (s *serverStream) SetHeader(hd streaming.Header) error {
@@ -121,7 +121,10 @@ func (s *serverStream) SendMsg(ctx context.Context, res any) error {
121121
// CloseSend by serverStream will be called after server handler returned
122122
// after CloseSend stream cannot be access again
123123
func (s *serverStream) CloseSend(exception error) error {
124-
return s.close(exception, true)
124+
if s.close(errBizHandlerReturnCancel) {
125+
return nil
126+
}
127+
return s.sendTrailer(exception)
125128
}
126129

127130
// closeRecv called only when server receiving Trailer Frame
@@ -132,22 +135,24 @@ func (s *serverStream) closeRecv(exception error) error {
132135
return nil
133136
}
134137

135-
func (s *serverStream) close(exception error, sendTrailer bool) error {
138+
// cancel ensures the ex passed in cancelFunc is of type *Exception
139+
func (s *serverStream) cancel(ex *Exception) {
140+
s.cancelFunc(ex)
141+
}
142+
143+
func (s *serverStream) close(exception *Exception) (isClosed bool) {
136144
// support cascading cancel
137145
// we must cancel the ctx first before changing the state
138146
// otherwise Recv/Send will not get the expected exception
139-
s.cancelFunc(exception)
147+
s.cancel(exception)
140148
if atomic.SwapInt32(&s.state, streamStateInactive) == streamStateInactive {
141-
return nil
149+
return true
142150
}
143151

144152
s.reader.close(exception)
145153
s.runCloseCallback(exception)
146154

147-
if sendTrailer {
148-
return s.sendTrailer(exception)
149-
}
150-
return nil
155+
return false
151156
}
152157

153158
// === serverStream onRead callback
@@ -210,15 +215,16 @@ func (s *serverStream) onReadRstFrame(fr *Frame) (err error) {
210215
}
211216

212217
// when receiving rst frame, we should close stream and there is no need to send rst frame
213-
return s.close(rstEx, false)
218+
s.close(rstEx)
219+
return nil
214220
}
215221

216222
// closeTest is only used in unit tests for mocking Proxy Egress/Ingress send Rst Frame to downstream and upstream
217223
func (s *serverStream) closeTest(exception error, cancelPath string) error {
218224
// support cascading cancel
219225
// we must cancel the ctx first before changing the state
220226
// otherwise Recv/Send will not get the expected exception
221-
s.cancelFunc(exception)
227+
s.cancel(errInternalCancel.newBuilder().withCause(exception))
222228
if atomic.SwapInt32(&s.state, streamStateInactive) == streamStateInactive {
223229
return nil
224230
}

pkg/remote/trans/ttstream/stream_server_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,22 @@ func newTestServerStream() *serverStream {
3939
func Test_serverStreamStateChange(t *testing.T) {
4040
t.Run("serverStream close, then RecvMsg/SendMsg returning exception", func(t *testing.T) {
4141
srvSt := newTestServerStream()
42-
testException := errors.New("test")
43-
err := srvSt.close(testException, false)
44-
test.Assert(t, err == nil, err)
42+
testException := errInternalCancel.newBuilder().withCause(errors.New("test"))
43+
isClosed := srvSt.close(testException)
44+
test.Assert(t, !isClosed)
4545
rErr := srvSt.RecvMsg(srvSt.ctx, nil)
4646
test.Assert(t, rErr == testException, rErr)
4747
sErr := srvSt.SendMsg(srvSt.ctx, nil)
4848
test.Assert(t, sErr == testException, sErr)
4949
})
5050
t.Run("serverStream close twice with different exception, RecvMsg/SendMsg returning the first time exception", func(t *testing.T) {
5151
srvSt := newTestServerStream()
52-
testException1 := errors.New("test1")
53-
err := srvSt.close(testException1, false)
54-
test.Assert(t, err == nil, err)
55-
testException2 := errors.New("test2")
56-
err = srvSt.close(testException2, false)
57-
test.Assert(t, err == nil, err)
52+
testException1 := errInternalCancel.newBuilder().withCause(errors.New("test1"))
53+
isClosed := srvSt.close(testException1)
54+
test.Assert(t, !isClosed)
55+
testException2 := errInternalCancel.newBuilder().withCause(errors.New("test2"))
56+
isClosed = srvSt.close(testException2)
57+
test.Assert(t, isClosed)
5858

5959
rErr := srvSt.RecvMsg(srvSt.ctx, nil)
6060
test.Assert(t, rErr == testException1, rErr)
@@ -65,11 +65,11 @@ func Test_serverStreamStateChange(t *testing.T) {
6565
srvSt := newTestServerStream()
6666
var wg sync.WaitGroup
6767
wg.Add(2)
68-
testException := errors.New("test")
68+
testException := errInternalCancel.newBuilder().withCause(errors.New("test1"))
6969
go func() {
7070
time.Sleep(100 * time.Millisecond)
71-
err := srvSt.close(testException, false)
72-
test.Assert(t, err == nil, err)
71+
isClosed := srvSt.close(testException)
72+
test.Assert(t, !isClosed)
7373
}()
7474
go func() {
7575
defer wg.Done()

pkg/remote/trans/ttstream/transport_client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (t *clientTransport) readFrame(reader bufiox.Reader) error {
172172
var ok bool
173173
s, ok = t.loadStream(fr.sid)
174174
if !ok {
175-
klog.Errorf("transport[%s] read a unknown stream: frame[%s]", t.Addr(), fr)
175+
klog.Debugf("transport[%s] read a unknown stream: frame[%s]", t.Addr(), fr)
176176
// ignore unknown stream error
177177
err = nil
178178
} else {

pkg/remote/trans/ttstream/transport_server.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,10 @@ func (t *serverTransport) Close(exception error) (err error) {
103103
}
104104
klog.Debugf("server transport[%s] is closing", t.Addr())
105105
// close streams first
106+
ex := errConnectionClosedCancel.newBuilder().withCause(exception)
106107
t.streams.Range(func(key, value any) bool {
107108
s := value.(*serverStream)
108-
_ = s.close(exception, false)
109+
_ = s.close(ex)
109110
return true
110111
})
111112
// then close stream and frame pipes
@@ -163,7 +164,9 @@ func (t *serverTransport) readFrame(reader bufiox.Reader) error {
163164
var ok bool
164165
s, ok = t.loadStream(fr.sid)
165166
if !ok {
166-
klog.Errorf("transport[%s] read a unknown stream: frame[%s]", t.Addr(), fr.String())
167+
// there is a race condition that server handler returns and client sends rst frame concurrently.
168+
// then serverTransport would not find the target stream when receiving the rst frame.
169+
klog.Debugf("transport[%s] read a unknown stream: frame[%s]", t.Addr(), fr.String())
167170
// ignore unknown stream error
168171
err = nil
169172
} else {

pkg/remote/trans/ttstream/transport_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,74 @@ func TestTransportException(t *testing.T) {
257257
test.Assert(t, errors.Is(err, errIllegalFrame), err)
258258
}
259259

260+
func TestTransportClose(t *testing.T) {
261+
cfd, sfd := netpoll.GetSysFdPairs()
262+
cconn, err := netpoll.NewFDConnection(cfd)
263+
test.Assert(t, err == nil, err)
264+
sconn, err := netpoll.NewFDConnection(sfd)
265+
test.Assert(t, err == nil, err)
266+
267+
intHeader := make(IntHeader)
268+
intHeader[0] = "test"
269+
strHeader := make(streaming.Header)
270+
strHeader["key"] = "val"
271+
ctrans := newClientTransport(cconn, nil)
272+
defer ctrans.Close(nil)
273+
ctx := context.Background()
274+
cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"})
275+
err = ctrans.WriteStream(ctx, cs, intHeader, strHeader)
276+
test.Assert(t, err == nil, err)
277+
strans := newServerTransport(sconn)
278+
ss, err := strans.ReadStream(context.Background())
279+
test.Assert(t, err == nil, err)
280+
281+
var wg sync.WaitGroup
282+
wg.Add(1)
283+
// client
284+
go func() {
285+
defer wg.Done()
286+
for {
287+
req := new(testRequest)
288+
req.B = "hello"
289+
sErr := cs.SendMsg(context.Background(), req)
290+
if sErr != nil {
291+
test.Assert(t, errors.Is(sErr, errIllegalFrame), sErr)
292+
return
293+
}
294+
295+
res := new(testResponse)
296+
rErr := cs.RecvMsg(context.Background(), res)
297+
if rErr != nil {
298+
test.Assert(t, errors.Is(rErr, errIllegalFrame), rErr)
299+
return
300+
}
301+
test.DeepEqual(t, req.B, res.B)
302+
}
303+
}()
304+
305+
go func() {
306+
time.Sleep(100 * time.Millisecond)
307+
strans.Close(nil)
308+
}()
309+
// server
310+
for {
311+
req := new(testRequest)
312+
rErr := ss.RecvMsg(context.Background(), req)
313+
if rErr != nil {
314+
test.Assert(t, errors.Is(rErr, errConnectionClosedCancel), rErr)
315+
break
316+
}
317+
res := new(testResponse)
318+
res.B = req.B
319+
sErr := ss.SendMsg(context.Background(), res)
320+
if sErr != nil {
321+
test.Assert(t, errors.Is(sErr, errConnectionClosedCancel), sErr)
322+
break
323+
}
324+
}
325+
wg.Wait()
326+
}
327+
260328
func TestStreamID(t *testing.T) {
261329
oriId := atomic.LoadInt32(&clientStreamID)
262330
atomic.StoreInt32(&clientStreamID, math.MaxInt32-1)

0 commit comments

Comments
 (0)