diff --git a/CHANGELOG.md b/CHANGELOG.md index 356142ccf98aa..4c2d120d6e94a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix Netty deprecation warnings in transport-netty4 module ([#20233](https://github.com/opensearch-project/OpenSearch/pull/20233)) - Fix snapshot restore when an index sort is present ([#20284](https://github.com/opensearch-project/OpenSearch/pull/20284)) - Fix SearchPhaseExecutionException to properly initCause ([#20320](https://github.com/opensearch-project/OpenSearch/pull/20320)) +- Fixes and refactoring in stream transport to make it more robust ([#20359](https://github.com/opensearch-project/OpenSearch/pull/20359)) ### Dependencies - Bump `com.google.auth:google-auth-library-oauth2-http` from 1.38.0 to 1.41.0 ([#20183](https://github.com/opensearch-project/OpenSearch/pull/20183)) diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerConfig.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerConfig.java index 2aadae9542845..2e8b687e6c7a6 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerConfig.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerConfig.java @@ -204,7 +204,8 @@ public static List> getSettings() { ARROW_ENABLE_DEBUG_ALLOCATOR, ARROW_ENABLE_UNSAFE_MEMORY_ACCESS, ARROW_SSL_ENABLE, - FLIGHT_EVENT_LOOP_THREADS + FLIGHT_EVENT_LOOP_THREADS, + FLIGHT_THREAD_POOL_MIN_SIZE ) ); } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java index c20ace2d41b74..25b7738a65c8a 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java @@ -16,7 +16,6 @@ import org.opensearch.arrow.flight.stats.FlightCallTracker; import org.opensearch.arrow.flight.stats.FlightStatsCollector; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; @@ -45,6 +44,7 @@ */ class FlightClientChannel implements TcpChannel { private static final Logger logger = LogManager.getLogger(FlightClientChannel.class); + private static final AtomicLong GLOBAL_CHANNEL_COUNTER = new AtomicLong(); private final AtomicLong correlationIdGenerator = new AtomicLong(); private final FlightClient client; private final DiscoveryNode node; @@ -112,6 +112,11 @@ public FlightClientChannel( this.closeListeners = new CopyOnWriteArrayList<>(); this.stats = new ChannelStats(); this.isClosed = false; + // Initialize with timestamp + global counter to ensure uniqueness with multiple channels + // Upper bits: timestamp, lower 20 bits: channel ID + long channelId = GLOBAL_CHANNEL_COUNTER.incrementAndGet() & 0xFFFFF; // 20 bits for channel ID + long initialValue = (System.currentTimeMillis() << 20) | channelId; + this.correlationIdGenerator.set(initialValue); if (statsCollector != null) { statsCollector.incrementClientChannelsActive(); } @@ -229,7 +234,8 @@ public void sendMessage(long requestId, BytesReference reference, ActionListener config ); - processStreamResponse(streamResponse); + // Open stream and prefetch first batch, invoke handler when ready + openStreamAndInvokeHandler(streamResponse); listener.onResponse(null); } catch (Exception e) { if (callTracker != null) { @@ -244,39 +250,45 @@ public void sendMessage(BytesReference reference, ActionListener listener) throw new IllegalStateException("sendMessage must be accompanied with requestId for FlightClientChannel, use the right variant."); } - private void processStreamResponse(FlightTransportResponse streamResponse) { - try { - executeWithThreadContext(streamResponse); - } catch (Exception e) { - handleStreamException(streamResponse, e); - } - } - @SuppressWarnings({ "unchecked", "rawtypes" }) - private void executeWithThreadContext(FlightTransportResponse streamResponse) { - final ThreadContext threadContext = threadPool.getThreadContext(); - final String executor = streamResponse.getHandler().executor(); + private void openStreamAndInvokeHandler(FlightTransportResponse streamResponse) { + TransportResponseHandler handler = streamResponse.getHandler(); + String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { - executeHandler(threadContext, streamResponse); - } else { - threadPool.executor(executor).execute(() -> executeHandler(threadContext, streamResponse)); + logger.warn("Stream transport handler using SAME executor, which may cause blocking behavior"); } - } - @SuppressWarnings({ "unchecked", "rawtypes" }) - private void executeHandler(ThreadContext threadContext, FlightTransportResponse streamResponse) { - try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { - Header header = streamResponse.getHeader(); - if (header == null) { - throw new StreamException(StreamErrorCode.INTERNAL, "Header is null"); + var threadContext = threadPool.getThreadContext(); + CompletableFuture
future = new CompletableFuture<>(); + streamResponse.openAndPrefetchAsync(future); + + future.whenComplete((header, error) -> { + if (error != null) { + handleStreamException(streamResponse, error instanceof Exception ? (Exception) error : new Exception(error)); + return; } - TransportResponseHandler handler = streamResponse.getHandler(); - threadContext.setHeaders(header.getHeaders()); - handler.handleStreamResponse(streamResponse); - } catch (Exception e) { - cleanupStreamResponse(streamResponse); - throw e; - } + + Runnable task = () -> { + try (var ignored = threadContext.stashContext()) { + if (header == null) { + cleanupStreamResponse(streamResponse); + throw new StreamException(StreamErrorCode.INTERNAL, "Header is null"); + } + threadContext.setHeaders(header.getHeaders()); + handler.handleStreamResponse(streamResponse); + } catch (Exception e) { + cleanupStreamResponse(streamResponse); + throw e; + } + }; + + if (ThreadPool.Names.SAME.equals(executor)) { + task.run(); + } else { + threadPool.executor(executor).execute(task); + } + }); } private void cleanupStreamResponse(StreamTransportResponse streamResponse) { diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java index b340b98492f01..a4a76d518bad2 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java @@ -215,7 +215,7 @@ private void processCompleteTask(BatchTask task) { } try { - flightChannel.completeStream(); + flightChannel.completeStream(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features())); messageListener.onResponseSent(task.requestId(), task.action(), TransportResponse.Empty.INSTANCE); } catch (Exception e) { messageListener.onResponseSent(task.requestId(), task.action(), e); diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightServerChannel.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightServerChannel.java index cef4ed0d99c92..ad14291437ad8 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightServerChannel.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightServerChannel.java @@ -28,7 +28,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; @@ -49,7 +48,7 @@ class FlightServerChannel implements TcpChannel { private final InetSocketAddress remoteAddress; private final List> closeListeners = Collections.synchronizedList(new ArrayList<>()); private final ServerHeaderMiddleware middleware; - private volatile Optional root = Optional.empty(); + private volatile VectorSchemaRoot root = null; private final FlightCallTracker callTracker; private volatile boolean cancelled = false; private final ExecutorService executor; @@ -63,13 +62,10 @@ public FlightServerChannel( ) { this.serverStreamListener = serverStreamListener; this.serverStreamListener.setUseZeroCopy(true); - this.serverStreamListener.setOnCancelHandler(new Runnable() { - @Override - public void run() { - cancelled = true; - callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name()); - close(); - } + this.serverStreamListener.setOnCancelHandler(() -> { + cancelled = true; + callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name()); + close(); }); this.allocator = allocator; this.middleware = middleware; @@ -83,7 +79,7 @@ public BufferAllocator getAllocator() { return allocator; } - Optional getRoot() { + VectorSchemaRoot getRoot() { return root; } @@ -108,12 +104,12 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) { } long batchStartTime = System.nanoTime(); // Only set for the first batch - if (root.isEmpty()) { + if (root == null) { middleware.setHeader(header); - root = Optional.of(output.getRoot()); - serverStreamListener.start(root.get()); + root = output.getRoot(); + serverStreamListener.start(root); } else { - root = Optional.of(output.getRoot()); + root = output.getRoot(); // placeholder to clear and fill the root with data for the next batch } @@ -121,7 +117,7 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) { // its transmitted at transport; we close them all at complete stream. TODO: optimize this behaviour serverStreamListener.putNext(); if (callTracker != null) { - long rootSize = FlightUtils.calculateVectorSchemaRootSize(root.get()); + long rootSize = FlightUtils.calculateVectorSchemaRootSize(root); callTracker.recordBatchSent(rootSize, System.nanoTime() - batchStartTime); } } @@ -130,11 +126,15 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) { * Completes the streaming response and closes all pending roots. * */ - public void completeStream() { + public void completeStream(ByteBuffer header) { try { if (!open.get()) { throw new IllegalStateException("FlightServerChannel already closed."); } + if (root == null) { + // Set header if no batches were sent + middleware.setHeader(header); + } serverStreamListener.completed(); } finally { callTracker.recordCallEnd(StreamErrorCode.OK.name()); @@ -210,7 +210,9 @@ public void close() { return; } open.set(false); - root.ifPresent(VectorSchemaRoot::close); + if (root != null) { + root.close(); + } notifyCloseListeners(); } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightStreamPlugin.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightStreamPlugin.java index 10550d9730a40..dd7cb34a9db34 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightStreamPlugin.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightStreamPlugin.java @@ -364,7 +364,8 @@ public List> getSettings() { ServerComponents.SETTING_FLIGHT_PORTS, ServerComponents.SETTING_FLIGHT_HOST, ServerComponents.SETTING_FLIGHT_BIND_HOST, - ServerComponents.SETTING_FLIGHT_PUBLISH_HOST + ServerComponents.SETTING_FLIGHT_PUBLISH_HOST, + ServerComponents.SETTING_FLIGHT_PUBLISH_PORT ) ) { { diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransport.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransport.java index 6ad73468dbe19..d42a93b425cb9 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransport.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransport.java @@ -64,6 +64,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -132,7 +133,7 @@ public FlightTransport( this.sslContextProvider = sslContextProvider; this.statsCollector = statsCollector; this.bossEventLoopGroup = createEventLoopGroup("os-grpc-boss-ELG", 1); - this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors() * 2); + this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors()); this.serverExecutor = threadPool.executor(ServerConfig.GRPC_EXECUTOR_THREAD_POOL_NAME); this.clientExecutor = threadPool.executor(ServerConfig.FLIGHT_CLIENT_THREAD_POOL_NAME); this.threadPool = threadPool; @@ -409,7 +410,9 @@ protected InboundHandler createInboundHandler( } private EventLoopGroup createEventLoopGroup(String name, int threads) { - return new NioEventLoopGroup(threads); + AtomicInteger threadCounter = new AtomicInteger(0); + ThreadFactory threadFactory = r -> new Thread(r, name + "-" + threadCounter.incrementAndGet()); + return new NioEventLoopGroup(threads, threadFactory); } private void gracefullyShutdownELG(EventLoopGroup group, String name) { diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransportResponse.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransportResponse.java index 048c1875b5679..011ac0e877e1d 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransportResponse.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightTransportResponse.java @@ -27,40 +27,34 @@ import java.io.IOException; import java.util.Objects; +import java.util.concurrent.CompletableFuture; import static org.opensearch.arrow.flight.transport.ClientHeaderMiddleware.CORRELATION_ID_KEY; /** - * Arrow Flight implementation of streaming transport responses. - * - *

Handles streaming responses from Arrow Flight servers with lazy batch processing. - * Headers are extracted when first accessed, and responses are deserialized on demand. + * Streaming transport response implementation using Arrow Flight. + * Manages Flight stream lifecycle with lazy initialization and prefetching support. */ class FlightTransportResponse implements StreamTransportResponse { private static final Logger logger = LogManager.getLogger(FlightTransportResponse.class); - private final FlightStream flightStream; + private final FlightClient flightClient; + private final Ticket ticket; + private final FlightCallHeaders callHeaders; private final NamedWriteableRegistry namedWriteableRegistry; private final HeaderContext headerContext; - private final long correlationId; + private final TransportResponseHandler handler; private final FlightTransportConfig config; + private final long correlationId; - private final TransportResponseHandler handler; - private boolean isClosed; - - // Stream state - private VectorSchemaRoot currentRoot; - private Header currentHeader; - private boolean streamInitialized = false; - private boolean streamExhausted = false; - private boolean firstResponseConsumed = false; - private StreamException initializationException; - private long currentBatchSize; - - /** - * Creates a new Flight transport response. - */ - public FlightTransportResponse( + private volatile FlightStream flightStream; + private volatile long currentBatchSize; + private volatile boolean firstBatchConsumed; + private volatile boolean closed; + private volatile boolean prefetchStarted; + private volatile Header initialHeader; + + FlightTransportResponse( TransportResponseHandler handler, long correlationId, FlightClient flightClient, @@ -69,94 +63,94 @@ public FlightTransportResponse( NamedWriteableRegistry namedWriteableRegistry, FlightTransportConfig config ) { - this.handler = handler; + this.handler = Objects.requireNonNull(handler); this.correlationId = correlationId; - this.headerContext = Objects.requireNonNull(headerContext, "headerContext must not be null"); - this.namedWriteableRegistry = namedWriteableRegistry; - this.config = config; - // Initialize Flight stream with correlation ID header - FlightCallHeaders callHeaders = new FlightCallHeaders(); - callHeaders.insert(CORRELATION_ID_KEY, String.valueOf(correlationId)); - HeaderCallOption callOptions = new HeaderCallOption(callHeaders); - this.flightStream = flightClient.getStream(ticket, callOptions); - - this.isClosed = false; + this.flightClient = Objects.requireNonNull(flightClient); + this.headerContext = Objects.requireNonNull(headerContext); + this.ticket = Objects.requireNonNull(ticket); + this.namedWriteableRegistry = Objects.requireNonNull(namedWriteableRegistry); + this.config = Objects.requireNonNull(config); + this.callHeaders = new FlightCallHeaders(); + this.callHeaders.insert(CORRELATION_ID_KEY, String.valueOf(correlationId)); } - /** - * Gets the header for the current batch. - * If no batch has been fetched yet, fetches the first batch to extract headers. - */ - public Header getHeader() { - ensureOpen(); - initializeStreamIfNeeded(); - return currentHeader; + void openAndPrefetchAsync(CompletableFuture

future) { + if (prefetchStarted) return; + + synchronized (this) { + if (prefetchStarted) return; + if (closed) { + future.completeExceptionally(new StreamException(StreamErrorCode.UNAVAILABLE, "Stream is closed")); + return; + } + + prefetchStarted = true; + + Thread.ofVirtual().start(() -> { + try { + long start = System.nanoTime(); + flightStream = flightClient.getStream(ticket, new HeaderCallOption(callHeaders)); + if ((System.nanoTime() - start) / 1_000_000 > 10) { + logger.debug( + "FlightClient.getStream() for correlationId: {} took {}ms", + correlationId, + (System.nanoTime() - start) / 1_000_000 + ); + } + + flightStream.next(); + initialHeader = headerContext.getHeader(correlationId); + future.complete(initialHeader); + } catch (FlightRuntimeException e) { + future.completeExceptionally(FlightErrorMapper.fromFlightException(e)); + } catch (Exception e) { + future.completeExceptionally(new StreamException(StreamErrorCode.INTERNAL, "Stream open/prefetch failed", e)); + } + }); + } + } + + TransportResponseHandler getHandler() { + return handler; } - /** - * Gets the next response from the stream. - */ @Override public T nextResponse() { - ensureOpen(); - initializeStreamIfNeeded(); - - if (streamExhausted) { - if (initializationException != null) { - throw initializationException; - } - return null; - } + if (closed) throw new StreamException(StreamErrorCode.UNAVAILABLE, "Stream is closed"); + if (flightStream == null) throw new IllegalStateException("openAndPrefetch() must be called first"); long startTime = System.currentTimeMillis(); try { - if (!firstResponseConsumed) { - // First call - use the batch we already fetched during initialization - firstResponseConsumed = true; - return deserializeResponse(); - } - - if (flightStream.next()) { - currentRoot = flightStream.getRoot(); - currentHeader = headerContext.getHeader(correlationId); - // Capture the batch size before deserialization - currentBatchSize = FlightUtils.calculateVectorSchemaRootSize(currentRoot); - return deserializeResponse(); - } else { - streamExhausted = true; - return null; + boolean hasNext = firstBatchConsumed ? flightStream.next() : (firstBatchConsumed = true); + if (!hasNext) return null; + + VectorSchemaRoot root = flightStream.getRoot(); + currentBatchSize = FlightUtils.calculateVectorSchemaRootSize(root); + try (VectorStreamInput input = new VectorStreamInput(root, namedWriteableRegistry)) { + input.setVersion(initialHeader.getVersion()); + return handler.read(input); } } catch (FlightRuntimeException e) { - streamExhausted = true; throw FlightErrorMapper.fromFlightException(e); - } catch (Exception e) { - streamExhausted = true; - throw new StreamException(StreamErrorCode.INTERNAL, "Failed to fetch next batch", e); + } catch (IOException e) { + throw new StreamException(StreamErrorCode.INTERNAL, "Failed to deserialize batch", e); } finally { - logSlowOperation(startTime); + long took = System.currentTimeMillis() - startTime; + if (took > config.getSlowLogThreshold().millis()) { + logger.warn("Flight stream next() took [{}ms], exceeding threshold [{}ms]", took, config.getSlowLogThreshold().millis()); + } } } - /** - * Gets the size of the current batch in bytes. - * - * @return the size in bytes, or 0 if no batch is available - */ - public long getCurrentBatchSize() { + long getCurrentBatchSize() { return currentBatchSize; } - /** - * Cancels the Flight stream. - */ @Override public void cancel(String reason, Throwable cause) { - if (isClosed) { - return; - } + if (closed) return; try { - flightStream.cancel(reason, cause); - logger.debug("Cancelled flight stream: {}", reason); + if (flightStream != null) flightStream.cancel(reason, cause); } catch (Exception e) { logger.warn("Error cancelling flight stream", e); } finally { @@ -164,88 +158,17 @@ public void cancel(String reason, Throwable cause) { } } - /** - * Closes the Flight stream and releases resources. - */ @Override public void close() { - if (isClosed) { - return; - } - try { - if (currentRoot != null) { - currentRoot.close(); - currentRoot = null; + if (closed) return; + closed = true; + + if (flightStream != null) { + try { + flightStream.close(); + } catch (IllegalStateException ignore) {} catch (Exception e) { + throw new StreamException(StreamErrorCode.INTERNAL, "Error closing flight stream", e); } - flightStream.close(); - } catch (IllegalStateException ignore) { - // this is fine if the allocator is already closed - } catch (Exception e) { - throw new StreamException(StreamErrorCode.INTERNAL, "Error while closing flight stream", e); - } finally { - isClosed = true; - } - } - - public TransportResponseHandler getHandler() { - return handler; - } - - /** - * Initializes the stream by fetching the first batch to extract headers. - */ - private synchronized void initializeStreamIfNeeded() { - if (streamInitialized || streamExhausted) { - return; - } - long startTime = System.currentTimeMillis(); - try { - if (flightStream.next()) { - currentRoot = flightStream.getRoot(); - currentHeader = headerContext.getHeader(correlationId); - // Capture the batch size before deserialization - currentBatchSize = FlightUtils.calculateVectorSchemaRootSize(currentRoot); - streamInitialized = true; - } else { - streamExhausted = true; - } - } catch (FlightRuntimeException e) { - // TODO maybe add a check - handshake and validate if node is connected - // Try to get headers even if stream failed - currentHeader = headerContext.getHeader(correlationId); - streamExhausted = true; - initializationException = FlightErrorMapper.fromFlightException(e); - logger.warn("Stream initialization failed", e); - } catch (Exception e) { - // Try to get headers even if stream failed - currentHeader = headerContext.getHeader(correlationId); - streamExhausted = true; - initializationException = new StreamException(StreamErrorCode.INTERNAL, "Stream initialization failed", e); - logger.warn("Stream initialization failed", e); - } finally { - logSlowOperation(startTime); - } - } - - private T deserializeResponse() { - try (VectorStreamInput input = new VectorStreamInput(currentRoot, namedWriteableRegistry)) { - return handler.read(input); - } catch (IOException e) { - throw new StreamException(StreamErrorCode.INTERNAL, "Failed to deserialize response", e); - } - } - - private void ensureOpen() { - if (isClosed) { - throw new StreamException(StreamErrorCode.UNAVAILABLE, "Stream is closed"); - } - } - - private void logSlowOperation(long startTime) { - long took = System.currentTimeMillis() - startTime; - long thresholdMs = config.getSlowLogThreshold().millis(); - if (took > thresholdMs) { - logger.warn("Flight stream next() took [{}ms], exceeding threshold [{}ms]", took, thresholdMs); } } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/VectorStreamOutput.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/VectorStreamOutput.java index 09e9e4a54c6c9..9755c4c77dbda 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/VectorStreamOutput.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/VectorStreamOutput.java @@ -18,45 +18,60 @@ import java.io.IOException; import java.util.List; -import java.util.Optional; class VectorStreamOutput extends StreamOutput { private int row = 0; private final VarBinaryVector vector; - private Optional root = Optional.empty(); + private VectorSchemaRoot root; + private final byte[] tempBuffer = new byte[8192]; + private int tempBufferPos = 0; - public VectorStreamOutput(BufferAllocator allocator, Optional root) { - if (root.isPresent()) { - vector = (VarBinaryVector) root.get().getVector(0); + public VectorStreamOutput(BufferAllocator allocator, VectorSchemaRoot root) { + if (root != null) { + vector = (VarBinaryVector) root.getVector(0); this.root = root; } else { Field field = new Field("0", new FieldType(true, new ArrowType.Binary(), null, null), null); vector = (VarBinaryVector) field.createVector(allocator); + // Pre-allocate with reasonable capacity to avoid repeated allocations + vector.setInitialCapacity(16); + vector.allocateNew(); } - vector.allocateNew(); } @Override - public void writeByte(byte b) throws IOException { - vector.setInitialCapacity(row + 1); - vector.setSafe(row++, new byte[] { b }); + public void writeByte(byte b) { + // Buffer small writes to reduce vector operations + if (tempBufferPos >= tempBuffer.length) { + flushTempBuffer(); + } + tempBuffer[tempBufferPos++] = b; } @Override - public void writeBytes(byte[] b, int offset, int length) throws IOException { - vector.setInitialCapacity(row + 1); + public void writeBytes(byte[] b, int offset, int length) { if (length == 0) { return; } if (b.length < (offset + length)) { throw new IllegalArgumentException("Illegal offset " + offset + "/length " + length + " for byte[] of length " + b.length); } + if (tempBufferPos > 0) { + flushTempBuffer(); + } vector.setSafe(row++, b, offset, length); } + private void flushTempBuffer() { + if (tempBufferPos > 0) { + vector.setSafe(row++, tempBuffer, 0, tempBufferPos); + tempBufferPos = 0; + } + } + @Override - public void flush() throws IOException { + public void flush() { } @@ -67,17 +82,19 @@ public void close() throws IOException { } @Override - public void reset() throws IOException { + public void reset() { row = 0; + tempBufferPos = 0; vector.clear(); } public VectorSchemaRoot getRoot() { + flushTempBuffer(); vector.setValueCount(row); - if (!root.isPresent()) { - root = Optional.of(new VectorSchemaRoot(List.of(vector))); + if (root == null) { + root = new VectorSchemaRoot(List.of(vector)); } - root.get().setRowCount(row); - return root.get(); + root.setRowCount(row); + return root; } } diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/ArrowStreamSerializationTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/ArrowStreamSerializationTests.java index 843ddbcc1e385..16024496216da 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/ArrowStreamSerializationTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/ArrowStreamSerializationTests.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.Optional; public class ArrowStreamSerializationTests extends OpenSearchTestCase { private NamedWriteableRegistry registry; @@ -51,7 +50,7 @@ public void tearDown() throws Exception { public void testInternalAggregationSerializationDeserialization() throws IOException { StringTerms original = createTestStringTerms(); - try (VectorStreamOutput output = new VectorStreamOutput(allocator, Optional.empty())) { + try (VectorStreamOutput output = new VectorStreamOutput(allocator, null)) { output.writeNamedWriteable(original); VectorSchemaRoot unifiedRoot = output.getRoot(); diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightClientChannelTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightClientChannelTests.java index ebdafdf0225bd..2d513c9bd9cf9 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightClientChannelTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightClientChannelTests.java @@ -555,7 +555,10 @@ public void handleStreamResponse(StreamTransportResponse streamRes } @Override - public void handleException(TransportException exp) {} + public void handleException(TransportException exp) { + handlerException.set(exp); + handlerLatch.countDown(); + } @Override public String executor() { diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightTransportTestBase.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightTransportTestBase.java index a9a4d19f7e9a1..12a05d94ee64d 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightTransportTestBase.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightTransportTestBase.java @@ -17,6 +17,7 @@ import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -41,6 +42,7 @@ import java.util.Collections; import java.util.concurrent.atomic.AtomicInteger; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -103,7 +105,9 @@ public void setUp() throws Exception { ); flightTransport.start(); TransportService transportService = mock(TransportService.class); - when(transportService.getTaskManager()).thenReturn(mock(TaskManager.class)); + TaskManager taskManager = mock(TaskManager.class); + when(taskManager.taskExecutionStarted(any())).thenReturn(mock(ThreadContext.StoredContext.class)); + when(transportService.getTaskManager()).thenReturn(taskManager); streamTransportService = spy( new StreamTransportService( settings,