Skip to content

Commit 2259b9d

Browse files
smyrickdariuszkuc
authored andcommitted
Fix subscription caching logic (#515)
* Refactor subscription caching logic Move the saving of subscriptions to a separate class so we can vefify the logic with unit tests and simplify the ApolloSubscriptionProtocolHandler. This also exposed a bug that we were not saving the operation subscriptions to be stopped properly. This is now covered by the unit tests * Clear cache of operations if no more left Prioritize clearing memory over saving the small amount of operation we have to perform is the session stays open but there is no active operations
1 parent c676c7f commit 2259b9d

File tree

3 files changed

+360
-55
lines changed

3 files changed

+360
-55
lines changed

graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/spring/execution/ApolloSubscriptionProtocolHandler.kt

+31-55
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Client
2323
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_TERMINATE
2424
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_START
2525
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ClientMessages.GQL_STOP
26-
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE
2726
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ACK
2827
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ERROR
2928
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_KEEP_ALIVE
@@ -32,12 +31,10 @@ import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage.Server
3231
import com.fasterxml.jackson.databind.ObjectMapper
3332
import com.fasterxml.jackson.module.kotlin.convertValue
3433
import com.fasterxml.jackson.module.kotlin.readValue
35-
import org.reactivestreams.Subscription
3634
import org.slf4j.LoggerFactory
3735
import org.springframework.web.reactive.socket.WebSocketSession
3836
import reactor.core.publisher.Flux
3937
import java.time.Duration
40-
import java.util.concurrent.ConcurrentHashMap
4138

4239
/**
4340
* Implementation of the `graphql-ws` protocol defined by Apollo
@@ -48,11 +45,7 @@ class ApolloSubscriptionProtocolHandler(
4845
private val subscriptionHandler: SubscriptionHandler,
4946
private val objectMapper: ObjectMapper
5047
) {
51-
// Sessions are saved by web socket session id
52-
private val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()
53-
// Operations are saved by web socket session id, then operation id
54-
private val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()
55-
48+
private val sessionState = ApolloSubscriptionSessionState()
5649
private val logger = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler::class.java)
5750
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
5851
private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type)
@@ -63,36 +56,26 @@ class ApolloSubscriptionProtocolHandler(
6356
try {
6457
val operationMessage: SubscriptionOperationMessage = objectMapper.readValue(payload)
6558

66-
return when (operationMessage.type) {
67-
GQL_CONNECTION_INIT.type -> {
68-
val flux = Flux.just(acknowledgeMessage)
69-
val keepAliveInterval = config.subscriptions.keepAliveInterval
70-
if (keepAliveInterval != null) {
71-
// Send the GQL_CONNECTION_KEEP_ALIVE message every interval until the connection is closed or terminated
72-
val keepAliveFlux = Flux.interval(Duration.ofMillis(keepAliveInterval))
73-
.map { keepAliveMessage }
74-
.doOnSubscribe {
75-
logger.debug("GraphQL subscription INIT, sessionId=${session.id} activeSessions=${activeKeepAliveSessions.count()}")
76-
activeKeepAliveSessions[session.id] = it
77-
}
78-
79-
return flux.concatWith(keepAliveFlux)
80-
}
59+
logger.debug("GraphQL subscription client message, sessionId=${session.id} operationMessage=$operationMessage")
8160

82-
return flux
61+
when (operationMessage.type) {
62+
GQL_CONNECTION_INIT.type -> {
63+
val ackowledgeMessageFlux = Flux.just(acknowledgeMessage)
64+
val keepAliveFlux = getKeepAliveFlux(session)
65+
return ackowledgeMessageFlux.concatWith(keepAliveFlux)
8366
}
84-
GQL_START.type -> startSubscription(operationMessage, session)
67+
GQL_START.type -> return startSubscription(operationMessage, session)
8568
GQL_STOP.type -> {
86-
stopSubscription(operationMessage, session)
69+
sessionState.stopOperation(session, operationMessage)
8770
return Flux.empty()
8871
}
8972
GQL_CONNECTION_TERMINATE.type -> {
90-
terminateSession(session)
73+
sessionState.terminateSession(session)
9174
return Flux.empty()
9275
}
9376
else -> {
9477
logger.error("Unknown subscription operation $operationMessage")
95-
stopSubscription(operationMessage, session)
78+
sessionState.stopOperation(session, operationMessage)
9679
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
9780
}
9881
}
@@ -102,6 +85,21 @@ class ApolloSubscriptionProtocolHandler(
10285
}
10386
}
10487

88+
/**
89+
* If the keep alive configuraation is set, send a message back to client at every interval until the session is terminated.
90+
* Otherwise just return empty flux to append to the acknowledge message.
91+
*/
92+
private fun getKeepAliveFlux(session: WebSocketSession): Flux<SubscriptionOperationMessage> {
93+
val keepAliveInterval: Long? = config.subscriptions.keepAliveInterval
94+
if (keepAliveInterval != null) {
95+
return Flux.interval(Duration.ofMillis(keepAliveInterval))
96+
.map { keepAliveMessage }
97+
.doOnSubscribe { sessionState.saveKeepAliveSubscription(session, it) }
98+
}
99+
100+
return Flux.empty()
101+
}
102+
105103
@Suppress("Detekt.TooGenericExceptionCaught")
106104
private fun startSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession): Flux<SubscriptionOperationMessage> {
107105
if (operationMessage.id == null) {
@@ -113,7 +111,7 @@ class ApolloSubscriptionProtocolHandler(
113111

114112
if (payload == null) {
115113
logger.error("GraphQL subscription payload was null instead of a GraphQLRequest object")
116-
stopSubscription(operationMessage, session)
114+
sessionState.stopOperation(session, operationMessage)
117115
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
118116
}
119117

@@ -127,35 +125,13 @@ class ApolloSubscriptionProtocolHandler(
127125
SubscriptionOperationMessage(type = GQL_DATA.type, id = operationMessage.id, payload = it)
128126
}
129127
}
130-
.concatWith(Flux.just(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = operationMessage.id)))
131-
.doOnSubscribe {
132-
logger.debug("GraphQL subscription START, sessionId=${session.id} operationId=${operationMessage.id}")
133-
activeOperations[session.id]?.put(operationMessage.id, it)
134-
}
135-
.doOnCancel { logger.debug("GraphQL subscription CANCEL, sessionId=${session.id} operationId=${operationMessage.id}") }
136-
.doOnComplete { logger.debug("GraphQL subscription COMPELTE, sessionId=${session.id} operationId=${operationMessage.id}") }
128+
.concatWith(Flux.just(SubscriptionOperationMessage(type = SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE.type, id = operationMessage.id)))
129+
.doOnSubscribe { sessionState.saveOperation(session, operationMessage, it) }
137130
} catch (exception: Exception) {
138131
logger.error("Error running graphql subscription", exception)
139-
stopSubscription(operationMessage, session)
132+
// Do not terminate the session, just stop the operation messages
133+
sessionState.stopOperation(session, operationMessage)
140134
return Flux.just(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
141135
}
142136
}
143-
144-
private fun stopSubscription(operationMessage: SubscriptionOperationMessage, session: WebSocketSession) {
145-
logger.debug("GraphQL subscription STOP, sessionId=${session.id} operationId=${operationMessage.id}")
146-
if (operationMessage.id != null) {
147-
val operationsForSession = activeOperations[session.id]
148-
operationsForSession?.get(operationMessage.id)?.cancel()
149-
operationsForSession?.remove(operationMessage.id)
150-
}
151-
}
152-
153-
private fun terminateSession(session: WebSocketSession) {
154-
logger.debug("GraphQL subscription TERMINATE, sessionId=${session.id}")
155-
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
156-
activeOperations.remove(session.id)
157-
activeKeepAliveSessions[session.id]?.cancel()
158-
activeKeepAliveSessions.remove(session.id)
159-
session.close()
160-
}
161137
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright 2019 Expedia, Inc
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.expediagroup.graphql.spring.execution
18+
19+
import com.expediagroup.graphql.spring.model.SubscriptionOperationMessage
20+
import org.reactivestreams.Subscription
21+
import org.springframework.web.reactive.socket.WebSocketSession
22+
import java.util.concurrent.ConcurrentHashMap
23+
24+
internal class ApolloSubscriptionSessionState {
25+
26+
// Sessions are saved by web socket session id
27+
internal val activeKeepAliveSessions = ConcurrentHashMap<String, Subscription>()
28+
29+
// Operations are saved by web socket session id, then operation id
30+
internal val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()
31+
32+
/**
33+
* Save the session that is sending keep alive messages.
34+
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
35+
* These messages will be stopped on [terminateSession].
36+
*/
37+
fun saveKeepAliveSubscription(session: WebSocketSession, subscription: Subscription) {
38+
activeKeepAliveSessions[session.id] = subscription
39+
}
40+
41+
/**
42+
* Save the operation that is sending data to the client.
43+
* This will override values without cancelling the subscription so it is the responsbility of the consumer to cancel.
44+
* These messages will be stopped on [stopOperation].
45+
*/
46+
fun saveOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage, subscription: Subscription) {
47+
if (operationMessage.id != null) {
48+
val operationsForSession: ConcurrentHashMap<String, Subscription> = activeOperations.getOrPut(session.id) { ConcurrentHashMap() }
49+
operationsForSession[operationMessage.id] = subscription
50+
}
51+
}
52+
53+
/**
54+
* Stop the subscription sending data. Does NOT terminate the session.
55+
*/
56+
fun stopOperation(session: WebSocketSession, operationMessage: SubscriptionOperationMessage) {
57+
if (operationMessage.id != null) {
58+
val operationsForSession = activeOperations[session.id]
59+
operationsForSession?.get(operationMessage.id)?.cancel()
60+
operationsForSession?.remove(operationMessage.id)
61+
62+
if (operationsForSession?.isEmpty() == true) {
63+
activeOperations.remove(session.id)
64+
}
65+
}
66+
}
67+
68+
/**
69+
* Terminate the session, cancelling the keep alive messages and all operations active for this session.
70+
*/
71+
fun terminateSession(session: WebSocketSession) {
72+
activeOperations[session.id]?.forEach { _, subscription -> subscription.cancel() }
73+
activeOperations.remove(session.id)
74+
activeKeepAliveSessions[session.id]?.cancel()
75+
activeKeepAliveSessions.remove(session.id)
76+
session.close()
77+
}
78+
}

0 commit comments

Comments
 (0)