Skip to content

Commit e598034

Browse files
authored
feat: add ability to propagate HTTP headers with callbacks (#435)
Adds a mechanism for propgating HTTP headers from the original GraphQL subscription request to the subsequent callback responses. ```java @bean public CallbackWebGraphQLInterceptor callbackGraphQlInterceptor( SubscriptionCallbackHandler callbackHandler) { var headers = Set.of("X-Custom-Header", "X-Custom-Header2"); return new CallbackWebGraphQLInterceptor(callbackHandler, headers); } ```
1 parent 3b7d5b4 commit e598034

File tree

8 files changed

+169
-39
lines changed

8 files changed

+169
-39
lines changed

spring-subscription-callback/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ public class GraphQLConfiguration {
9393
// to allow users to still apply other interceptors that handle common stuff (e.g. extracting
9494
// auth headers, etc).
9595
// You can override this behavior by specifying custom order.
96+
// You can also provide a set of HTTP headers that should be propagated to the callback responses.
9697
@Bean
9798
public CallbackWebGraphQLInterceptor callbackGraphQlInterceptor(
9899
SubscriptionCallbackHandler callbackHandler) {
@@ -113,6 +114,8 @@ public class GraphQLConfiguration {
113114
}
114115
```
115116

117+
### Custom Scheduler
118+
116119
By default, subscription and heartbeat stream will be executed in non-blocking way on [bounded elastic](https://projectreactor.io/docs/core/release/api/reactor/core/scheduler/Schedulers.html#boundedElastic--)
117120
scheduler. If you need more granular control over the underlying scheduler, you can configure callback handler to run on
118121
your provided scheduler.
@@ -124,3 +127,17 @@ public SubscriptionCallbackHandler callbackHandler(ExecutionGraphQlService graph
124127
return new SubscriptionCallbackHandler(graphQlService, customScheduler);
125128
}
126129
```
130+
131+
### Propagating HTTP Headers
132+
133+
Subscription callback interceptor can be configured with a list of header names to propagate from the
134+
original GraphQL subscription request.
135+
136+
```java
137+
@Bean
138+
public CallbackWebGraphQLInterceptor callbackGraphQlInterceptor(
139+
SubscriptionCallbackHandler callbackHandler) {
140+
var headers = Set.of("X-Custom-Header");
141+
return new CallbackWebGraphQLInterceptor(callbackHandler, headers);
142+
}
143+
```

spring-subscription-callback/src/main/java/com/apollographql/subscription/CallbackWebGraphQLInterceptor.java

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import com.apollographql.subscription.callback.SubscriptionCallbackHandler;
88
import graphql.ExecutionResult;
99
import graphql.GraphQLError;
10+
import java.util.HashMap;
11+
import java.util.HashSet;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.Set;
1015
import org.apache.commons.logging.Log;
1116
import org.apache.commons.logging.LogFactory;
1217
import org.springframework.core.Ordered;
@@ -20,13 +25,15 @@
2025
import reactor.core.publisher.Mono;
2126

2227
/**
23-
* Interceptor that provides support for Apollo Subscription Callback Protocol. This interceptor
24-
* defaults to {@link Ordered#LOWEST_PRECEDENCE} order as it should run last in chain to allow users
25-
* to still apply other interceptors that handle common stuff (e.g. extracting auth headers, etc).
26-
* You can override this behavior by specifying custom order.
28+
* Interceptor that provides support for Apollo Subscription Callback Protocol. HTTP headers from
29+
* the original subscription GraphQL operation can be propagated to the subsequent callback
30+
* responses by specifying a Set of contextual header names. This interceptor defaults to {@link
31+
* Ordered#LOWEST_PRECEDENCE} order as it should run last in the chain to allow users to still apply
32+
* other interceptors that handle common stuff (e.g. extracting auth headers, etc). You can override
33+
* this behavior by specifying custom order.
2734
*
2835
* @see <a
29-
* href="https://www.apollographql.com/docs/router/executing-operations/subscription-callback-protocol">Subscription
36+
* href="https://www.apollographql.com/docs/graphos/routing/operations/subscriptions/callback-protocol">Subscription
3037
* Callback Protocol</a>
3138
*/
3239
public class CallbackWebGraphQLInterceptor implements WebGraphQlInterceptor, Ordered {
@@ -35,15 +42,29 @@ public class CallbackWebGraphQLInterceptor implements WebGraphQlInterceptor, Ord
3542

3643
private final SubscriptionCallbackHandler subscriptionCallbackHandler;
3744
private final int order;
45+
private final Set<String> contextualHeaders;
3846

3947
public CallbackWebGraphQLInterceptor(SubscriptionCallbackHandler subscriptionCallbackHandler) {
4048
this(subscriptionCallbackHandler, LOWEST_PRECEDENCE);
4149
}
4250

4351
public CallbackWebGraphQLInterceptor(
4452
SubscriptionCallbackHandler subscriptionCallbackHandler, int order) {
53+
this(subscriptionCallbackHandler, order, new HashSet<>());
54+
}
55+
56+
public CallbackWebGraphQLInterceptor(
57+
SubscriptionCallbackHandler subscriptionCallbackHandler, Set<String> contextualHeaders) {
58+
this(subscriptionCallbackHandler, LOWEST_PRECEDENCE, contextualHeaders);
59+
}
60+
61+
public CallbackWebGraphQLInterceptor(
62+
SubscriptionCallbackHandler subscriptionCallbackHandler,
63+
int order,
64+
Set<String> contextualHeaders) {
4565
this.subscriptionCallbackHandler = subscriptionCallbackHandler;
4666
this.order = order;
67+
this.contextualHeaders = contextualHeaders;
4768
}
4869

4970
@Override
@@ -58,8 +79,10 @@ public Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chain
5879
if (logger.isDebugEnabled()) {
5980
logger.debug("Starting subscription using callback: " + callback);
6081
}
82+
var callbackWithContextualData =
83+
callback.withContext(parseContextualHeaders(request));
6184
return this.subscriptionCallbackHandler
62-
.handleSubscriptionUsingCallback(request, callback)
85+
.handleSubscriptionUsingCallback(request, callbackWithContextualData)
6386
.map(response -> callbackResponse(request, response));
6487
})
6588
.onErrorResume(
@@ -104,4 +127,20 @@ private WebGraphQlResponse errorCallbackResponse(WebGraphQlRequest request) {
104127
public int getOrder() {
105128
return order;
106129
}
130+
131+
private Map<String, List<String>> parseContextualHeaders(WebGraphQlRequest request) {
132+
if (!this.contextualHeaders.isEmpty()) {
133+
Map<String, List<String>> contextualHeaders = new HashMap<>();
134+
var headers = request.getHeaders();
135+
for (String header : this.contextualHeaders) {
136+
if (headers.containsKey(header)) {
137+
var headerValue = request.getHeaders().get(header);
138+
contextualHeaders.put(header, headerValue);
139+
}
140+
}
141+
return contextualHeaders;
142+
} else {
143+
return new HashMap<>();
144+
}
145+
}
107146
}

spring-subscription-callback/src/main/java/com/apollographql/subscription/callback/SubscriptionCallback.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import com.apollographql.subscription.exception.CallbackExtensionNotSpecifiedException;
44
import com.apollographql.subscription.exception.InvalidCallbackExtensionException;
5+
import java.util.HashMap;
6+
import java.util.List;
57
import java.util.Map;
68
import org.jetbrains.annotations.NotNull;
79
import reactor.core.publisher.Mono;
@@ -13,12 +15,29 @@
1315
* @param subscription_id The generated unique ID for the subscription operation
1416
* @param verifier A string that Emitter will include in all HTTP callback requests to verify its
1517
* identity
18+
* @param heartbeatIntervalMs Interval between heartbeats in milliseconds
19+
* @param context Contextual data that should be returned with callbacks as HTTP headers
1620
*/
1721
public record SubscriptionCallback(
1822
@NotNull String callback_url,
1923
@NotNull String subscription_id,
2024
@NotNull String verifier,
21-
int heartbeatIntervalMs) {
25+
int heartbeatIntervalMs,
26+
@NotNull Map<String, List<String>> context) {
27+
28+
public SubscriptionCallback(
29+
@NotNull String callback_url,
30+
@NotNull String subscription_id,
31+
@NotNull String verifier,
32+
int heartbeatIntervalMs) {
33+
this(callback_url, subscription_id, verifier, heartbeatIntervalMs, new HashMap<>());
34+
}
35+
36+
@NotNull
37+
public SubscriptionCallback withContext(@NotNull Map<String, List<String>> context) {
38+
return new SubscriptionCallback(
39+
callback_url, subscription_id, verifier, heartbeatIntervalMs, context);
40+
}
2241

2342
public static String SUBSCRIPTION_EXTENSION = "subscription";
2443
public static String CALLBACK_URL = "callbackUrl";

spring-subscription-callback/src/main/java/com/apollographql/subscription/callback/SubscriptionCallbackHandler.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ public Mono<ExecutionResult> handleSubscriptionUsingCallback(
5959
.post()
6060
.header("Content-Type", MediaType.APPLICATION_JSON_VALUE)
6161
.header(SUBSCRIPTION_PROTOCOL_HEADER, SUBSCRIPTION_PROTOCOL_HEADER_VALUE)
62+
.headers(httpHeaders -> httpHeaders.putAll(callback.context()))
6263
.bodyValue(checkMessage)
6364
.exchangeToMono(
6465
checkResponse -> {
@@ -149,6 +150,7 @@ protected Flux<SubscritionCallbackMessage> startSubscription(
149150
.post()
150151
.header("Content-Type", MediaType.APPLICATION_JSON_VALUE)
151152
.header(SUBSCRIPTION_PROTOCOL_HEADER, SUBSCRIPTION_PROTOCOL_HEADER_VALUE)
153+
.headers(httpHeaders -> httpHeaders.putAll(callback.context()))
152154
.bodyValue(message)
153155
.exchangeToMono(
154156
(routerResponse) -> {
@@ -184,6 +186,7 @@ private Flux<SubscritionCallbackMessage> heartbeatFlux(
184186
.post()
185187
.header("Content-Type", MediaType.APPLICATION_JSON_VALUE)
186188
.header(SUBSCRIPTION_PROTOCOL_HEADER, SUBSCRIPTION_PROTOCOL_HEADER_VALUE)
189+
.headers(httpHeaders -> httpHeaders.putAll(callback.context()))
187190
.bodyValue(heartbeat)
188191
.exchangeToFlux(
189192
(heartBeatResponse) -> {

spring-subscription-callback/src/test/java/com/apollographql/subscription/CallbackGraphQLHttpHandlerAbstrIT.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
import static com.apollographql.subscription.callback.SubscriptionCallbackHandler.SUBSCRIPTION_PROTOCOL_HEADER;
55
import static com.apollographql.subscription.callback.SubscriptionCallbackHandler.SUBSCRIPTION_PROTOCOL_HEADER_VALUE;
66

7+
import com.jayway.jsonpath.JsonPath;
78
import java.io.IOException;
9+
import java.util.List;
810
import java.util.Map;
911
import java.util.UUID;
12+
import java.util.concurrent.TimeUnit;
1013
import okhttp3.mockwebserver.MockResponse;
1114
import okhttp3.mockwebserver.MockWebServer;
15+
import org.junit.Assert;
1216
import org.springframework.graphql.test.tester.WebSocketGraphQlTester;
1317
import org.springframework.http.HttpStatus;
1418
import org.springframework.http.MediaType;
@@ -18,6 +22,8 @@
1822
import reactor.test.StepVerifier;
1923

2024
public abstract class CallbackGraphQLHttpHandlerAbstrIT {
25+
public static String TEST_HEADER_NAME = "X-Test-Header";
26+
public static String TEST_HEADER_VALUE = "junitTEST123";
2127

2228
public void verifyQueriesWorks(WebTestClient testClient) {
2329
var graphQLRequest = Map.of("query", "query { hello }");
@@ -59,13 +65,23 @@ public void verifyPostSubscriptionsWithoutCallbackDontWork(WebTestClient testCli
5965
.is4xxClientError();
6066
}
6167

62-
public void verifySuccessfulCallbackInit(WebTestClient testClient) {
68+
public void verifySuccessfulCallbackSubscription(WebTestClient testClient) {
69+
verifyCallbackSubscription(testClient, false);
70+
}
71+
72+
public void verifySuccessfulCallbackSubscriptionWithHeaders(WebTestClient testClient) {
73+
verifyCallbackSubscription(testClient, true);
74+
}
75+
76+
private void verifyCallbackSubscription(WebTestClient testClient, boolean testHeaders) {
6377
try (MockWebServer server = new MockWebServer()) {
6478
server.enqueue(
6579
new MockResponse()
6680
.setResponseCode(HttpStatus.NO_CONTENT.value())
6781
.setHeader(SUBSCRIPTION_PROTOCOL_HEADER, SUBSCRIPTION_PROTOCOL_HEADER_VALUE));
6882
server.enqueue(new MockResponse().setResponseCode(HttpStatus.NO_CONTENT.value()));
83+
server.enqueue(new MockResponse().setResponseCode(HttpStatus.NO_CONTENT.value()));
84+
server.enqueue(new MockResponse().setResponseCode(HttpStatus.NO_CONTENT.value()));
6985
server.enqueue(new MockResponse().setResponseCode(HttpStatus.NOT_FOUND.value()));
7086

7187
// Start the server.
@@ -79,14 +95,49 @@ public void verifySuccessfulCallbackInit(WebTestClient testClient) {
7995
.post()
8096
.uri("/graphql")
8197
.contentType(MediaType.APPLICATION_JSON)
98+
.headers(
99+
headers -> {
100+
if (testHeaders) {
101+
headers.put(TEST_HEADER_NAME, List.of(TEST_HEADER_VALUE));
102+
}
103+
})
82104
.bodyValue(graphQLRequest)
83105
.exchange()
84106
.expectStatus()
85107
.is2xxSuccessful()
86108
.expectHeader()
87109
.valueEquals(SUBSCRIPTION_PROTOCOL_HEADER, SUBSCRIPTION_PROTOCOL_HEADER_VALUE);
110+
111+
var checkMessageRaw = server.takeRequest(1, TimeUnit.SECONDS);
112+
Assert.assertNotNull(checkMessageRaw);
113+
if (testHeaders) {
114+
var headerValue = checkMessageRaw.getHeader(TEST_HEADER_NAME);
115+
Assert.assertEquals(TEST_HEADER_VALUE, headerValue);
116+
}
117+
var checkMessage = JsonPath.parse(checkMessageRaw.getBody().inputStream());
118+
Assert.assertEquals(subscriptionId, checkMessage.read("$.id"));
119+
120+
for (int i : List.of(1, 2, 3)) {
121+
var nextMessageRaw = server.takeRequest(1, TimeUnit.SECONDS);
122+
Assert.assertNotNull(nextMessageRaw);
123+
if (testHeaders) {
124+
var headerValue = checkMessageRaw.getHeader(TEST_HEADER_NAME);
125+
Assert.assertEquals(TEST_HEADER_VALUE, headerValue);
126+
}
127+
var nextMessage = JsonPath.parse(nextMessageRaw.getBody().inputStream());
128+
Assert.assertEquals(subscriptionId, nextMessage.read("$.id"));
129+
Assert.assertEquals(Integer.valueOf(i), nextMessage.read("$.payload.data.counter"));
130+
}
131+
132+
var completeMessageRaw = server.takeRequest(1, TimeUnit.SECONDS);
133+
Assert.assertNotNull(completeMessageRaw);
134+
var completeMessage = JsonPath.parse(completeMessageRaw.getBody().inputStream());
135+
Assert.assertEquals(subscriptionId, completeMessage.read("$.id"));
136+
Assert.assertEquals("complete", completeMessage.read("$.action"));
88137
} catch (IOException e) {
89138
// failed to close the server
139+
} catch (InterruptedException e) {
140+
Assert.fail("Failed to read callback data. Exception: " + e.getMessage());
90141
}
91142
}
92143

Original file line numberDiff line numberDiff line change
@@ -1,31 +1,18 @@
11
package com.apollographql.subscription.configuration;
22

3-
import com.apollographql.federation.graphqljava.Federation;
4-
import com.apollographql.federation.graphqljava._Entity;
5-
import graphql.schema.DataFetcher;
6-
import java.util.List;
7-
import java.util.Map;
8-
import java.util.stream.Collectors;
93
import org.springframework.boot.autoconfigure.graphql.GraphQlSourceBuilderCustomizer;
104
import org.springframework.context.annotation.Bean;
11-
import org.springframework.graphql.execution.ClassNameTypeResolver;
5+
import org.springframework.graphql.data.federation.FederationSchemaFactory;
126

137
public class TestGraphQLConfiguration {
14-
// configure federation
8+
159
@Bean
16-
public GraphQlSourceBuilderCustomizer federationTransform() {
17-
DataFetcher<?> entityDataFetcher =
18-
env -> {
19-
List<Map<String, Object>> representations = env.getArgument(_Entity.argumentName);
20-
return representations.stream().map(representation -> null).collect(Collectors.toList());
21-
};
10+
public GraphQlSourceBuilderCustomizer customizer(FederationSchemaFactory factory) {
11+
return builder -> builder.schemaFactory(factory::createGraphQLSchema);
12+
}
2213

23-
return builder ->
24-
builder.schemaFactory(
25-
(registry, wiring) ->
26-
Federation.transform(registry, wiring)
27-
.fetchEntities(entityDataFetcher)
28-
.resolveEntityType(new ClassNameTypeResolver())
29-
.build());
14+
@Bean
15+
public FederationSchemaFactory federationSchemaFactory() {
16+
return new FederationSchemaFactory();
3017
}
3118
}

spring-subscription-callback/src/test/java/com/apollographql/subscription/webflux/CallbackGraphQLWebfluxHttpHandlerIT.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.apollographql.subscription.CallbackWebGraphQLInterceptor;
55
import com.apollographql.subscription.callback.SubscriptionCallbackHandler;
66
import com.apollographql.subscription.configuration.TestGraphQLConfiguration;
7+
import java.util.Set;
78
import org.junit.jupiter.api.Test;
89
import org.springframework.beans.factory.annotation.Autowired;
910
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -39,7 +40,8 @@ public SubscriptionCallbackHandler callbackHandler(ExecutionGraphQlService graph
3940
@Bean
4041
public CallbackWebGraphQLInterceptor callbackGraphQlInterceptor(
4142
SubscriptionCallbackHandler callbackHandler) {
42-
return new CallbackWebGraphQLInterceptor(callbackHandler);
43+
var headers = Set.of(TEST_HEADER_NAME);
44+
return new CallbackWebGraphQLInterceptor(callbackHandler, headers);
4345
}
4446
}
4547

@@ -58,13 +60,18 @@ public void webSocketSubscription_works() {
5860
}
5961

6062
@Test
61-
public void postSubscription_withoutCallback_returns404() {
62-
verifyPostSubscriptionsWithoutCallbackDontWork(testClient);
63+
public void callbackSubscription_works() {
64+
verifySuccessfulCallbackSubscription(testClient);
6365
}
6466

6567
@Test
66-
public void callback_initSuccessful_returns200() {
67-
verifySuccessfulCallbackInit(testClient);
68+
public void callbackSubscription_withHeaders_works() {
69+
verifySuccessfulCallbackSubscriptionWithHeaders(testClient);
70+
}
71+
72+
@Test
73+
public void postSubscription_withoutCallback_returns404() {
74+
verifyPostSubscriptionsWithoutCallbackDontWork(testClient);
6875
}
6976

7077
@Test

0 commit comments

Comments
 (0)