Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Netty deprecation warnings in transport-netty4 module ([#20233](https://github.com/opensearch-project/OpenSearch/pull/20233))
- Fix snapshot restore when an index sort is present ([#20284](https://github.com/opensearch-project/OpenSearch/pull/20284))
- Fix SearchPhaseExecutionException to properly initCause ([#20320](https://github.com/opensearch-project/OpenSearch/pull/20320))
- Fixes and refactoring in stream transport to make it more robust ([#20359](https://github.com/opensearch-project/OpenSearch/pull/20359))

### Dependencies
- Bump `com.google.auth:google-auth-library-oauth2-http` from 1.38.0 to 1.41.0 ([#20183](https://github.com/opensearch-project/OpenSearch/pull/20183))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ public static List<Setting<?>> getSettings() {
ARROW_ENABLE_DEBUG_ALLOCATOR,
ARROW_ENABLE_UNSAFE_MEMORY_ACCESS,
ARROW_SSL_ENABLE,
FLIGHT_EVENT_LOOP_THREADS
FLIGHT_EVENT_LOOP_THREADS,
FLIGHT_THREAD_POOL_MIN_SIZE
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.arrow.flight.stats.FlightCallTracker;
import org.opensearch.arrow.flight.stats.FlightStatsCollector;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
Expand Down Expand Up @@ -45,6 +44,7 @@
*/
class FlightClientChannel implements TcpChannel {
private static final Logger logger = LogManager.getLogger(FlightClientChannel.class);
private static final AtomicLong GLOBAL_CHANNEL_COUNTER = new AtomicLong();
private final AtomicLong correlationIdGenerator = new AtomicLong();
private final FlightClient client;
private final DiscoveryNode node;
Expand Down Expand Up @@ -112,6 +112,11 @@ public FlightClientChannel(
this.closeListeners = new CopyOnWriteArrayList<>();
this.stats = new ChannelStats();
this.isClosed = false;
// Initialize with timestamp + global counter to ensure uniqueness with multiple channels
// Upper bits: timestamp, lower 20 bits: channel ID
long channelId = GLOBAL_CHANNEL_COUNTER.incrementAndGet() & 0xFFFFF; // 20 bits for channel ID
long initialValue = (System.currentTimeMillis() << 20) | channelId;
this.correlationIdGenerator.set(initialValue);
if (statsCollector != null) {
statsCollector.incrementClientChannelsActive();
}
Expand Down Expand Up @@ -229,7 +234,8 @@ public void sendMessage(long requestId, BytesReference reference, ActionListener
config
);

processStreamResponse(streamResponse);
// Open stream and prefetch first batch, invoke handler when ready
openStreamAndInvokeHandler(streamResponse);
listener.onResponse(null);
} catch (Exception e) {
if (callTracker != null) {
Expand All @@ -244,39 +250,45 @@ public void sendMessage(BytesReference reference, ActionListener<Void> listener)
throw new IllegalStateException("sendMessage must be accompanied with requestId for FlightClientChannel, use the right variant.");
}

private void processStreamResponse(FlightTransportResponse<?> streamResponse) {
try {
executeWithThreadContext(streamResponse);
} catch (Exception e) {
handleStreamException(streamResponse, e);
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private void executeWithThreadContext(FlightTransportResponse<?> streamResponse) {
final ThreadContext threadContext = threadPool.getThreadContext();
final String executor = streamResponse.getHandler().executor();
private void openStreamAndInvokeHandler(FlightTransportResponse<?> streamResponse) {
TransportResponseHandler handler = streamResponse.getHandler();
String executor = handler.executor();

if (ThreadPool.Names.SAME.equals(executor)) {
executeHandler(threadContext, streamResponse);
} else {
threadPool.executor(executor).execute(() -> executeHandler(threadContext, streamResponse));
logger.warn("Stream transport handler using SAME executor, which may cause blocking behavior");
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private void executeHandler(ThreadContext threadContext, FlightTransportResponse<?> streamResponse) {
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
Header header = streamResponse.getHeader();
if (header == null) {
throw new StreamException(StreamErrorCode.INTERNAL, "Header is null");
var threadContext = threadPool.getThreadContext();
CompletableFuture<Header> future = new CompletableFuture<>();
streamResponse.openAndPrefetchAsync(future);

future.whenComplete((header, error) -> {
if (error != null) {
handleStreamException(streamResponse, error instanceof Exception ? (Exception) error : new Exception(error));
return;
}
TransportResponseHandler handler = streamResponse.getHandler();
threadContext.setHeaders(header.getHeaders());
handler.handleStreamResponse(streamResponse);
} catch (Exception e) {
cleanupStreamResponse(streamResponse);
throw e;
}

Runnable task = () -> {
try (var ignored = threadContext.stashContext()) {
if (header == null) {
cleanupStreamResponse(streamResponse);
throw new StreamException(StreamErrorCode.INTERNAL, "Header is null");
}
threadContext.setHeaders(header.getHeaders());
handler.handleStreamResponse(streamResponse);
} catch (Exception e) {
cleanupStreamResponse(streamResponse);
throw e;
}
};

if (ThreadPool.Names.SAME.equals(executor)) {
task.run();
} else {
threadPool.executor(executor).execute(task);
}
});
Comment on lines +272 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Exception thrown inside executor task may go unhandled.

When header == null (line 274-277) or an exception occurs (line 280-283), the exception is re-thrown from within the Runnable. For the SAME executor case (line 287), this propagates up. For the thread pool executor case (line 289), the exception may be swallowed or only logged by the executor's uncaught exception handler.

Consider wrapping the task execution in a try-catch that routes exceptions through handleStreamException for consistent error handling:

🔎 Suggested fix
             Runnable task = () -> {
                 try (var ignored = threadContext.stashContext()) {
                     if (header == null) {
                         cleanupStreamResponse(streamResponse);
-                        throw new StreamException(StreamErrorCode.INTERNAL, "Header is null");
+                        handleStreamException(streamResponse, new StreamException(StreamErrorCode.INTERNAL, "Header is null"));
+                        return;
                     }
                     threadContext.setHeaders(header.getHeaders());
                     handler.handleStreamResponse(streamResponse);
                 } catch (Exception e) {
                     cleanupStreamResponse(streamResponse);
-                    throw e;
+                    handleStreamException(streamResponse, e);
                 }
             };

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In
plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java
around lines 272 to 291, exceptions thrown inside the Runnable (e.g., when
header == null or other errors) are re-thrown and may be unhandled when executed
on a thread pool; wrap task execution so all exceptions are caught and delegated
to the existing handleStreamException mechanism (or equivalent error handler)
instead of re-throwing: modify the Runnable to catch Throwable, call
cleanupStreamResponse as needed, and invoke
handleStreamException(streamResponse, throwable) (or the proper handler) before
returning; also ensure the executor invocation (both SAME and
threadPool.executor(executor).execute(...)) does not re-throw but relies on this
catch-and-handle behavior.

}

private void cleanupStreamResponse(StreamTransportResponse<?> streamResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ private void processCompleteTask(BatchTask task) {
}

try {
flightChannel.completeStream();
flightChannel.completeStream(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()));
messageListener.onResponseSent(task.requestId(), task.action(), TransportResponse.Empty.INSTANCE);
} catch (Exception e) {
messageListener.onResponseSent(task.requestId(), task.action(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -49,7 +48,7 @@ class FlightServerChannel implements TcpChannel {
private final InetSocketAddress remoteAddress;
private final List<ActionListener<Void>> closeListeners = Collections.synchronizedList(new ArrayList<>());
private final ServerHeaderMiddleware middleware;
private volatile Optional<VectorSchemaRoot> root = Optional.empty();
private volatile VectorSchemaRoot root = null;
private final FlightCallTracker callTracker;
private volatile boolean cancelled = false;
private final ExecutorService executor;
Expand All @@ -63,13 +62,10 @@ public FlightServerChannel(
) {
this.serverStreamListener = serverStreamListener;
this.serverStreamListener.setUseZeroCopy(true);
this.serverStreamListener.setOnCancelHandler(new Runnable() {
@Override
public void run() {
cancelled = true;
callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name());
close();
}
this.serverStreamListener.setOnCancelHandler(() -> {
cancelled = true;
callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name());
close();
});
this.allocator = allocator;
this.middleware = middleware;
Expand All @@ -83,7 +79,7 @@ public BufferAllocator getAllocator() {
return allocator;
}

Optional<VectorSchemaRoot> getRoot() {
VectorSchemaRoot getRoot() {
return root;
}

Expand All @@ -108,20 +104,20 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) {
}
long batchStartTime = System.nanoTime();
// Only set for the first batch
if (root.isEmpty()) {
if (root == null) {
middleware.setHeader(header);
root = Optional.of(output.getRoot());
serverStreamListener.start(root.get());
root = output.getRoot();
serverStreamListener.start(root);
} else {
root = Optional.of(output.getRoot());
root = output.getRoot();
// placeholder to clear and fill the root with data for the next batch
}

// we do not want to close the root right after putNext() call as we do not know the status of it whether
// its transmitted at transport; we close them all at complete stream. TODO: optimize this behaviour
serverStreamListener.putNext();
if (callTracker != null) {
long rootSize = FlightUtils.calculateVectorSchemaRootSize(root.get());
long rootSize = FlightUtils.calculateVectorSchemaRootSize(root);
callTracker.recordBatchSent(rootSize, System.nanoTime() - batchStartTime);
}
}
Expand All @@ -130,11 +126,15 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) {
* Completes the streaming response and closes all pending roots.
*
*/
public void completeStream() {
public void completeStream(ByteBuffer header) {
try {
if (!open.get()) {
throw new IllegalStateException("FlightServerChannel already closed.");
}
if (root == null) {
// Set header if no batches were sent
middleware.setHeader(header);
}
serverStreamListener.completed();
} finally {
callTracker.recordCallEnd(StreamErrorCode.OK.name());
Expand Down Expand Up @@ -210,7 +210,9 @@ public void close() {
return;
}
open.set(false);
root.ifPresent(VectorSchemaRoot::close);
if (root != null) {
root.close();
}
notifyCloseListeners();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ public List<Setting<?>> getSettings() {
ServerComponents.SETTING_FLIGHT_PORTS,
ServerComponents.SETTING_FLIGHT_HOST,
ServerComponents.SETTING_FLIGHT_BIND_HOST,
ServerComponents.SETTING_FLIGHT_PUBLISH_HOST
ServerComponents.SETTING_FLIGHT_PUBLISH_HOST,
ServerComponents.SETTING_FLIGHT_PUBLISH_PORT
)
) {
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -132,7 +133,7 @@ public FlightTransport(
this.sslContextProvider = sslContextProvider;
this.statsCollector = statsCollector;
this.bossEventLoopGroup = createEventLoopGroup("os-grpc-boss-ELG", 1);
this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors() * 2);
this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors());
this.serverExecutor = threadPool.executor(ServerConfig.GRPC_EXECUTOR_THREAD_POOL_NAME);
this.clientExecutor = threadPool.executor(ServerConfig.FLIGHT_CLIENT_THREAD_POOL_NAME);
this.threadPool = threadPool;
Expand Down Expand Up @@ -409,7 +410,9 @@ protected InboundHandler createInboundHandler(
}

private EventLoopGroup createEventLoopGroup(String name, int threads) {
return new NioEventLoopGroup(threads);
AtomicInteger threadCounter = new AtomicInteger(0);
ThreadFactory threadFactory = r -> new Thread(r, name + "-" + threadCounter.incrementAndGet());
return new NioEventLoopGroup(threads, threadFactory);
}

private void gracefullyShutdownELG(EventLoopGroup group, String name) {
Expand Down
Loading
Loading