Skip to content

Commit 0d37566

Browse files
artembilangaryrussell
authored andcommitted
GH-2744: ScatterGather: reinstate request headers
Fixes #2744 When we get scattering results, there is no reason to keep internal headers any more. * Fix `ScatterGatherHandler` to modify scattering result messages to reinstate headers from original request message. This way we are able to re-throw an exception from the gatherer to the caller. * Fix typos and language in Docs
1 parent d38db25 commit 0d37566

File tree

4 files changed

+81
-36
lines changed

4 files changed

+81
-36
lines changed

spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.springframework.beans.factory.BeanFactory;
2121
import org.springframework.beans.factory.BeanInitializationException;
2222
import org.springframework.context.Lifecycle;
23+
import org.springframework.integration.channel.ChannelInterceptorAware;
2324
import org.springframework.integration.channel.FixedSubscriberChannel;
2425
import org.springframework.integration.channel.QueueChannel;
2526
import org.springframework.integration.channel.ReactiveStreamsSubscribableChannel;
@@ -30,14 +31,14 @@
3031
import org.springframework.integration.endpoint.PollingConsumer;
3132
import org.springframework.integration.endpoint.ReactiveStreamsConsumer;
3233
import org.springframework.integration.handler.AbstractReplyProducingMessageHandler;
33-
import org.springframework.integration.support.channel.HeaderChannelRegistry;
3434
import org.springframework.messaging.Message;
3535
import org.springframework.messaging.MessageChannel;
3636
import org.springframework.messaging.MessageDeliveryException;
3737
import org.springframework.messaging.MessageHandler;
3838
import org.springframework.messaging.MessageHeaders;
3939
import org.springframework.messaging.PollableChannel;
4040
import org.springframework.messaging.SubscribableChannel;
41+
import org.springframework.messaging.support.ChannelInterceptor;
4142
import org.springframework.util.Assert;
4243
import org.springframework.util.ClassUtils;
4344

@@ -66,8 +67,6 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i
6667

6768
private AbstractEndpoint gatherEndpoint;
6869

69-
private HeaderChannelRegistry replyChannelRegistry;
70-
7170

7271
public ScatterGatherHandler(MessageHandler scatterer, MessageHandler gatherer) {
7372
this(new FixedSubscriberChannel(scatterer), gatherer);
@@ -134,52 +133,64 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
134133
((MessageProducer) this.gatherer)
135134
.setOutputChannel(new FixedSubscriberChannel(message -> {
136135
MessageHeaders headers = message.getHeaders();
137-
if (headers.containsKey(GATHER_RESULT_CHANNEL)) {
138-
Object gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL);
139-
if (gatherResultChannel instanceof MessageChannel) {
140-
messagingTemplate.send((MessageChannel) gatherResultChannel, message);
141-
}
142-
else if (gatherResultChannel instanceof String) {
143-
messagingTemplate.send((String) gatherResultChannel, message);
144-
}
136+
MessageChannel gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL, MessageChannel.class);
137+
if (gatherResultChannel != null) {
138+
this.messagingTemplate.send(gatherResultChannel, message);
145139
}
146140
else {
147141
throw new MessageDeliveryException(message,
148-
"The 'gatherResultChannel' header is required to delivery gather result.");
142+
"The 'gatherResultChannel' header is required to deliver the gather result.");
149143
}
150144
}));
151-
152-
this.replyChannelRegistry =
153-
beanFactory.getBean(IntegrationContextUtils.INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME,
154-
HeaderChannelRegistry.class);
155145
}
156146

157147
@Override
158148
protected Object handleRequestMessage(Message<?> requestMessage) {
159149
PollableChannel gatherResultChannel = new QueueChannel();
160150

161-
Object gatherResultChannelName = this.replyChannelRegistry.channelToChannelName(gatherResultChannel);
151+
MessageChannel replyChannel = this.gatherChannel;
152+
153+
if (replyChannel instanceof ChannelInterceptorAware) {
154+
((ChannelInterceptorAware) replyChannel)
155+
.addInterceptor(0,
156+
new ChannelInterceptor() {
157+
158+
@Override
159+
public Message<?> preSend(Message<?> message, MessageChannel channel) {
160+
return enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage);
161+
}
162+
163+
});
164+
}
165+
else {
166+
replyChannel =
167+
new FixedSubscriberChannel(message ->
168+
this.messagingTemplate.send(this.gatherChannel,
169+
enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage)));
170+
}
162171

163172
Message<?> scatterMessage =
164173
getMessageBuilderFactory()
165174
.fromMessage(requestMessage)
166-
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannelName)
167-
.setReplyChannel(this.gatherChannel)
175+
.setReplyChannel(replyChannel)
168176
.setErrorChannelName(this.errorChannelName)
169177
.build();
170178

171179
this.messagingTemplate.send(this.scatterChannel, scatterMessage);
172180

173-
Message<?> gatherResult = gatherResultChannel.receive(this.gatherTimeout);
174-
if (gatherResult != null) {
175-
return getMessageBuilderFactory()
176-
.fromMessage(gatherResult)
177-
.removeHeader(GATHER_RESULT_CHANNEL)
178-
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessage.getHeaders().getReplyChannel())
179-
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessage.getHeaders().getErrorChannel());
180-
}
181+
return gatherResultChannel.receive(this.gatherTimeout);
182+
}
183+
184+
private Message<?> enhanceScatterReplyMessage(Message<?> message, PollableChannel gatherResultChannel,
185+
Message<?> requestMessage) {
181186

182-
return null;
187+
MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
188+
return getMessageBuilderFactory()
189+
.fromMessage(message)
190+
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
191+
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
192+
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
193+
.build();
183194
}
184195

