Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,10 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientSt
s.ctx = ctx
s.trReader = transportReader{
reader: recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: &s.buf,
closeStream: func(err error) {
s.Close(err)
},
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: &s.buf,
clientStream: s,
},
windowHandler: s,
}
Expand Down
22 changes: 11 additions & 11 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ func (b *recvBuffer) get() <-chan recvMsg {
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
_ noCopy
closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata.
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last mem.Buffer // Stores the remaining data in the previous calls.
err error
_ noCopy
clientStream *ClientStream // The client transport stream is closed with the given error and nil trailer metadata.
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last mem.Buffer // Stores the remaining data in the previous calls.
err error
}

func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
Expand All @@ -140,7 +140,7 @@ func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
n, r.last = mem.ReadUnsafe(header, r.last)
return n, nil
}
if r.closeStream != nil {
if r.clientStream != nil {
n, r.err = r.readMessageHeaderClient(header)
} else {
n, r.err = r.readMessageHeader(header)
Expand All @@ -165,7 +165,7 @@ func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
}
return buf, nil
}
if r.closeStream != nil {
if r.clientStream != nil {
buf, r.err = r.readClient(n)
} else {
buf, r.err = r.read(n)
Expand Down Expand Up @@ -210,7 +210,7 @@ func (r *recvBufferReader) readMessageHeaderClient(header []byte) (n int, err er
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readMessageHeaderAdditional(m, header)
case m := <-r.recv.get():
Expand All @@ -237,7 +237,7 @@ func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, n)
case m := <-r.recv.get():
Expand Down
3 changes: 0 additions & 3 deletions preloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
}

// check if the context has the relevant information to prepareMsg
if rpcInfo.preloaderInfo == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil")
}
if rpcInfo.preloaderInfo.codec == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil")
}
Expand Down
4 changes: 2 additions & 2 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxR
// Information about RPC
type rpcInfo struct {
failfast bool
preloaderInfo *compressorInfo
preloaderInfo compressorInfo
}

// Information about Preloader
Expand All @@ -980,7 +980,7 @@ type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
failfast: failfast,
preloaderInfo: &compressorInfo{
preloaderInfo: compressorInfo{
codec: codec,
cp: cp,
comp: comp,
Expand Down
10 changes: 6 additions & 4 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return cc.NewStream(ctx, desc, method, opts...)
}

var emptyMethodConfig = serviceconfig.MethodConfig{}

func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
// Start tracking the RPC for idleness purposes. This is where a stream is
// created for both streaming and unary RPCs, and hence is a good place to
Expand Down Expand Up @@ -217,7 +219,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return nil, err
}

var mc serviceconfig.MethodConfig
mc := &emptyMethodConfig
var onCommit func()
newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, nameResolutionDelayed, opts...)
Expand All @@ -240,7 +242,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if rpcConfig.Context != nil {
ctx = rpcConfig.Context
}
mc = rpcConfig.MethodConfig
mc = &rpcConfig.MethodConfig
onCommit = rpcConfig.OnCommitted
if rpcConfig.Interceptor != nil {
rpcInfo.Context = nil
Expand All @@ -258,7 +260,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return newStream(ctx, func() {})
}

func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) {
func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc *serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) {
callInfo := defaultCallInfo()
if mc.WaitForReady != nil {
callInfo.failFast = !*mc.WaitForReady
Expand Down Expand Up @@ -325,7 +327,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
cs := &clientStream{
callHdr: callHdr,
ctx: ctx,
methodConfig: &mc,
methodConfig: mc,
opts: opts,
callInfo: callInfo,
cc: cc,
Expand Down