From a64b3f3fd2ff4fc9664680b3f15d6a364cb9d8ab Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 31 Jan 2024 18:26:41 -0800 Subject: [PATCH] Refactor async executor service dependencies using guice framework Signed-off-by: Vamsi Manohar --- .../org/opensearch/sql/plugin/SQLPlugin.java | 101 +------------ .../client/EMRServerlessClientFactory.java | 10 ++ .../EMRServerlessClientFactoryImpl.java | 60 ++++++++ .../dispatcher/SparkQueryDispatcher.java | 21 +-- .../execution/session/SessionManager.java | 14 +- ...ransportCancelAsyncQueryRequestAction.java | 29 +++- ...ransportCreateAsyncQueryRequestAction.java | 29 +++- .../TransportGetAsyncQueryResultAction.java | 28 +++- .../config/AsyncExecutorServiceModule.java | 143 ++++++++++++++++++ ...AsyncQueryExecutorServiceImplSpecTest.java | 49 ++++-- .../AsyncQueryExecutorServiceSpec.java | 19 ++- .../AsyncQueryGetResultSpecTest.java | 4 +- .../spark/asyncquery/IndexQuerySpecTest.java | 125 ++++++++++++--- .../dispatcher/SparkQueryDispatcherTest.java | 5 +- .../session/InteractiveSessionTest.java | 12 +- .../execution/session/SessionManagerTest.java | 8 +- .../execution/statement/StatementTest.java | 25 ++- ...portCancelAsyncQueryRequestActionTest.java | 12 +- ...portCreateAsyncQueryRequestActionTest.java | 12 +- ...ransportGetAsyncQueryResultActionTest.java | 12 +- 20 files changed, 533 insertions(+), 185 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f0689a0966..cdb68bfabb 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -5,21 +5,14 @@ package org.opensearch.sql.plugin; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static java.util.Collections.singletonList; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.emrserverless.AWSEMRServerless; -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.Clock; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -68,7 +61,6 @@ import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; -import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.plugin.RestSqlAction; import org.opensearch.sql.legacy.plugin.RestSqlStatsAction; @@ -86,22 +78,7 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; -import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EmrServerlessClientImpl; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; -import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; -import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -127,7 +104,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; - private AsyncQueryExecutorService asyncQueryExecutorService; private Injector injector; public String name() { @@ -223,23 +199,6 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(pluginSettings); - SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); - if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) { - LOGGER.warn( - String.format( - "Async Query APIs are disabled as %s is not configured properly in cluster settings. " - + "Please configure and restart the domain to enable Async Query APIs", - SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); - this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); - } else { - this.asyncQueryExecutorService = - createAsyncQueryExecutorService( - sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig); - } - ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); modules.add( @@ -260,13 +219,12 @@ public Collection createComponents( OpenSearchSettings.RESULT_INDEX_TTL_SETTING, OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, environment.settings()); - return ImmutableList.of( - dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); + return ImmutableList.of(dataSourceService, clusterManagerEventListener, pluginSettings); } @Override public List> getExecutorBuilders(Settings settings) { - return Collections.singletonList( + return singletonList( new FixedExecutorBuilder( settings, AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME, @@ -318,57 +276,4 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } - - private AsyncQueryExecutorService createAsyncQueryExecutorService( - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, - SparkExecutionEngineConfig sparkExecutionEngineConfig) { - StateStore stateStore = new StateStore(client, clusterService); - registerStateStoreMetrics(stateStore); - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - EMRServerlessClient emrServerlessClient = - createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); - JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - this.dataSourceService, - new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader, - new FlintIndexMetadataReaderImpl(client), - client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); - return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, - sparkQueryDispatcher, - sparkExecutionEngineConfigSupplier); - } - - private void registerStateStoreMetrics(StateStore stateStore) { - GaugeMetric activeSessionMetric = - new GaugeMetric<>( - "active_async_query_sessions_count", - StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); - GaugeMetric activeStatementMetric = - new GaugeMetric<>( - "active_async_query_statements_count", - StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); - Metrics.getInstance().registerMetric(activeSessionMetric); - Metrics.getInstance().registerMetric(activeStatementMetric); - } - - private EMRServerlessClient createEMRServerlessClient(String region) { - return AccessController.doPrivileged( - (PrivilegedAction) - () -> { - AWSEMRServerless awsemrServerless = - AWSEMRServerlessClientBuilder.standard() - .withRegion(region) - .withCredentials(new DefaultAWSCredentialsProviderChain()) - .build(); - return new EmrServerlessClientImpl(awsemrServerless); - }); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java new file mode 100644 index 0000000000..c4dded9ee6 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java @@ -0,0 +1,10 @@ +package org.opensearch.sql.spark.client; + +/** Interface for EMRServerlessClientFactory. */ +public interface EMRServerlessClientFactory { + + /** + * @return {@link EMRServerlessClient} + */ + EMRServerlessClient getClient(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java new file mode 100644 index 0000000000..921362e7bd --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -0,0 +1,60 @@ +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; + +@RequiredArgsConstructor +public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory { + + private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + private EMRServerlessClient emrServerlessClient; + private String region; + + @Override + public EMRServerlessClient getClient() { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + validateSparkExecutionEngineConfig(sparkExecutionEngineConfig); + if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) { + region = sparkExecutionEngineConfig.getRegion(); + this.emrServerlessClient = createEMRServerlessClient(this.region); + } + return this.emrServerlessClient; + } + + private boolean isNewClientCreationRequired(String region) { + return !region.equals(this.region) || emrServerlessClient == null; + } + + private void validateSparkExecutionEngineConfig( + SparkExecutionEngineConfig sparkExecutionEngineConfig) { + if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) { + throw new IllegalArgumentException( + String.format( + "Async Query APIs are disabled as %s is not configured in cluster settings. Please" + + " configure the setting and restart the domain to enable Async Query APIs", + SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); + } + } + + private EMRServerlessClient createEMRServerlessClient(String awsRegion) { + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(awsRegion) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0aa183335e..4a790c311f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -8,8 +8,6 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.client.Client; import org.opensearch.sql.datasource.DataSourceService; @@ -18,6 +16,8 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -35,13 +35,12 @@ @AllArgsConstructor public class SparkQueryDispatcher { - private static final Logger LOG = LogManager.getLogger(); public static final String INDEX_TAG_KEY = "index"; public static final String DATASOURCE_TAG_KEY = "datasource"; public static final String CLUSTER_NAME_TAG_KEY = "domain_ident"; public static final String JOB_TYPE_TAG_KEY = "type"; - private EMRServerlessClient emrServerlessClient; + private EMRServerlessClientFactory emrServerlessClientFactory; private DataSourceService dataSourceService; @@ -60,10 +59,10 @@ public class SparkQueryDispatcher { private StateStore stateStore; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - AsyncQueryHandler asyncQueryHandler = sessionManager.isEnabled() ? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) @@ -83,7 +82,7 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) contextBuilder.indexQueryDetails(indexQueryDetails); if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { - asyncQueryHandler = createIndexDMLHandler(); + asyncQueryHandler = createIndexDMLHandler(emrServerlessClient); } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType()) && indexQueryDetails.isAutoRefresh()) { asyncQueryHandler = @@ -98,12 +97,15 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build()); } + private void createEmrServerlessClient(SparkExecutionEngineConfig sparkExecutionEngineConfig) {} + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); if (asyncQueryJobMetadata.getSessionId() != null) { return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) .getQueryResponse(asyncQueryJobMetadata); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - return createIndexDMLHandler().getQueryResponse(asyncQueryJobMetadata); + return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata); } else { return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager) .getQueryResponse(asyncQueryJobMetadata); @@ -111,12 +113,13 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); AsyncQueryHandler queryHandler; if (asyncQueryJobMetadata.getSessionId() != null) { queryHandler = new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - queryHandler = createIndexDMLHandler(); + queryHandler = createIndexDMLHandler(emrServerlessClient); } else { queryHandler = new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); @@ -124,7 +127,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { return queryHandler.cancelJob(asyncQueryJobMetadata); } - private IndexDMLHandler createIndexDMLHandler() { + private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) { return new IndexDMLHandler( emrServerlessClient, dataSourceService, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c3d5807305..e441492c20 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.utils.RealTimeProvider; @@ -21,13 +21,15 @@ */ public class SessionManager { private final StateStore stateStore; - private final EMRServerlessClient emrServerlessClient; + private final EMRServerlessClientFactory emrServerlessClientFactory; private Settings settings; public SessionManager( - StateStore stateStore, EMRServerlessClient emrServerlessClient, Settings settings) { + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory, + Settings settings) { this.stateStore = stateStore; - this.emrServerlessClient = emrServerlessClient; + this.emrServerlessClientFactory = emrServerlessClientFactory; this.settings = settings; } @@ -36,7 +38,7 @@ public Session createSession(CreateSessionRequest request) { InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) - .serverlessClient(emrServerlessClient) + .serverlessClient(emrServerlessClientFactory.getClient()) .build(); session.open(request); return session; @@ -68,7 +70,7 @@ public Optional getSession(SessionId sid, String dataSourceName) { InteractiveSession.builder() .sessionId(sid) .stateStore(stateStore) - .serverlessClient(emrServerlessClient) + .serverlessClient(emrServerlessClientFactory.getClient()) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( settings.getSettingValue(SESSION_INACTIVITY_TIMEOUT_MILLIS)) diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..4fe2b6bf3d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -10,9 +10,17 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -22,7 +30,7 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; - private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + private final Injector injector; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @@ -30,9 +38,20 @@ public class TransportCancelAsyncQueryRequestAction public TransportCancelAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = asyncQueryExecutorService; + ModulesBuilder modulesBuilder = new ModulesBuilder(); + modulesBuilder.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modulesBuilder.add(new AsyncExecutorServiceModule()); + this.injector = modulesBuilder.createInjector(); } @Override @@ -41,6 +60,8 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); listener.onResponse( new CancelAsyncQueryActionResponse( diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java index 991eafdad9..aa5b4ee4a5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -10,13 +10,20 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -24,8 +31,7 @@ public class TransportCreateAsyncQueryRequestAction extends HandledTransportAction { - - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/create"; public static final ActionType ACTION_TYPE = @@ -35,9 +41,20 @@ public class TransportCreateAsyncQueryRequestAction public TransportCreateAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CreateAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + ModulesBuilder modulesBuilder = new ModulesBuilder(); + modulesBuilder.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modulesBuilder.add(new AsyncExecutorServiceModule()); + this.injector = modulesBuilder.createInjector(); } @Override @@ -47,6 +64,8 @@ protected void doExecute( ActionListener listener) { try { CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); CreateAsyncQueryResponse createAsyncQueryResponse = asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest); String responseContent = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 5c784cf04c..3662b33a6a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -10,15 +10,22 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; @@ -29,7 +36,7 @@ public class TransportGetAsyncQueryResultAction extends HandledTransportAction< GetAsyncQueryResultActionRequest, GetAsyncQueryResultActionResponse> { - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/result"; public static final ActionType ACTION_TYPE = @@ -39,9 +46,20 @@ public class TransportGetAsyncQueryResultAction public TransportGetAsyncQueryResultAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, GetAsyncQueryResultActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + ModulesBuilder modulesBuilder = new ModulesBuilder(); + modulesBuilder.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modulesBuilder.add(new AsyncExecutorServiceModule()); + this.injector = modulesBuilder.createInjector(); } @Override @@ -51,6 +69,8 @@ protected void doExecute( ActionListener listener) { try { String jobId = request.getQueryId(); + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); AsyncQueryExecutionResponse asyncQueryExecutionResponse = asyncQueryExecutorService.getAsyncQueryResults(jobId); ResponseFormatter formatter = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java new file mode 100644 index 0000000000..ede3fde069 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -0,0 +1,143 @@ +package org.opensearch.sql.spark.transport.config; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; + +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.legacy.metrics.GaugeMetric; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class AsyncExecutorServiceModule extends AbstractModule { + + private static final Logger LOG = LogManager.getLogger(AsyncExecutorServiceModule.class); + + @Override + protected void configure() {} + + @Provides + public AsyncQueryExecutorService asyncQueryExecutorService( + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, + SparkQueryDispatcher sparkQueryDispatcher, + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); + } + + @Provides + public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService( + StateStore stateStore) { + return new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + } + + @Provides + @Singleton + public StateStore stateStore(NodeClient client, ClusterService clusterService) { + StateStore stateStore = new StateStore(client, clusterService); + registerStateStoreMetrics(stateStore); + return stateStore; + } + + @Provides + public SparkQueryDispatcher sparkQueryDispatcher( + EMRServerlessClientFactory emrServerlessClientFactory, + DataSourceService dataSourceService, + DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper, + JobExecutionResponseReader jobExecutionResponseReader, + FlintIndexMetadataReaderImpl flintIndexMetadataReader, + NodeClient client, + SessionManager sessionManager, + DefaultLeaseManager defaultLeaseManager, + StateStore stateStore, + ClusterService clusterService) { + return new SparkQueryDispatcher( + emrServerlessClientFactory, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader, + client, + sessionManager, + defaultLeaseManager, + stateStore); + } + + @Provides + public SessionManager sessionManager( + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory, + Settings settings) { + return new SessionManager(stateStore, emrServerlessClientFactory, settings); + } + + @Provides + public DefaultLeaseManager defaultLeaseManager(Settings settings, StateStore stateStore) { + return new DefaultLeaseManager(settings, stateStore); + } + + @Provides + public EMRServerlessClientFactory createEMRServerlessClientFactory( + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + return new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + } + + @Provides + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { + return new SparkExecutionEngineConfigSupplierImpl(settings); + } + + @Provides + @Singleton + public FlintIndexMetadataReaderImpl flintIndexMetadataReader(NodeClient client) { + return new FlintIndexMetadataReaderImpl(client); + } + + @Provides + public JobExecutionResponseReader jobExecutionResponseReader(NodeClient client) { + return new JobExecutionResponseReader(client); + } + + @Provides + public DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper( + NodeClient client) { + return new DataSourceUserAuthorizationHelperImpl(client); + } + + private void registerStateStoreMetrics(StateStore stateStore) { + GaugeMetric activeSessionMetric = + new GaugeMetric<>( + "active_async_query_sessions_count", + StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); + GaugeMetric activeStatementMetric = + new GaugeMetric<>( + "active_async_query_statements_count", + StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); + Metrics.getInstance().registerMetric(activeSessionMetric); + Metrics.getInstance().registerMetric(activeStatementMetric); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 011d97dcdf..33fec89e26 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -32,6 +32,7 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; @@ -46,8 +47,9 @@ public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorSer @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // disable session enableSession(false); @@ -74,8 +76,9 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @Disabled("batch query is unsupported") public void sessionLimitNotImpactBatchQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // disable session enableSession(false); @@ -96,8 +99,9 @@ public void sessionLimitNotImpactBatchQuery() { @Disabled("batch query is unsupported") public void createAsyncQueryCreateJobWithCorrectParameters() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); enableSession(false); CreateAsyncQueryResponse response = @@ -129,8 +133,9 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { @Test public void withSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // 1. create async query. CreateAsyncQueryResponse response = @@ -156,8 +161,9 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { @Test public void reuseSessionWhenCreateAsyncQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -207,8 +213,9 @@ public void reuseSessionWhenCreateAsyncQuery() { @Disabled("batch query is unsupported") public void batchQueryHasTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); enableSession(false); CreateAsyncQueryResponse response = @@ -221,8 +228,9 @@ public void batchQueryHasTimeout() { @Test public void interactiveQueryNoTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -255,8 +263,9 @@ public void datasourceWithBasicAuth() { properties, null)); LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -274,8 +283,9 @@ public void datasourceWithBasicAuth() { @Test public void withSessionCreateAsyncQueryFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -322,8 +332,9 @@ public void withSessionCreateAsyncQueryFailed() { @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -351,8 +362,9 @@ public void createSessionMoreThanLimitFailed() { @Test public void recreateSessionIfNotReady() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -388,8 +400,9 @@ public void recreateSessionIfNotReady() { @Test public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -426,8 +439,9 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { @Test public void recreateSessionIfStale() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -480,8 +494,9 @@ public void recreateSessionIfStale() { @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -516,8 +531,9 @@ public void datasourceNameIncludeUppercase() { null)); LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -536,8 +552,9 @@ public void datasourceNameIncludeUppercase() { @Test public void concurrentSessionLimitIsDomainLevel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // only allow one session in domain. setSessionLimit(1); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c7054dd200..c9b4b6fc88 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -59,6 +59,7 @@ import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -195,27 +196,27 @@ private DataSourceServiceImpl createDataSourceService() { } protected AsyncQueryExecutorService createAsyncQueryExecutorService( - EMRServerlessClient emrServerlessClient) { + EMRServerlessClientFactory emrServerlessClientFactory) { return createAsyncQueryExecutorService( - emrServerlessClient, new JobExecutionResponseReader(client)); + emrServerlessClientFactory, new JobExecutionResponseReader(client)); } /** Pass a custom response reader which can mock interaction between PPL plugin and EMR-S job. */ protected AsyncQueryExecutorService createAsyncQueryExecutorService( - EMRServerlessClient emrServerlessClient, + EMRServerlessClientFactory emrServerlessClientFactory, JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(stateStore); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - emrServerlessClient, + emrServerlessClientFactory, this.dataSourceService, new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), + new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), stateStore); return new AsyncQueryExecutorServiceImpl( @@ -271,6 +272,14 @@ public void setJobState(JobRunState jobState) { } } + public static class LocalEMRServerlessClientFactory implements EMRServerlessClientFactory { + + @Override + public EMRServerlessClient getClient() { + return new LocalEMRSClient(); + } + } + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 2ddfe77868..ab6439492a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -411,9 +412,10 @@ private class AssertionHelper { private Interaction interaction; AssertionHelper(String query, LocalEMRSClient emrClient) { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrClient; this.queryService = createAsyncQueryExecutorService( - emrClient, + emrServerlessClientFactory, /* * Custom reader that intercepts get results call and inject extra steps defined in * current interaction. Intercept both get methods for different query handler which diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 49ac538e65..844567f4f5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -12,6 +12,8 @@ import org.junit.Assert; import org.junit.Test; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexType; import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; @@ -72,9 +74,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -120,8 +128,15 @@ public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -157,8 +172,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -193,9 +215,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -248,8 +276,15 @@ public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -290,8 +325,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -331,8 +373,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -380,8 +429,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -424,8 +480,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -468,8 +531,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -517,8 +587,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -565,8 +642,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -586,8 +670,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { @Test public void concurrentRefreshJobLimitNotApplied() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index COVERING.createIndex(); @@ -607,8 +692,9 @@ public void concurrentRefreshJobLimitNotApplied() { @Test public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); @@ -633,8 +719,9 @@ public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { @Test public void concurrentRefreshJobLimitAppliedToRefresh() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); @@ -658,9 +745,9 @@ public void concurrentRefreshJobLimitAppliedToRefresh() { @Test public void concurrentRefreshJobLimitNotAppliedToDDL() { String query = "CREATE INDEX covering ON mys3.default.http_logs(l_orderkey, l_quantity)"; - + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 4205102cb1..2a499e7d30 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -62,6 +62,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -82,6 +83,7 @@ public class SparkQueryDispatcherTest { @Mock private EMRServerlessClient emrServerlessClient; + @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @@ -112,7 +114,7 @@ public class SparkQueryDispatcherTest { void setUp() { sparkQueryDispatcher = new SparkQueryDispatcher( - emrServerlessClient, + emrServerlessClientFactory, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader, @@ -121,6 +123,7 @@ void setUp() { sessionManager, leaseManager, stateStore); + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index d670fc4ca8..338da431fb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -23,6 +23,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchIntegTestCase; @@ -117,8 +118,9 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); TestSession testSession = testSession(session, stateStore); @@ -127,7 +129,9 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting()); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + SessionManager sessionManager = + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); @@ -137,7 +141,9 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting()); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + SessionManager sessionManager = + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); Optional managerSession = sessionManager.getSession(SessionId.newSessionId("no-exist")); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 44dd5c3a57..d021bc7248 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -14,17 +14,19 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; @ExtendWith(MockitoExtension.class) public class SessionManagerTest { @Mock private StateStore stateStore; - @Mock private EMRServerlessClient emrClient; + + @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Test public void sessionEnable() { - Assertions.assertTrue(new SessionManager(stateStore, emrClient, sessionSetting()).isEnabled()); + Assertions.assertTrue( + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()).isEnabled()); } public static org.opensearch.sql.common.setting.Settings sessionSetting() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 97f38d37a7..3a69fa01d7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -24,6 +24,7 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -258,8 +259,9 @@ public void cancelRunningStatementSuccess() { @Test public void submitStatementInRunningSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running @@ -271,8 +273,9 @@ public void submitStatementInRunningSession() { @Test public void submitStatementInNotStartedState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); @@ -281,8 +284,9 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); @@ -297,8 +301,9 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); @@ -313,8 +318,9 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); @@ -331,8 +337,9 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // other's delete session @@ -347,8 +354,9 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); @@ -362,8 +370,9 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..50c3206b1f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -22,7 +22,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; @@ -43,12 +46,19 @@ public class TransportCancelAsyncQueryRequestActionTest { private ArgumentCaptor deleteJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 36060d3850..0fad5d8bde 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -24,7 +24,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -42,6 +45,9 @@ public class TransportCreateAsyncQueryRequestActionTest { @Mock private AsyncQueryExecutorServiceImpl jobExecutorService; @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; @@ -52,7 +58,11 @@ public class TransportCreateAsyncQueryRequestActionTest { public void setUp() { action = new TransportCreateAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 34f10b0083..ce9859d326 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -28,7 +28,10 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; @@ -51,12 +54,19 @@ public class TransportGetAsyncQueryResultActionTest { private ArgumentCaptor createJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @BeforeEach public void setUp() { action = new TransportGetAsyncQueryResultAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test