Skip to content

Commit 9e7f557

Browse files
committed
Updates for buffer management in RSocket
- Integration tests run with zero copy configuration. - RSocketBufferLeakTests has been added. - Updates in MessagingRSocket to ensure proper release See gh-21987
1 parent 23b39ad commit 9e7f557

File tree

10 files changed

+582
-58
lines changed

10 files changed

+582
-58
lines changed

spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -105,13 +105,15 @@ protected Consumer<DataBuffer> stringConsumer(String expected) {
105105
*/
106106
protected void waitForDataBufferRelease(Duration duration) throws InterruptedException {
107107
Instant start = Instant.now();
108-
while (Instant.now().isBefore(start.plus(duration))) {
108+
while (true) {
109109
try {
110110
verifyAllocations();
111111
break;
112112
}
113113
catch (AssertionError ex) {
114-
// ignore;
114+
if (Instant.now().isAfter(start.plus(duration))) {
115+
throw ex;
116+
}
115117
}
116118
Thread.sleep(50);
117119
}

spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractMethodMessageHandler.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,10 @@ private Match<T> getHandlerMethod(Message<?> message) {
396396
if (matches.size() > 1) {
397397
Match<T> secondBestMatch = matches.get(1);
398398
if (comparator.compare(bestMatch, secondBestMatch) == 0) {
399-
Method m1 = bestMatch.handlerMethod.getMethod();
400-
Method m2 = secondBestMatch.handlerMethod.getMethod();
399+
HandlerMethod m1 = bestMatch.handlerMethod;
400+
HandlerMethod m2 = secondBestMatch.handlerMethod;
401401
throw new IllegalStateException("Ambiguous handler methods mapped for destination '" +
402-
destination + "': {" + m1 + ", " + m2 + "}");
402+
destination + "': {" + m1.getShortLogMessage() + ", " + m2.getShortLogMessage() + "}");
403403
}
404404
}
405405
return bestMatch;

spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ private <T> Mono<T> retrieveMono(ResolvableType elementType) {
244244

245245
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
246246
return (Mono<T>) decoder.decodeToMono(
247-
payloadMono.map(this::wrapPayloadData), elementType, dataMimeType, EMPTY_HINTS);
247+
payloadMono.map(this::retainDataAndReleasePayload), elementType, dataMimeType, EMPTY_HINTS);
248248
}
249249

250250
@SuppressWarnings("unchecked")
@@ -260,12 +260,12 @@ private <T> Flux<T> retrieveFlux(ResolvableType elementType) {
260260

261261
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
262262

263-
return payloadFlux.map(this::wrapPayloadData).concatMap(dataBuffer ->
263+
return payloadFlux.map(this::retainDataAndReleasePayload).concatMap(dataBuffer ->
264264
(Mono<T>) decoder.decodeToMono(Mono.just(dataBuffer), elementType, dataMimeType, EMPTY_HINTS));
265265
}
266266

267-
private DataBuffer wrapPayloadData(Payload payload) {
268-
return PayloadUtils.wrapPayloadData(payload, strategies.dataBufferFactory());
267+
private DataBuffer retainDataAndReleasePayload(Payload payload) {
268+
return PayloadUtils.retainDataAndReleasePayload(payload, strategies.dataBufferFactory());
269269
}
270270
}
271271

spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java

+8-12
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
import java.util.List;
2222
import java.util.function.Consumer;
2323

24-
import io.netty.buffer.PooledByteBufAllocator;
25-
2624
import org.springframework.core.ReactiveAdapterRegistry;
2725
import org.springframework.core.codec.Decoder;
2826
import org.springframework.core.codec.Encoder;
2927
import org.springframework.core.io.buffer.DataBufferFactory;
30-
import org.springframework.core.io.buffer.NettyDataBufferFactory;
28+
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
3129
import org.springframework.lang.Nullable;
30+
import org.springframework.util.Assert;
3231

3332
/**
3433
* Default, package-private {@link RSocketStrategies} implementation.
@@ -88,11 +87,10 @@ static class DefaultRSocketStrategiesBuilder implements RSocketStrategies.Builde
8887

8988
private final List<Decoder<?>> decoders = new ArrayList<>();
9089

91-
@Nullable
92-
private ReactiveAdapterRegistry adapterRegistry;
90+
private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
9391

9492
@Nullable
95-
private DataBufferFactory bufferFactory;
93+
private DataBufferFactory dataBufferFactory;
9694

9795

9896
@Override
@@ -121,23 +119,21 @@ public Builder decoders(Consumer<List<Decoder<?>>> consumer) {
121119

122120
@Override
123121
public Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry) {
122+
Assert.notNull(registry, "ReactiveAdapterRegistry is required");
124123
this.adapterRegistry = registry;
125124
return this;
126125
}
127126

128127
@Override
129128
public Builder dataBufferFactory(DataBufferFactory bufferFactory) {
130-
this.bufferFactory = bufferFactory;
129+
this.dataBufferFactory = bufferFactory;
131130
return this;
132131
}
133132

134133
@Override
135134
public RSocketStrategies build() {
136-
return new DefaultRSocketStrategies(this.encoders, this.decoders,
137-
this.adapterRegistry != null ?
138-
this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance(),
139-
this.bufferFactory != null ? this.bufferFactory :
140-
new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT));
135+
return new DefaultRSocketStrategies(this.encoders, this.decoders, this.adapterRegistry,
136+
this.dataBufferFactory != null ? this.dataBufferFactory : new DefaultDataBufferFactory());
141137
}
142138
}
143139

spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java

+43-21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.messaging.rsocket;
1717

18+
import java.util.concurrent.atomic.AtomicBoolean;
1819
import java.util.function.Function;
1920

2021
import io.rsocket.AbstractRSocket;
@@ -29,7 +30,7 @@
2930
import org.springframework.core.io.buffer.DataBuffer;
3031
import org.springframework.core.io.buffer.DataBufferFactory;
3132
import org.springframework.core.io.buffer.DataBufferUtils;
32-
import org.springframework.core.io.buffer.PooledDataBuffer;
33+
import org.springframework.core.io.buffer.NettyDataBuffer;
3334
import org.springframework.lang.Nullable;
3435
import org.springframework.messaging.Message;
3536
import org.springframework.messaging.MessageHeaders;
@@ -84,6 +85,9 @@ public Mono<Void> handleConnectionSetupPayload(ConnectionSetupPayload payload) {
8485
if (StringUtils.hasText(payload.dataMimeType())) {
8586
this.dataMimeType = MimeTypeUtils.parseMimeType(payload.dataMimeType());
8687
}
88+
// frameDecoder does not apply to connectionSetupPayload
89+
// so retain here since handle expects it..
90+
payload.retain();
8791
return handle(payload);
8892
}
8993

@@ -120,54 +124,72 @@ public Mono<Void> metadataPush(Payload payload) {
120124

121125

122126
private Mono<Void> handle(Payload payload) {
123-
Message<?> message = MessageBuilder.createMessage(
124-
Mono.fromCallable(() -> wrapPayloadData(payload)), createHeaders(payload, null));
127+
String destination = getDestination(payload);
128+
MessageHeaders headers = createHeaders(destination, null);
129+
DataBuffer dataBuffer = retainDataAndReleasePayload(payload);
130+
int refCount = refCount(dataBuffer);
131+
Message<?> message = MessageBuilder.createMessage(dataBuffer, headers);
132+
return Mono.defer(() -> this.handler.apply(message))
133+
.doFinally(s -> {
134+
if (refCount(dataBuffer) == refCount) {
135+
DataBufferUtils.release(dataBuffer);
136+
}
137+
});
138+
}
125139

126-
return this.handler.apply(message);
140+
private int refCount(DataBuffer dataBuffer) {
141+
return dataBuffer instanceof NettyDataBuffer ?
142+
((NettyDataBuffer) dataBuffer).getNativeBuffer().refCnt() : 1;
127143
}
128144

129145
private Flux<Payload> handleAndReply(Payload firstPayload, Flux<Payload> payloads) {
130146
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
131-
Message<?> message = MessageBuilder.createMessage(
132-
payloads.map(this::wrapPayloadData).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release),
133-
createHeaders(firstPayload, replyMono));
134-
135-
return this.handler.apply(message)
147+
String destination = getDestination(firstPayload);
148+
MessageHeaders headers = createHeaders(destination, replyMono);
149+
150+
AtomicBoolean read = new AtomicBoolean();
151+
Flux<DataBuffer> buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true));
152+
Message<Flux<DataBuffer>> message = MessageBuilder.createMessage(buffers, headers);
153+
154+
return Mono.defer(() -> this.handler.apply(message))
155+
.doFinally(s -> {
156+
// Subscription should have happened by now due to ChannelSendOperator
157+
if (!read.get()) {
158+
buffers.subscribe(DataBufferUtils::release);
159+
}
160+
})
136161
.thenMany(Flux.defer(() -> replyMono.isTerminated() ?
137162
replyMono.flatMapMany(Function.identity()) :
138163
Mono.error(new IllegalStateException("Something went wrong: reply Mono not set"))));
139164
}
140165

141-
private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor<?> replyMono) {
166+
private String getDestination(Payload payload) {
142167

143168
// TODO:
144169
// For now treat the metadata as a simple string with routing information.
145170
// We'll have to get more sophisticated once the routing extension is completed.
146171
// https://github.com/rsocket/rsocket-java/issues/568
147172

148-
MessageHeaderAccessor headers = new MessageHeaderAccessor();
173+
return payload.getMetadataUtf8();
174+
}
149175

150-
String destination = payload.getMetadataUtf8();
151-
headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination);
176+
private DataBuffer retainDataAndReleasePayload(Payload payload) {
177+
return PayloadUtils.retainDataAndReleasePayload(payload, this.strategies.dataBufferFactory());
178+
}
152179

180+
private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor<?> replyMono) {
181+
MessageHeaderAccessor headers = new MessageHeaderAccessor();
182+
headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination);
153183
if (this.dataMimeType != null) {
154184
headers.setContentType(this.dataMimeType);
155185
}
156-
157186
headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester);
158-
159187
if (replyMono != null) {
160188
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
161189
}
162-
163190
DataBufferFactory bufferFactory = this.strategies.dataBufferFactory();
164191
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory);
165-
166192
return headers.getMessageHeaders();
167193
}
168194

169-
private DataBuffer wrapPayloadData(Payload payload) {
170-
return PayloadUtils.wrapPayloadData(payload, this.strategies.dataBufferFactory());
171-
}
172-
173195
}

spring-messaging/src/main/java/org/springframework/messaging/rsocket/PayloadUtils.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package org.springframework.messaging.rsocket;
1717

18+
import io.netty.buffer.ByteBuf;
19+
import io.rsocket.Frame;
1820
import io.rsocket.Payload;
1921
import io.rsocket.util.ByteBufPayload;
2022
import io.rsocket.util.DefaultPayload;
@@ -24,6 +26,7 @@
2426
import org.springframework.core.io.buffer.DefaultDataBuffer;
2527
import org.springframework.core.io.buffer.NettyDataBuffer;
2628
import org.springframework.core.io.buffer.NettyDataBufferFactory;
29+
import org.springframework.util.Assert;
2730

2831
/**
2932
* Static utility methods to create {@link Payload} from {@link DataBuffer}s
@@ -35,19 +38,31 @@
3538
abstract class PayloadUtils {
3639

3740
/**
38-
* Return the Payload data wrapped as DataBuffer. If the bufferFactory is
39-
* {@link NettyDataBufferFactory} the payload retained and sliced.
40-
* @param payload the input payload
41-
* @param bufferFactory the BufferFactory to use to wrap
42-
* @return the DataBuffer wrapper
41+
* Use this method to slice, retain and wrap the data portion of the
42+
* {@code Payload}, and also to release the {@code Payload}. This assumes
43+
* the Payload metadata has been read by now and ensures downstream code
44+
* need only be aware of {@code DataBuffer}s.
45+
* @param payload the payload to process
46+
* @param bufferFactory the DataBufferFactory to wrap with
47+
* @return the created {@code DataBuffer} instance
4348
*/
44-
public static DataBuffer wrapPayloadData(Payload payload, DataBufferFactory bufferFactory) {
45-
if (bufferFactory instanceof NettyDataBufferFactory) {
46-
return ((NettyDataBufferFactory) bufferFactory).wrap(payload.retain().sliceData());
47-
}
48-
else {
49+
public static DataBuffer retainDataAndReleasePayload(Payload payload, DataBufferFactory bufferFactory) {
50+
try {
51+
if (bufferFactory instanceof NettyDataBufferFactory) {
52+
ByteBuf byteBuf = payload.sliceData().retain();
53+
return ((NettyDataBufferFactory) bufferFactory).wrap(byteBuf);
54+
}
55+
56+
Assert.isTrue(!(payload instanceof ByteBufPayload) && !(payload instanceof Frame),
57+
"NettyDataBufferFactory expected, actual: " + bufferFactory.getClass().getSimpleName());
58+
4959
return bufferFactory.wrap(payload.getData());
5060
}
61+
finally {
62+
if (payload.refCnt() > 0) {
63+
payload.release();
64+
}
65+
}
5166
}
5267

5368
/**

spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketStrategies.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,23 @@ interface Builder {
142142
Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry);
143143

144144
/**
145-
* Configure the DataBufferFactory to use for the allocation of buffers
146-
* when creating or responding requests.
147-
* <p>By default this is an instance of
145+
* Configure the DataBufferFactory to use for allocating buffers, for
146+
* example when preparing requests or when responding. The choice here
147+
* must be aligned with the frame decoder configured in
148+
* {@link io.rsocket.RSocketFactory}.
149+
* <p>By default this property is an instance of
150+
* {@link org.springframework.core.io.buffer.DefaultDataBufferFactory
151+
* DefaultDataBufferFactory} matching to the default frame decoder in
152+
* {@link io.rsocket.RSocketFactory} which copies the payload. This
153+
* comes at cost to performance but does not require reference counting
154+
* and eliminates possibility for memory leaks.
155+
* <p>To switch to a zero-copy strategy,
156+
* <a href="https://github.com/rsocket/rsocket-java#zero-copy">configure RSocket</a>
157+
* accordingly, and then configure this property with an instance of
148158
* {@link org.springframework.core.io.buffer.NettyDataBufferFactory
149-
* NettyDataBufferFactory} with {@link PooledByteBufAllocator#DEFAULT}.
150-
* @param bufferFactory the buffer factory to use
159+
* NettyDataBufferFactory} with a pooled allocator such as
160+
* {@link PooledByteBufAllocator#DEFAULT}.
161+
* @param bufferFactory the DataBufferFactory to use
151162
*/
152163
Builder dataBufferFactory(DataBufferFactory bufferFactory);
153164

0 commit comments

Comments
 (0)