Skip to content

Commit

Permalink
add support for ping message callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
emattheis committed Jan 3, 2025
1 parent 6d38d50 commit 80c8762
Show file tree
Hide file tree
Showing 18 changed files with 251 additions and 27 deletions.
45 changes: 31 additions & 14 deletions docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ implementation("io.quarkus:quarkus-websockets-next")

Both the <<server-api>> and <<client-api>> define _endpoints_ that are used to consume and send messages.
The endpoints are implemented as CDI beans and support injection.
Endpoints declare <<callback-methods,_callback methods_>> annotated with `@OnTextMessage`, `@OnBinaryMessage`, `@OnPong`, `@OnOpen`, `@OnClose` and `@OnError`.
Endpoints declare <<callback-methods,_callback methods_>> annotated with `@OnTextMessage`, `@OnBinaryMessage`, `@OnPing`, `@OnPong`, `@OnOpen`, `@OnClose` and `@OnError`.
These methods are used to handle various WebSocket events.
Typically, a method annotated with `@OnTextMessage` is called when the connected client sends a message to the server and vice versa.

Expand Down Expand Up @@ -210,6 +210,7 @@ A WebSocket endpoint may declare:

* At most one `@OnTextMessage` method: Handles the text messages from the connected client/server.
* At most one `@OnBinaryMessage` method: Handles the binary messages from the connected client/server.
* At most one `@OnPingMessage` method: Handles the ping messages from the connected client/server.
* At most one `@OnPongMessage` method: Handles the pong messages from the connected client/server.
* At most one `@OnOpen` method: Invoked when a connection is opened.
* At most one `@OnClose` method: Executed when the connection is closed.
Expand Down Expand Up @@ -551,39 +552,55 @@ Item find(Item item) {
1. Specify the codec to use for the deserialization of the incoming message
2. Specify the codec to use for the serialization of the outgoing message

=== Ping/pong messages
=== Ping/Pong messages

A https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2[ping message] may serve as a keepalive or to verify the remote endpoint.
A https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3[pong message] is sent in response to a ping message and it must have an identical payload.

Server/client endpoints automatically respond to a ping message sent from the client/server.
In other words, there is no need for `@OnPingMessage` callback declared on an endpoint.
==== Sending ping messages

The server can send ping messages to a connected client.
`WebSocketConnection`/`WebSocketClientConnection` declare methods to send ping messages; there is a non-blocking variant: `sendPing(Buffer)` and a blocking variant: `sendPingAndAwait(Buffer)`.
By default, the ping messages are not sent automatically.
However, the configuration properties `quarkus.websockets-next.server.auto-ping-interval` and `quarkus.websockets-next.client.auto-ping-interval` can be used to set the interval after which, the server/client sends a ping message to a connected client/server automatically.
Ping messages are optional and not sent by default. However, server and client endpoints can be configured to automatically send ping messages on an interval.

[source,properties]
----
quarkus.websockets-next.server.auto-ping-interval=2 <1>
quarkus.websockets-next.client.auto-ping-interval=10 <2>
----
<1> Sends a ping message from the server to a connected client every 2 seconds.
<1> Sends a ping message from the server to each connected client every 2 seconds.
<2> Sends a ping message from all connected client instances to their remote servers every 10 seconds.

The `@OnPongMessage` annotation is used to define a callback that consumes pong messages sent from the client/server.
An endpoint must declare at most one method annotated with `@OnPongMessage`.
Servers and clients can send ping messages programmatically at any time using `WebSocketConnection` or `WebSocketClientConnection`.
There is a non-blocking variant: `Sender#sendPing(Buffer)` and a blocking variant: `Sender#sendPingAndAwait(Buffer)`.

==== Sending pong messages

Server and client endpoints will always respond to a ping message sent from the remote party with a corresponding pong message, using the application data from the ping message.
This behavior is built-in and requires no additional code or configuration.

Servers and clients can send unsolicited pong messages that may serve as a unidirectional heartbeat using `WebSocketConnection` or `WebSocketClientConnection`. There is a non-blocking variant: `Sender#sendPong(Buffer)` and a blocking variant: `Sender#sendPongAndAwait(Buffer)`.

==== Handling ping/pong messages

Because ping messages are handled automatically and pong messages require no response, it is not necessary to write handlers for these messages to comply with the WebSocket protocol.
However, it is sometime useful to know when ping or pong messages are received by an endpoint.

The `@OnPingMessage` and `@OnPongMessage` annotations can be used to define callbacks that consumes ping or pong messages sent from the remote party.
An endpoint may declare at most one callback method for each.
The callback method must return either `void` or `Uni<Void>` (or be a Kotlin `suspend` function returning `Unit`), and it must accept a single parameter of type `Buffer`.

[source,java]
----
@OnPingMessage
void ping(Buffer data) {
// an incoming ping that will automatically receive a pong
}
@OnPongMessage
void pong(Buffer data) {
// ....
// an incoming pong in response to the last ping sent
}
----

NOTE: The server/client can also send unsolicited pong messages that may serve as a unidirectional heartbeat. There is a non-blocking variant: `WebSocketConnection#sendPong(Buffer)` and also a blocking variant: `WebSocketConnection#sendPongAndAwait(Buffer)`.

[[inbound-processing-mode]]
=== Inbound processing mode

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public Callback(Target target, AnnotationInstance annotation, BeanInfo bean, Met
this.messageType = MessageType.BINARY;
} else if (WebSocketDotNames.ON_TEXT_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.TEXT;
} else if (WebSocketDotNames.ON_PING_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.PING;
} else if (WebSocketDotNames.ON_PONG_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.PONG;
} else {
Expand Down Expand Up @@ -123,7 +125,7 @@ public boolean acceptsMessage() {
}

public boolean acceptsBinaryMessage() {
return messageType == MessageType.BINARY || messageType == MessageType.PONG;
return messageType == MessageType.BINARY || messageType == MessageType.PING || messageType == MessageType.PONG;
}

public boolean acceptsMulti() {
Expand Down Expand Up @@ -162,6 +164,7 @@ private DotName getCodec(String valueName) {

public enum MessageType {
NONE,
PING,
PONG,
TEXT,
BINARY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ interface ParameterContext {
default boolean acceptsMessage() {
return WebSocketDotNames.ON_BINARY_MESSAGE.equals(callbackAnnotation().name())
|| WebSocketDotNames.ON_TEXT_MESSAGE.equals(callbackAnnotation().name())
|| WebSocketDotNames.ON_PING_MESSAGE.equals(callbackAnnotation().name())
|| WebSocketDotNames.ON_PONG_MESSAGE.equals(callbackAnnotation().name());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPingMessage;
import io.quarkus.websockets.next.OnPongMessage;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.PathParam;
Expand Down Expand Up @@ -37,6 +38,7 @@ final class WebSocketDotNames {
static final DotName ON_OPEN = DotName.createSimple(OnOpen.class);
static final DotName ON_TEXT_MESSAGE = DotName.createSimple(OnTextMessage.class);
static final DotName ON_BINARY_MESSAGE = DotName.createSimple(OnBinaryMessage.class);
static final DotName ON_PING_MESSAGE = DotName.createSimple(OnPingMessage.class);
static final DotName ON_PONG_MESSAGE = DotName.createSimple(OnPongMessage.class);
static final DotName ON_CLOSE = DotName.createSimple(OnClose.class);
static final DotName ON_ERROR = DotName.createSimple(OnError.class);
Expand All @@ -57,5 +59,5 @@ final class WebSocketDotNames {
static final DotName TRANSACTIONAL = DotName.createSimple("jakarta.transaction.Transactional");

static final List<DotName> CALLBACK_ANNOTATIONS = List.of(ON_OPEN, ON_CLOSE, ON_BINARY_MESSAGE, ON_TEXT_MESSAGE,
ON_PONG_MESSAGE, ON_ERROR);
ON_PING_MESSAGE, ON_PONG_MESSAGE, ON_ERROR);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem {
public final Callback onOpen;
public final Callback onTextMessage;
public final Callback onBinaryMessage;
public final Callback onPingMessage;
public final Callback onPongMessage;
public final Callback onClose;
public final List<Callback> onErrors;

WebSocketEndpointBuildItem(boolean isClient, BeanInfo bean, String path, String id,
InboundProcessingMode inboundProcessingMode,
Callback onOpen, Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose,
List<Callback> onErrors) {
Callback onOpen, Callback onTextMessage, Callback onBinaryMessage, Callback onPingMessage,
Callback onPongMessage, Callback onClose, List<Callback> onErrors) {
this.isClient = isClient;
this.bean = bean;
this.path = path;
Expand All @@ -42,6 +43,7 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem {
this.onOpen = onOpen;
this.onTextMessage = onTextMessage;
this.onBinaryMessage = onBinaryMessage;
this.onPingMessage = onPingMessage;
this.onPongMessage = onPongMessage;
this.onClose = onClose;
this.onErrors = onErrors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,19 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
WebSocketDotNames.ON_TEXT_MESSAGE, callbackArguments, transformedAnnotations, path);
Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path);
Callback onPingMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
WebSocketDotNames.ON_PING_MESSAGE, callbackArguments, transformedAnnotations, path,
this::validateOnPingMessage);
Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
WebSocketDotNames.ON_PONG_MESSAGE, callbackArguments, transformedAnnotations, path,
this::validateOnPongMessage);
Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
WebSocketDotNames.ON_CLOSE, callbackArguments, transformedAnnotations, path,
this::validateOnClose);
if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) {
if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPingMessage == null
&& onPongMessage == null) {
throw new WebSocketServerException(
"The endpoint must declare at least one method annotated with @OnTextMessage, @OnBinaryMessage, @OnPongMessage or @OnOpen: "
"The endpoint must declare at least one method annotated with @OnTextMessage, @OnBinaryMessage, @OnPingMessage, @OnPongMessage or @OnOpen: "
+ beanClass);
}
endpoints.produce(new WebSocketEndpointBuildItem(target == Target.CLIENT, bean, path, id,
Expand All @@ -373,6 +377,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
onOpen,
onTextMessage,
onBinaryMessage,
onPingMessage,
onPongMessage,
onClose,
findErrorHandlers(target, index, bean, beanClass, callbackArguments, transformedAnnotations, path)));
Expand Down Expand Up @@ -764,6 +769,26 @@ static String getPathPrefix(IndexView index, DotName enclosingClassName) {
return "";
}

private void validateOnPingMessage(Callback callback) {
if (KotlinUtils.isKotlinMethod(callback.method)) {
if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) {
throw new WebSocketServerException(
"@OnPingMessage callback must return Unit: " + callback.asString());
}
} else {
if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) {
throw new WebSocketServerException(
"@OnPingMessage callback must return void or Uni<Void>: " + callback.asString());
}
}
Type messageType = callback.argumentType(MessageCallbackArgument::isMessage);
if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) {
throw new WebSocketServerException(
"@OnPingMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: "
+ callback.asString());
}
}

private void validateOnPongMessage(Callback callback) {
if (KotlinUtils.isKotlinMethod(callback.method)) {
if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) {
Expand Down Expand Up @@ -919,6 +944,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
transformedAnnotations, index, globalErrorHandlers, invokerFactory, metricsSupportEnabled);
generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onTextMessage,
transformedAnnotations, index, globalErrorHandlers, invokerFactory, metricsSupportEnabled);
generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onPingMessage,
transformedAnnotations, index, globalErrorHandlers, invokerFactory, metricsSupportEnabled);
generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onPongMessage,
transformedAnnotations, index, globalErrorHandlers, invokerFactory, metricsSupportEnabled);