185196
@Override
@@ -201,11 +212,11 @@ public boolean isRunning() {
201212
return this.gatherEndpoint == null || this.gatherEndpoint.isRunning();
202213
}
203214

204-
private void checkClass(Class<?> gathererClass, String className, String type) throws LinkageError {
215+
private static void checkClass(Class<?> gathererClass, String className, String type) throws LinkageError {
205216
try {
206217
Class<?> clazz = ClassUtils.forName(className, ClassUtils.getDefaultClassLoader());
207-
Assert.isAssignable(clazz, gathererClass, () -> "the '" + type + "' must be an " + className + " " +
208-
"instance");
218+
Assert.isAssignable(clazz, gathererClass,
219+
() -> "the '" + type + "' must be an " + className + " " + "instance");
209220
}
210221
catch (ClassNotFoundException e) {
211222
throw new IllegalStateException("The class for '" + className + "' cannot be loaded", e);

spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.integration.dsl.routers;
1818

19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1920
import static org.hamcrest.Matchers.containsString;
2021
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
2122
import static org.hamcrest.Matchers.instanceOf;
@@ -29,6 +30,7 @@
2930
import java.util.Arrays;
3031
import java.util.List;
3132
import java.util.concurrent.atomic.AtomicReference;
33+
import java.util.function.Function;
3234
import java.util.stream.Collectors;
3335

3436
import org.junit.Test;
@@ -590,6 +592,16 @@ public void testScatterGatherWithExecutorChannelSubFlow() {
590592
assertThat(((List) payload).get(1), instanceOf(RuntimeException.class));
591593
}
592594

595+
@Autowired
596+
@Qualifier("propagateErrorFromGatherer.gateway")
597+
private Function<Object, ?> propagateErrorFromGathererGateway;
598+
599+
@Test
600+
public void propagateErrorFromGatherer() {
601+
assertThatThrownBy(() -> propagateErrorFromGathererGateway.apply("bar"))
602+
.hasMessage("intentional");
603+
}
604+
593605
@Configuration
594606
@EnableIntegration
595607
@EnableMessageHistory({ "recipientListOrder*", "recipient1*", "recipient2*" })
@@ -881,6 +893,22 @@ public Message<?> processAsyncScatterError(MessagingException payload) {
881893
.build();
882894
}
883895

896+
@Bean
897+
public IntegrationFlow propagateErrorFromGatherer(TaskExecutor taskExecutor) {
898+
return IntegrationFlows.from(Function.class)
899+
.scatterGather(s -> s
900+
.applySequence(true)
901+
.recipientFlow(subFlow -> subFlow
902+
.channel(c -> c.executor(taskExecutor))
903+
.transform(p -> "foo")),
904+
g -> g
905+
.outputProcessor(group -> {
906+
throw new RuntimeException("intentional");
907+
}),
908+
sg -> sg.gatherTimeout(100))
909+
.get();
910+
}
911+
884912
}
885913

886914
private static class RoutingTestBean {

spring-integration-jmx/src/test/java/org/springframework/integration/monitor/ScatterGatherHandlerIntegrationTests.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2014-2015 the original author or authors.
2+
* Copyright 2014-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
2424

2525
import java.util.Arrays;
2626
import java.util.List;
27-
import java.util.concurrent.Executors;
27+
import java.util.concurrent.Executor;
2828

2929
import org.junit.Test;
3030
import org.junit.runner.RunWith;
@@ -278,8 +278,8 @@ public MessageChannel gatherChannel() {
278278
}
279279

280280
@Bean
281-
public SubscribableChannel scatterAuctionWithGatherChannel() {
282-
PublishSubscribeChannel channel = new PublishSubscribeChannel(Executors.newCachedThreadPool());
281+
public SubscribableChannel scatterAuctionWithGatherChannel(Executor executor) {
282+
PublishSubscribeChannel channel = new PublishSubscribeChannel(executor);
283283
channel.setApplySequence(true);
284284
return channel;
285285
}
@@ -296,7 +296,8 @@ public MessageHandler gatherer2() {
296296
@Bean
297297
@ServiceActivator(inputChannel = "inputAuctionWithGatherChannel")
298298
public MessageHandler scatterGatherAuctionWithGatherChannel() {
299-
ScatterGatherHandler handler = new ScatterGatherHandler(scatterAuctionWithGatherChannel(), gatherer2());
299+
ScatterGatherHandler handler =
300+
new ScatterGatherHandler(scatterAuctionWithGatherChannel(null), gatherer2());
300301
handler.setGatherChannel(gatherChannel());
301302
handler.setOutputChannel(output());
302303
return handler;

src/reference/asciidoc/scatter-gather.adoc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,8 @@ public Message<?> processAsyncScatterError(MessagingException payload) {
206206
To produce a proper reply, we have to copy headers (including `replyChannel` and `errorChannel`) from the `failedMessage` of the `MessagingException` that has been sent to the `scatterGatherErrorChannel` by the `MessagePublishingErrorHandler`.
207207
This way the target exception is returned to the gatherer of the `ScatterGatherHandler` for reply messages group completion.
208208
Such an exception `payload` can be filtered out in the `MessageGroupProcessor` of the gatherer or processed other way downstream, after the scatter-gather endpoint.
209+
210+
NOTE: Before sending scattering results to the gatherer, `ScatterGatherHandler` reinstates the request message headers, including reply and error channels if any.
211+
This way errors from the `AggregatingMessageHandler` are going to be propagated to the caller, even if an async hand off is applied in scatter recipient subflows.
212+
In this case a reasonable, finite `gatherTimeout` must be configured for the `ScatterGatherHandler`.
213+
Otherwise it is going to be blocked waiting for a reply from the gatherer forever, by default.

0 commit comments

Comments
 (0)