Skip to content

Commit 3cc5304

Browse files
committed
fix rebase conflict
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 9c1aa50 commit 3cc5304

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import java.io.IOException;
1919
import java.lang.reflect.Constructor;
20-
import java.lang.reflect.Type;
2120
import java.util.ArrayList;
2221
import java.util.Arrays;
2322
import java.util.HashMap;
@@ -1202,9 +1201,9 @@ public void testSerializeFloatNaNAndInfinity_BecomesNull_InPojo() {
12021201
@Test
12031202
public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
12041203
String jsonObj = "{\"fObj\":1.23e-5}";
1205-
Type mapType = new TypeToken<Map<String, Float>>() {
1204+
java.lang.reflect.Type mapType = new com.google.gson.reflect.TypeToken<java.util.Map<String, Float>>() {
12061205
}.getType();
1207-
Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
1206+
java.util.Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
12081207
assertEquals(1.23e-5f, m.get("fObj"), 1e-9f);
12091208

12101209
String jsonArr = "[4.56e1]";
@@ -1216,9 +1215,9 @@ public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
12161215
public void testDeserializeNullFloat_ToNull() {
12171216
String json = "{\"fObj\":null,\"fPrim\":1.0}";
12181217

1219-
Type mapType = new TypeToken<Map<String, JsonElement>>() {
1218+
java.lang.reflect.Type mapType = new TypeToken<java.util.Map<String, JsonElement>>() {
12201219
}.getType();
1221-
Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
1220+
java.util.Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
12221221

12231222
assertTrue(m.containsKey("fObj"));
12241223
assertTrue(m.get("fObj").isJsonNull());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Locale;
1919
import java.util.Map;
2020
import java.util.concurrent.CompletableFuture;
21+
import java.util.concurrent.atomic.AtomicReference;
2122

2223
import org.apache.commons.text.StringEscapeUtils;
2324
import org.apache.logging.log4j.Logger;
@@ -40,6 +41,8 @@
4041
import org.opensearch.transport.StreamTransportService;
4142
import org.opensearch.transport.client.Client;
4243

44+
import com.google.common.annotations.VisibleForTesting;
45+
4346
import lombok.Getter;
4447
import lombok.Setter;
4548
import lombok.extern.log4j.Log4j2;
@@ -70,19 +73,18 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
7073
@Getter
7174
private MLGuard mlGuard;
7275

73-
private SdkAsyncHttpClient httpClient;
76+
private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();
7477

7578
@Setter
7679
@Getter
7780
private StreamTransportService streamTransportService;
7881

82+
@Setter
83+
private boolean connectorPrivateIpEnabled;
84+
7985
public AwsConnectorExecutor(Connector connector) {
8086
super.initialize(connector);
8187
this.connector = (AwsConnector) connector;
82-
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
83-
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
84-
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
85-
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
8688
}
8789

8890
@Override
@@ -129,7 +131,8 @@ public void invokeRemoteService(
129131
)
130132
)
131133
.build();
132-
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
134+
AccessController
135+
.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> getHttpClient().execute(executeRequest));
133136
} catch (RuntimeException exception) {
134137
log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception);
135138
actionListener.onFailure(exception);
@@ -180,4 +183,19 @@ private void validateLLMInterface(String llmInterface) {
180183
throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface));
181184
}
182185
}
186+
187+
@VisibleForTesting
188+
protected SdkAsyncHttpClient getHttpClient() {
189+
if (httpClientRef.get() == null) {
190+
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
191+
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
192+
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
193+
this.httpClientRef
194+
.compareAndSet(
195+
null,
196+
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
197+
);
198+
}
199+
return httpClientRef.get();
200+
}
183201
}

0 commit comments

Comments
 (0)