From 90ed36a0cd4c258d242e35857e79c9e5af69ef53 Mon Sep 17 00:00:00 2001 From: Rishabh Maurya Date: Wed, 20 Nov 2024 11:13:28 -0800 Subject: [PATCH] implement cancellation logic for StreamProducer's BatchedJob --- .../opensearch/flight/ProxyStreamProducer.java | 18 ++++++++++++++++++ .../opensearch/arrow/StreamManagerWrapper.java | 12 ++++++++++++ .../search/query/StreamSearchPhase.java | 16 ++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/flight/ProxyStreamProducer.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/flight/ProxyStreamProducer.java index db7538d535401..e016e3dd87ea0 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/flight/ProxyStreamProducer.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/flight/ProxyStreamProducer.java @@ -14,6 +14,8 @@ import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamTicket; +import java.io.IOException; + /** * ProxyStreamProvider acts as forward proxy for FlightStream. * It creates a BatchedJob to handle the streaming of data from the remote FlightStream. @@ -38,6 +40,15 @@ public BatchedJob createJob(BufferAllocator allocator) { return new ProxyBatchedJob(remoteStream); } + @Override + public void close() { + try { + remoteStream.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + public static class ProxyBatchedJob implements BatchedJob { private final FlightStream remoteStream; @@ -66,5 +77,12 @@ public void onCancel() { throw new RuntimeException(e); } } + + @Override + public boolean isCancelled() { + // Proxy stream don't have any business logic to set this flag, + // they piggyback on remote stream getting cancelled. + return false; + } } } diff --git a/server/src/main/java/org/opensearch/arrow/StreamManagerWrapper.java b/server/src/main/java/org/opensearch/arrow/StreamManagerWrapper.java index 91f0a3c898fb0..10a300063049e 100644 --- a/server/src/main/java/org/opensearch/arrow/StreamManagerWrapper.java +++ b/server/src/main/java/org/opensearch/arrow/StreamManagerWrapper.java @@ -19,6 +19,8 @@ import org.opensearch.tasks.TaskAwareRequest; import org.opensearch.tasks.TaskManager; +import java.io.IOException; + public class StreamManagerWrapper implements StreamManager { private final StreamManager streamManager; @@ -96,6 +98,11 @@ public int estimatedRowCount() { return streamProducer.estimatedRowCount(); } + @Override + public void close() throws IOException { + streamProducer.close(); + } + static class BatchedJobTaskWrapper implements BatchedJob, TaskAwareRequest { private final BatchedJob batchedJob; private final TaskManager taskManager; @@ -140,6 +147,11 @@ public void onCancel() { batchedJob.onCancel(); } + @Override + public boolean isCancelled() { + return batchedJob.isCancelled(); + } + @Override public void setParentTask(TaskId taskId) { this.parentTaskId = taskId; diff --git a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java index 676f32a27fa20..70e78b9bfea16 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java @@ -16,6 +16,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.Collector; import org.apache.lucene.search.Query; +import org.opensearch.OpenSearchException; import org.opensearch.arrow.StreamManager; import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamTicket; @@ -112,7 +113,14 @@ private boolean searchWithCollector( if (streamManager == null) { throw new RuntimeException("StreamManager not setup"); } + final boolean[] isCancelled = {false}; StreamTicket ticket = streamManager.registerStream(new StreamProducer() { + + @Override + public void close() { + isCancelled[0] = true; + } + @Override public BatchedJob createJob(BufferAllocator allocator) { return new BatchedJob() { @@ -123,6 +131,9 @@ public void run(VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal) { Collector collector = QueryCollectorContext.createQueryCollector(collectors); final ArrowDocIdCollector arrowDocIdCollector = new ArrowDocIdCollector(collector, root, flushSignal, 1000); try { + searcher.addQueryCancellation(() -> {if (isCancelled[0] == true) { + throw new OpenSearchException("Stream for query results cancelled."); + }}); searcher.search(query, arrowDocIdCollector); } catch (EarlyTerminatingCollector.EarlyTerminationException e) { // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of @@ -153,7 +164,12 @@ public void run(VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal) { @Override public void onCancel() { + isCancelled[0] = true; + } + @Override + public boolean isCancelled() { + return searchContext.isCancelled() || isCancelled(); } }; }