Skip to content

Commit 144d05e

Browse files
vbabaninstIncMale
andauthored
Add connection timeout to TLS Channel (#1686)
JAVA-5856 --------- Co-authored-by: Valentin Kovalenko <[email protected]>
1 parent 457a294 commit 144d05e

File tree

5 files changed

+262
-44
lines changed

5 files changed

+262
-44
lines changed

driver-core/src/main/com/mongodb/internal/TimeoutSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ public TimeoutSettings withReadTimeoutMS(final long readTimeoutMS) {
165165
maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS);
166166
}
167167

168+
public TimeoutSettings withConnectTimeoutMS(final long connectTimeoutMS) {
169+
return new TimeoutSettings(generationId, timeoutMS, serverSelectionTimeoutMS, connectTimeoutMS, readTimeoutMS, maxAwaitTimeMS,
170+
maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS);
171+
}
172+
168173
public TimeoutSettings withServerSelectionTimeoutMS(final long serverSelectionTimeoutMS) {
169174
return new TimeoutSettings(timeoutMS, serverSelectionTimeoutMS, connectTimeoutMS, readTimeoutMS, maxAwaitTimeMS,
170175
maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS);

driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java

Lines changed: 96 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.net.StandardSocketOptions;
4141
import java.nio.ByteBuffer;
4242
import java.nio.channels.CompletionHandler;
43+
import java.nio.channels.InterruptedByTimeoutException;
4344
import java.nio.channels.SelectionKey;
4445
import java.nio.channels.Selector;
4546
import java.nio.channels.SocketChannel;
@@ -49,6 +50,7 @@
4950
import java.util.concurrent.ExecutorService;
5051
import java.util.concurrent.Future;
5152
import java.util.concurrent.TimeUnit;
53+
import java.util.concurrent.atomic.AtomicReference;
5254

5355
import static com.mongodb.assertions.Assertions.assertTrue;
5456
import static com.mongodb.assertions.Assertions.isTrue;
@@ -97,21 +99,40 @@ public void close() {
9799
group.shutdown();
98100
}
99101

102+
/**
103+
* Monitors `OP_CONNECT` events for socket connections.
104+
*/
100105
private static class SelectorMonitor implements Closeable {
101106

102-
private static final class Pair {
107+
static final class SocketRegistration {
103108
private final SocketChannel socketChannel;
104-
private final Runnable attachment;
109+
private final AtomicReference<Runnable> afterConnectAction;
105110

106-
private Pair(final SocketChannel socketChannel, final Runnable attachment) {
111+
SocketRegistration(final SocketChannel socketChannel, final Runnable afterConnectAction) {
107112
this.socketChannel = socketChannel;
108-
this.attachment = attachment;
113+
this.afterConnectAction = new AtomicReference<>(afterConnectAction);
114+
}
115+
116+
boolean tryCancelPendingConnection() {
117+
return tryTakeAction() != null;
118+
}
119+
120+
void runAfterConnectActionIfNotCanceled() {
121+
Runnable afterConnectActionToExecute = tryTakeAction();
122+
if (afterConnectActionToExecute != null) {
123+
afterConnectActionToExecute.run();
124+
}
125+
}
126+
127+
@Nullable
128+
private Runnable tryTakeAction() {
129+
return afterConnectAction.getAndSet(null);
109130
}
110131
}
111132

112133
private final Selector selector;
113134
private volatile boolean isClosed;
114-
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<>();
135+
private final ConcurrentLinkedDeque<SocketRegistration> pendingRegistrations = new ConcurrentLinkedDeque<>();
115136

116137
SelectorMonitor() {
117138
try {
@@ -127,17 +148,14 @@ void start() {
127148
while (!isClosed) {
128149
try {
129150
selector.select();
130-
131151
for (SelectionKey selectionKey : selector.selectedKeys()) {
132152
selectionKey.cancel();
133-
Runnable runnable = (Runnable) selectionKey.attachment();
134-
runnable.run();
153+
((SocketRegistration) selectionKey.attachment()).runAfterConnectActionIfNotCanceled();
135154
}
136155

137-
for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
138-
Pair pendingRegistration = iter.next();
139-
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
140-
pendingRegistration.attachment);
156+
for (Iterator<SocketRegistration> iter = pendingRegistrations.iterator(); iter.hasNext();) {
157+
SocketRegistration pendingRegistration = iter.next();
158+
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, pendingRegistration);
141159
iter.remove();
142160
}
143161
} catch (Exception e) {
@@ -156,8 +174,8 @@ void start() {
156174
selectorThread.start();
157175
}
158176

159-
void register(final SocketChannel channel, final Runnable attachment) {
160-
pendingRegistrations.add(new Pair(channel, attachment));
177+
void register(final SocketRegistration registration) {
178+
pendingRegistrations.add(registration);
161179
selector.wakeup();
162180
}
163181

@@ -200,44 +218,79 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
200218
if (getSettings().getSendBufferSize() > 0) {
201219
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
202220
}
203-
221+
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
222+
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
204223
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
224+
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
225+
socketChannel, () -> initializeTslChannel(handler, socketChannel));
205226

206-
selectorMonitor.register(socketChannel, () -> {
207-
try {
208-
if (!socketChannel.finishConnect()) {
209-
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
210-
}
227+
if (connectTimeoutMs > 0) {
228+
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
229+
}
230+
selectorMonitor.register(socketRegistration);
231+
} catch (IOException e) {
232+
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
233+
} catch (Throwable t) {
234+
handler.failed(t);
235+
}
236+
}
211237

212-
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
213-
getServerAddress().getPort());
214-
sslEngine.setUseClientMode(true);
238+
private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
239+
final SelectorMonitor.SocketRegistration socketRegistration,
240+
final int connectTimeoutMs) {
241+
group.getTimeoutExecutor().schedule(() -> {
242+
if (socketRegistration.tryCancelPendingConnection()) {
243+
closeAndTimeout(handler, socketRegistration.socketChannel);
244+
}
245+
}, connectTimeoutMs, TimeUnit.MILLISECONDS);
246+
}
215247

216-
SSLParameters sslParameters = sslEngine.getSSLParameters();
217-
enableSni(getServerAddress().getHost(), sslParameters);
248+
private void closeAndTimeout(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
249+
// We check if this stream was closed before timeout exception.
250+
boolean streamClosed = isClosed();
251+
InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException();
252+
try {
253+
socketChannel.close();
254+
} catch (Exception e) {
255+
timeoutException.addSuppressed(e);
256+
}
218257

219-
if (!sslSettings.isInvalidHostNameAllowed()) {
220-
enableHostNameVerification(sslParameters);
221-
}
222-
sslEngine.setSSLParameters(sslParameters);
258+
if (streamClosed) {
259+
handler.completed(null);
260+
} else {
261+
handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), timeoutException));
262+
}
263+
}
223264

224-
BufferAllocator bufferAllocator = new BufferProviderAllocator();
265+
private void initializeTslChannel(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
266+
try {
267+
if (!socketChannel.finishConnect()) {
268+
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
269+
}
225270

226-
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
227-
.withEncryptedBufferAllocator(bufferAllocator)
228-
.withPlainBufferAllocator(bufferAllocator)
229-
.build();
271+
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
272+
getServerAddress().getPort());
273+
sslEngine.setUseClientMode(true);
230274

231-
// build asynchronous channel, based in the TLS channel and associated with the global group.
232-
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
275+
SSLParameters sslParameters = sslEngine.getSSLParameters();
276+
enableSni(getServerAddress().getHost(), sslParameters);
233277

234-
handler.completed(null);
235-
} catch (IOException e) {
236-
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
237-
} catch (Throwable t) {
238-
handler.failed(t);
239-
}
240-
});
278+
if (!sslSettings.isInvalidHostNameAllowed()) {
279+
enableHostNameVerification(sslParameters);
280+
}
281+
sslEngine.setSSLParameters(sslParameters);
282+
283+
BufferAllocator bufferAllocator = new BufferProviderAllocator();
284+
285+
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
286+
.withEncryptedBufferAllocator(bufferAllocator)
287+
.withPlainBufferAllocator(bufferAllocator)
288+
.build();
289+
290+
// build asynchronous channel, based in the TLS channel and associated with the global group.
291+
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
292+
293+
handler.completed(null);
241294
} catch (IOException e) {
242295
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
243296
} catch (Throwable t) {

driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,4 +823,13 @@ public long getCurrentWriteCount() {
823823
public long getCurrentRegistrationCount() {
824824
return registrations.mappingCount();
825825
}
826+
827+
/**
828+
* Returns the timeout executor used by this channel group.
829+
*
830+
* @return the timeout executor
831+
*/
832+
public ScheduledThreadPoolExecutor getTimeoutExecutor() {
833+
return timeoutExecutor;
834+
}
826835
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal.connection;
18+
19+
import com.mongodb.MongoSocketOpenException;
20+
import com.mongodb.ServerAddress;
21+
import com.mongodb.connection.SocketSettings;
22+
import com.mongodb.connection.SslSettings;
23+
import com.mongodb.internal.TimeoutContext;
24+
import com.mongodb.internal.TimeoutSettings;
25+
import org.junit.jupiter.params.ParameterizedTest;
26+
import org.junit.jupiter.params.provider.ValueSource;
27+
import org.mockito.MockedStatic;
28+
import org.mockito.Mockito;
29+
import org.mockito.invocation.InvocationOnMock;
30+
import org.mockito.stubbing.Answer;
31+
32+
import java.io.IOException;
33+
import java.net.ServerSocket;
34+
import java.nio.channels.InterruptedByTimeoutException;
35+
import java.nio.channels.SocketChannel;
36+
import java.util.concurrent.TimeUnit;
37+
38+
import static com.mongodb.internal.connection.OperationContext.simpleOperationContext;
39+
import static java.lang.String.format;
40+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
41+
import static org.junit.jupiter.api.Assertions.assertFalse;
42+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
43+
import static org.junit.jupiter.api.Assertions.assertNotNull;
44+
import static org.junit.jupiter.api.Assertions.assertThrows;
45+
import static org.junit.jupiter.api.Assertions.assertTrue;
46+
import static org.junit.jupiter.api.Assertions.fail;
47+
import static org.mockito.Mockito.atLeast;
48+
import static org.mockito.Mockito.verify;
49+
50+
class TlsChannelStreamFunctionalTest {
51+
private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
52+
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1";
53+
private static final int UNREACHABLE_PORT = 65333;
54+
55+
@ParameterizedTest
56+
@ValueSource(ints = {500, 1000, 2000})
57+
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {
58+
//given
59+
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
60+
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
61+
SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
62+
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
63+
64+
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
65+
.connectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
66+
.build(), SSL_SETTINGS);
67+
68+
Stream stream = streamFactory.create(new ServerAddress(UNREACHABLE_PRIVATE_IP_ADDRESS, UNREACHABLE_PORT));
69+
long connectOpenStart = System.nanoTime();
70+
71+
//when
72+
OperationContext operationContext = createOperationContext(connectTimeoutMs);
73+
MongoSocketOpenException mongoSocketOpenException = assertThrows(MongoSocketOpenException.class, () ->
74+
stream.open(operationContext));
75+
76+
//then
77+
long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - connectOpenStart);
78+
// Allow for some timing imprecision due to test overhead.
79+
int maximumAcceptableTimeoutOvershoot = 300;
80+
81+
assertInstanceOf(InterruptedByTimeoutException.class, mongoSocketOpenException.getCause());
82+
assertFalse(connectTimeoutMs > elapsedMs,
83+
format("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d", connectTimeoutMs, elapsedMs));
84+
assertTrue(elapsedMs - connectTimeoutMs <= maximumAcceptableTimeoutOvershoot,
85+
format("Connection timeout overshoot time %d ms should be within %d ms", elapsedMs - connectTimeoutMs,
86+
maximumAcceptableTimeoutOvershoot));
87+
88+
SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult();
89+
assertNotNull(actualSpySocketChannel, "SocketChannel was not opened");
90+
verify(actualSpySocketChannel, atLeast(1)).close();
91+
}
92+
}
93+
94+
@ParameterizedTest
95+
@ValueSource(ints = {0, 500, 1000, 2000})
96+
void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, InterruptedException {
97+
//given
98+
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
99+
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class);
100+
ServerSocket serverSocket = new ServerSocket(0, 1)) {
101+
SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
102+
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
103+
104+
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
105+
.connectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
106+
.build(), SSL_SETTINGS);
107+
108+
Stream stream = streamFactory.create(new ServerAddress(serverSocket.getInetAddress(), serverSocket.getLocalPort()));
109+
try {
110+
//when
111+
stream.open(createOperationContext(connectTimeoutMs));
112+
113+
//then
114+
SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult();
115+
assertNotNull(actualSpySocketChannel, "SocketChannel was not opened");
116+
assertTrue(actualSpySocketChannel.isConnected());
117+
118+
// Wait to verify that socket was not closed by timeout.
119+
MILLISECONDS.sleep(connectTimeoutMs * 2L);
120+
assertTrue(actualSpySocketChannel.isConnected());
121+
assertFalse(stream.isClosed());
122+
} finally {
123+
stream.close();
124+
}
125+
}
126+
}
127+
128+
private static final class SingleResultSpyCaptor<T> implements Answer<T> {
129+
private volatile T result = null;
130+
131+
public T getResult() {
132+
return result;
133+
}
134+
135+
@Override
136+
public T answer(final InvocationOnMock invocationOnMock) throws Throwable {
137+
if (result != null) {
138+
fail(invocationOnMock.getMethod().getName() + " was called more then once");
139+
}
140+
@SuppressWarnings("unchecked")
141+
T returnedValue = (T) invocationOnMock.callRealMethod();
142+
result = Mockito.spy(returnedValue);
143+
return result;
144+
}
145+
}
146+
147+
private static OperationContext createOperationContext(final int connectTimeoutMs) {
148+
return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs)));
149+
}
150+
}

driver-core/src/test/unit/com/mongodb/internal/TimeoutSettingsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ Collection<DynamicTest> timeoutSettingsTest() {
5353
.withMaxAwaitTimeMS(11)
5454
.withMaxCommitMS(999L)
5555
.withReadTimeoutMS(11_000)
56+
.withConnectTimeoutMS(500)
5657
.withWTimeoutMS(222L);
5758
assertAll(
5859
() -> assertEquals(30_000, timeoutSettings.getServerSelectionTimeoutMS()),
59-
() -> assertEquals(10_000, timeoutSettings.getConnectTimeoutMS()),
60+
() -> assertEquals(500, timeoutSettings.getConnectTimeoutMS()),
6061
() -> assertEquals(11_000, timeoutSettings.getReadTimeoutMS()),
6162
() -> assertEquals(100, timeoutSettings.getTimeoutMS()),
6263
() -> assertEquals(111, timeoutSettings.getMaxTimeMS()),

0 commit comments

Comments
 (0)