Expand Down Expand Up @@ -1080,6 +1107,10 @@ private static void generateOnMessage(ClassCreator endpointCreator, MethodCreato
messageType = "Text";
methodParameterType = Object.class;
break;
case PING:
messageType = "Ping";
methodParameterType = Buffer.class;
break;
case PONG:
messageType = "Pong";
methodParameterType = Buffer.class;
Expand All @@ -1105,7 +1136,7 @@ private static void generateOnMessage(ClassCreator endpointCreator, MethodCreato
ExecutionModel.class);
onMessageExecutionModel.returnValue(onMessageExecutionModel.load(callback.executionModel));

if (callback.acceptsMulti() && callback.messageType != MessageType.PONG) {
if (callback.acceptsMulti() && (callback.messageType != MessageType.PING && callback.messageType != MessageType.PONG)) {
Type multiItemType = callback.messageParamType().asParameterizedType().arguments().get(0);
MethodCreator consumedMultiType = endpointCreator.getMethodCreator("consumed" + messageType + "MultiType",
java.lang.reflect.Type.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private List<Map<String, Object>> createEndpointsJson(List<WebSocketEndpointBuil
addCallback(endpoint.onOpen, callbacks);
addCallback(endpoint.onBinaryMessage, callbacks);
addCallback(endpoint.onTextMessage, callbacks);
addCallback(endpoint.onPingMessage, callbacks);
addCallback(endpoint.onPongMessage, callbacks);
addCallback(endpoint.onClose, callbacks);
for (Callback c : endpoint.onErrors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPingMessage;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.PathParam;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketClient;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketConnector;
import io.vertx.core.buffer.Buffer;

public class ClientEndpointTest {

Expand All @@ -49,6 +52,7 @@ void testClient() throws InterruptedException {
assertEquals("Lu=", connection.pathParam("name"));
connection.sendTextAndAwait("Hi!");

assertTrue(ClientEndpoint.PING_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertEquals("Lu=:Hello Lu=!", ClientEndpoint.MESSAGES.get(0));
assertEquals("Lu=:Hi!", ClientEndpoint.MESSAGES.get(1));
Expand All @@ -61,10 +65,16 @@ void testClient() throws InterruptedException {
@WebSocket(path = "/endpoint/{name}")
public static class ServerEndpoint {

private final Buffer ping = Buffer.buffer("ping");

static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1);

@Inject
WebSocketConnection connection;

@OnOpen
String open(@PathParam String name) {
connection.sendPingAndAwait(ping);
return "Hello " + name + "!";
}

Expand All @@ -83,6 +93,8 @@ void close() {
@WebSocketClient(path = "/endpoint/{name}")
public static class ClientEndpoint {

static final CountDownLatch PING_LATCH = new CountDownLatch(1);

static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2);

static final List<String> MESSAGES = new CopyOnWriteArrayList<>();
Expand All @@ -98,6 +110,11 @@ void onMessage(@PathParam String name, String message, WebSocketClientConnection
MESSAGE_LATCH.countDown();
}

@OnPingMessage
void onPing(Buffer message) {
PING_LATCH.countDown();
}

@OnClose
void close() {
CLOSED_LATCH.countDown();
Expand Down
Loading

0 comments on commit 80c8762

Please sign in to comment.