From 1f711f5bb2fccdbc056edf711363f03fcb903aa7 Mon Sep 17 00:00:00 2001 From: Rishabh Maurya Date: Fri, 11 Oct 2024 11:19:10 -0700 Subject: [PATCH] Added more tests and code refactor --- libs/arrow/build.gradle | 4 - .../opensearch/arrow/ArrowStreamProvider.java | 13 +- .../flight/BaseBackpressureStrategy.java | 35 ++ .../opensearch/flight/BaseFlightProducer.java | 52 +- .../org/opensearch/flight/FlightService.java | 8 +- .../flight/FlightStreamManager.java | 19 +- .../opensearch/flight/FlightStreamPlugin.java | 19 +- .../flight/BaseFlightProducerTests.java | 444 ++++++++++++++++++ .../opensearch/flight/FlightServiceTests.java | 65 +++ .../flight/FlightStreamManagerTests.java | 79 ++++ .../flight/FlightStreamPluginTests.java | 72 +++ .../arrow/query/ArrowDocIdCollector.java | 4 +- .../plugins/StreamManagerPlugin.java | 6 + .../search/query/StreamSearchPhase.java | 5 + 14 files changed, 798 insertions(+), 27 deletions(-) create mode 100644 modules/arrow-flight/src/main/java/org/opensearch/flight/BaseBackpressureStrategy.java create mode 100644 modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java create mode 100644 modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java create mode 100644 modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java create mode 100644 modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java diff --git a/libs/arrow/build.gradle b/libs/arrow/build.gradle index 69dc9c79d9367..3a4275410e80a 100644 --- a/libs/arrow/build.gradle +++ b/libs/arrow/build.gradle @@ -41,10 +41,6 @@ dependencies { implementation "org.yaml:snakeyaml:${versions.snakeyaml}" implementation "io.projectreactor.tools:blockhound:1.0.9.RELEASE" - // implementation 'net.sf.jopt-simple:jopt-simple:5.0.4' - // implementation "org.apache.logging.log4j:log4j-api:${versions.log4j}" - // implementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" - testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}" testImplementation "junit:junit:${versions.junit}" testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}" diff --git a/libs/arrow/src/main/java/org/opensearch/arrow/ArrowStreamProvider.java b/libs/arrow/src/main/java/org/opensearch/arrow/ArrowStreamProvider.java index 35fc71956e2f2..4d4639858a3df 100644 --- a/libs/arrow/src/main/java/org/opensearch/arrow/ArrowStreamProvider.java +++ b/libs/arrow/src/main/java/org/opensearch/arrow/ArrowStreamProvider.java @@ -45,6 +45,14 @@ interface Task { */ void run(VectorSchemaRoot root, FlushSignal flushSignal); + /** + * Called when the task is canceled. + * This method is used to clean up resources or cancel ongoing operations. + * This maybe called from a different thread than the one used for run(). It might be possible that run() + * thread is busy when onCancel() is called and wakes up later. In such cases, ensure that run() terminates early + * and should clean up resources. + */ + void onCancel(); } /** @@ -54,7 +62,10 @@ interface Task { interface FlushSignal { /** * Waits for the consumption of the current data to complete. + * This method blocks until the consumption is complete or a timeout occurs. + * + * @param timeout The maximum time to wait for consumption (in milliseconds). */ - void awaitConsumption(); + void awaitConsumption(int timeout); } } diff --git a/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseBackpressureStrategy.java b/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseBackpressureStrategy.java new file mode 100644 index 0000000000000..ac500e18a54ff --- /dev/null +++ b/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseBackpressureStrategy.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.flight; + +import org.apache.arrow.flight.BackpressureStrategy; + +public class BaseBackpressureStrategy extends BackpressureStrategy.CallbackBackpressureStrategy { + private final Runnable readyCallback; + private final Runnable cancelCallback; + + BaseBackpressureStrategy(Runnable readyCallback, Runnable cancelCallback) { + this.readyCallback = readyCallback; + this.cancelCallback = cancelCallback; + } + + /** Callback to execute when the listener is ready. */ + protected void readyCallback() { + if (readyCallback != null) { + readyCallback.run(); + } + } + + /** Callback to execute when the listener is cancelled. */ + protected void cancelCallback() { + if (cancelCallback != null) { + cancelCallback.run(); + } + } +} diff --git a/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseFlightProducer.java b/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseFlightProducer.java index fb9320ce832b8..75cab25cd452d 100644 --- a/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseFlightProducer.java +++ b/modules/arrow-flight/src/main/java/org/opensearch/flight/BaseFlightProducer.java @@ -18,16 +18,39 @@ import org.opensearch.arrow.StreamManager; import org.opensearch.arrow.StreamTicket; +import java.util.function.Supplier; +/** + * BaseFlightProducer extends NoOpFlightProducer to provide stream management functionality + * for Arrow Flight in OpenSearch. This class handles the retrieval and streaming of data + * based on provided tickets, managing backpressure, and coordinating between the stream + * provider and the server stream listener. + */ public class BaseFlightProducer extends NoOpFlightProducer { private final StreamManager streamManager; private final BufferAllocator allocator; + /** + * Constructs a new BaseFlightProducer. + * + * @param streamManager The StreamManager to handle stream operations, including + * retrieving and removing streams based on tickets. + * @param allocator The BufferAllocator for memory management in Arrow operations. + */ public BaseFlightProducer(StreamManager streamManager, BufferAllocator allocator) { this.streamManager = streamManager; this.allocator = allocator; } + /** + * Handles the retrieval and streaming of data based on the provided ticket. + * This method orchestrates the entire process of setting up the stream, + * managing backpressure, and handling data flow to the client. + * + * @param context The call context (unused in this implementation) + * @param ticket The ticket containing stream information + * @param listener The server stream listener to handle the data flow + */ @Override public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { StreamTicket streamTicket = new StreamTicket(ticket.getBytes()) {}; @@ -37,30 +60,37 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l listener.error(CallStatus.NOT_FOUND.withDescription("Stream not found").toRuntimeException()); return; } - BackpressureStrategy backpressureStrategy = new BackpressureStrategy.CallbackBackpressureStrategy(); - backpressureStrategy.register(listener); ArrowStreamProvider.Task task = provider.create(allocator); - VectorSchemaRoot root = task.init(allocator); - listener.start(root); - ArrowStreamProvider.FlushSignal flushSignal = () -> { - BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(1000); + if (context.isCancelled()) { + task.onCancel(); + listener.error(CallStatus.CANCELLED.cause()); + return; + } + listener.setOnCancelHandler(task::onCancel); + BackpressureStrategy backpressureStrategy = new BaseBackpressureStrategy(null, task::onCancel); + backpressureStrategy.register(listener); + ArrowStreamProvider.FlushSignal flushSignal = (timeout) -> { + BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(timeout); if (result.equals(BackpressureStrategy.WaitResult.READY)) { listener.putNext(); } else if (result.equals(BackpressureStrategy.WaitResult.TIMEOUT)) { listener.error(CallStatus.TIMED_OUT.cause()); - throw new RuntimeException("Timeout waiting for listener" + result); + throw new RuntimeException("Stream deadline exceeded for consumption"); + } else if (result.equals(BackpressureStrategy.WaitResult.CANCELLED)) { + task.onCancel(); + listener.error(CallStatus.CANCELLED.cause()); + throw new RuntimeException("Stream cancelled by client"); } else { listener.error(CallStatus.INTERNAL.toRuntimeException()); throw new RuntimeException("Error while waiting for client: " + result); } }; - try { + try(VectorSchemaRoot root = task.init(allocator)) { + listener.start(root); task.run(root, flushSignal); - } finally { - root.close(); } } catch (Exception e) { - listener.error(CallStatus.INTERNAL.toRuntimeException().initCause(e)); + listener.error(CallStatus.INTERNAL.withDescription(e.getMessage()).withCause(e).cause()); throw e; } finally { listener.completed(); diff --git a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightService.java b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightService.java index 1f4cf23c8dcd7..a05642999f52f 100644 --- a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightService.java +++ b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightService.java @@ -23,7 +23,11 @@ import org.opensearch.common.settings.Settings; import java.io.IOException; - +/** + * FlightService manages the Arrow Flight server and client for OpenSearch. + * It handles the initialization, startup, and shutdown of the Flight server and client, + * as well as managing the stream operations through a FlightStreamManager. + */ @ExperimentalApi public class FlightService extends AbstractLifecycleComponent { @@ -75,7 +79,7 @@ public class FlightService extends AbstractLifecycleComponent { public static final Setting NETTY_ALLOCATOR_NUM_DIRECT_ARENAS = Setting.intSetting( "io.netty.allocator.numDirectArenas", - 1, + 1, // TODO - 2 * the number of available processors 1, Property.NodeScope ); diff --git a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamManager.java b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamManager.java index 4c61e5ad5e8f3..8cb9287a18bdc 100644 --- a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamManager.java +++ b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamManager.java @@ -17,18 +17,35 @@ import java.util.UUID; +/** + * FlightStreamManager is a concrete implementation of StreamManager that provides + * an abstraction layer for managing Arrow Flight streams in OpenSearch. + * It encapsulates the details of Flight client operations, allowing consumers to + * work with streams without direct exposure to Flight internals. + */ public class FlightStreamManager extends StreamManager { private final FlightClient flightClient; + /** + * Constructs a new FlightStreamManager. + * + * @param flightClient The FlightClient instance used for stream operations. + */ public FlightStreamManager(FlightClient flightClient) { super(); this.flightClient = flightClient; } + /** + * Retrieves a VectorSchemaRoot for a given stream ticket. + * @param ticket The StreamTicket identifying the desired stream. + * @return The VectorSchemaRoot associated with the given ticket. + */ @Override public VectorSchemaRoot getVectorSchemaRoot(StreamTicket ticket) { - // TODO: for remote streams + // TODO: for remote streams, register streams in cluster state with node details + // maintain flightClient for all nodes in the cluster to serve the stream FlightStream stream = flightClient.getStream(new Ticket(ticket.getBytes())); return stream.getRoot(); } diff --git a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamPlugin.java b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamPlugin.java index 9821ad2174400..6024fc426187e 100644 --- a/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamPlugin.java +++ b/modules/arrow-flight/src/main/java/org/opensearch/flight/FlightStreamPlugin.java @@ -30,11 +30,18 @@ import java.util.List; import java.util.function.Supplier; -import static org.opensearch.flight.FlightService.*; +import static org.opensearch.flight.FlightService.ARROW_ALLOCATION_MANAGER_TYPE; +import static org.opensearch.flight.FlightService.ARROW_ENABLE_NULL_CHECK_FOR_GET; +import static org.opensearch.flight.FlightService.ARROW_ENABLE_UNSAFE_MEMORY_ACCESS; +import static org.opensearch.flight.FlightService.NETTY_ALLOCATOR_NUM_DIRECT_ARENAS; +import static org.opensearch.flight.FlightService.NETTY_NO_UNSAFE; +import static org.opensearch.flight.FlightService.NETTY_TRY_REFLECTION_SET_ACCESSIBLE; +import static org.opensearch.flight.FlightService.NETTY_TRY_UNSAFE; + public class FlightStreamPlugin extends Plugin implements StreamManagerPlugin { - private FlightService flightService; + private final FlightService flightService; public FlightStreamPlugin(Settings settings) { this.flightService = new FlightService(settings); @@ -68,10 +75,10 @@ public List> getSettings() { ARROW_ALLOCATION_MANAGER_TYPE, ARROW_ENABLE_NULL_CHECK_FOR_GET, NETTY_TRY_REFLECTION_SET_ACCESSIBLE, - FlightService.ARROW_ENABLE_UNSAFE_MEMORY_ACCESS, - FlightService.NETTY_ALLOCATOR_NUM_DIRECT_ARENAS, - FlightService.NETTY_NO_UNSAFE, - FlightService.NETTY_TRY_UNSAFE + ARROW_ENABLE_UNSAFE_MEMORY_ACCESS, + NETTY_ALLOCATOR_NUM_DIRECT_ARENAS, + NETTY_NO_UNSAFE, + NETTY_TRY_UNSAFE ); } } diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java new file mode 100644 index 0000000000000..bd0cd88d8104e --- /dev/null +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java @@ -0,0 +1,444 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a.java + * compatible open source license. + */ + +package org.opensearch.flight; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.opensearch.arrow.ArrowStreamProvider; +import org.opensearch.arrow.StreamManager; +import org.opensearch.arrow.StreamTicket; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.mockito.Mockito.*; + +public class BaseFlightProducerTests extends OpenSearchTestCase { + + private BaseFlightProducer baseFlightProducer; + private StreamManager streamManager; + private ArrowStreamProvider arrowStreamProvider; + private ArrowStreamProvider.Task task; + + @Override + public void setUp() throws Exception { + super.setUp(); + streamManager = mock(StreamManager.class); + BufferAllocator allocator = mock(BufferAllocator.class); + arrowStreamProvider = mock(ArrowStreamProvider.class); + task = mock(ArrowStreamProvider.Task.class); + baseFlightProducer = new BaseFlightProducer(streamManager, allocator); + } + + private static class TestServerStreamListener implements FlightProducer.ServerStreamListener { + private final CountDownLatch completionLatch = new CountDownLatch(1); + private final AtomicInteger putNextCount = new AtomicInteger(0); + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + private Throwable error; + private final AtomicBoolean dataConsumed = new AtomicBoolean(false); + private final AtomicBoolean ready = new AtomicBoolean(false); + private Runnable onReadyHandler; + private Runnable onCancelHandler; + + @Override + public void putNext() { + assertFalse(dataConsumed.get()); + putNextCount.incrementAndGet(); + dataConsumed.set(true); + } + + @Override + public boolean isReady() { + return ready.get(); + } + + public void setReady(boolean val) { + ready.set(val); + if (this.onReadyHandler != null) { + this.onReadyHandler.run(); + } + } + + @Override + public void start(VectorSchemaRoot root) { + // No-op for this test + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) { + } + + @Override + public void putNext(ArrowBuf metadata) { + putNext(); + } + + @Override + public void putMetadata(ArrowBuf metadata) { + + } + + @Override + public void completed() { + completionLatch.countDown(); + } + + @Override + public void error(Throwable t) { + error = t; + completionLatch.countDown(); + } + + @Override + public boolean isCancelled() { + return isCancelled.get(); + } + + @Override + public void setOnReadyHandler(Runnable handler) { + this.onReadyHandler = handler; + } + + @Override + public void setOnCancelHandler(Runnable handler) { + this.onCancelHandler = handler; + } + + public void resetConsumptionLatch() { + dataConsumed.set(false); + } + + public boolean getDataConsumed() { + return dataConsumed.get(); + } + + public int getPutNextCount() { + return putNextCount.get(); + } + + public Throwable getError() { + return error; + } + + public void cancel() { + isCancelled.set(true); + if (this.onCancelHandler != null) { + this.onCancelHandler.run(); + } + } + } + + public void testGetStream_SuccessfulFlow() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 3; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(100); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(VectorSchemaRoot.class), any(ArrowStreamProvider.FlushSignal.class)); + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(3, listener.getPutNextCount()); + assertEquals(3, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClient() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(100); + listener.setReady(true); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(300); // waiting for consumption for more than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(), any()); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClientTimeout() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(400); + listener.setReady(true); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded for consumption", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithClientCancel() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + int finalI = i; + Thread clientThread = new Thread(() -> { + if (finalI == 4) { + listener.cancel(); + } else { + listener.setReady(false); + listener.setReady(true); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + assertNotNull(listener.getError()); + assertEquals("Stream cancelled by client", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithUnresponsiveClient() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + // not setting ready to simulate unresponsive behaviour + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded for consumption", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerBackpressure() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + Thread.sleep(100); // simulating writer backpressure + flushSignal.awaitConsumption(100); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(VectorSchemaRoot.class), any(ArrowStreamProvider.FlushSignal.class)); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerError() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(arrowStreamProvider); + when(arrowStreamProvider.create(any(BufferAllocator.class))).thenReturn(task); + when(task.init(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + ArrowStreamProvider.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + if (i == 4) { + throw new RuntimeException("Server error"); + } + flushSignal.awaitConsumption(100); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(task).run(any(VectorSchemaRoot.class), any(ArrowStreamProvider.FlushSignal.class)); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Server error", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithMultipleConcurrentClients() throws Exception { + + } + + public void testGetStream_StreamNotFound() throws Exception { + Ticket ticket = new Ticket(new byte[]{1, 2, 3}); + + when(streamManager.getStream(any(StreamTicket.class))).thenReturn(null); + + TestServerStreamListener listener = new TestServerStreamListener(); + + baseFlightProducer.getStream(null, ticket, listener); + + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Stream not found")); + assertEquals(0, listener.getPutNextCount()); + + verify(streamManager).removeStream(any(StreamTicket.class)); + } +} diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java new file mode 100644 index 0000000000000..0a549ade1ef3d --- /dev/null +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a.java + * compatible open source license. + */ + +package org.opensearch.flight; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.arrow.StreamManager; +import org.opensearch.common.settings.Settings; +import org.apache.arrow.flight.FlightClient; + +public class FlightServiceTests extends OpenSearchTestCase { + + private FlightService flightService; + + @Override + public void setUp() throws Exception { + super.setUp(); + Settings settings = Settings.builder().build(); + flightService = new FlightService(settings); + } + + public void testGetStreamManager() { + StreamManager streamManager = flightService.getStreamManager(); + assertNotNull(streamManager); + assertTrue(streamManager instanceof FlightStreamManager); + } + + public void testDoStart() throws Exception { + flightService.doStart(); + // Add assertions here + } + + public void testDoStop() throws Exception { + flightService.doStop(); + // Add assertions here + } + + public void testDoClose() throws Exception { + flightService.doClose(); + // Add assertions here + } + + public void testCreateFlightClient() { + FlightClient client = null; //flightService.createFlightClient(); + assertNotNull(client); + // Add more specific assertions based on the expected configuration of the FlightClient + } + + public void testCreateFlightClientWithCustomSettings() { + Settings customSettings = Settings.builder() + .put("plugins.flight.host", "custom-host") + .put("plugins.flight.port", 1234) + .build(); + FlightService customFlightService = new FlightService(customSettings); + + FlightClient client = null; // customFlightService.createFlightClient(); + assertNotNull(client); + // Add more specific assertions based on the expected configuration of the FlightClient + } +} diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java new file mode 100644 index 0000000000000..d0a2fbabb5852 --- /dev/null +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a.java + * compatible open source license. + */ + +package org.opensearch.flight; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.arrow.StreamTicket; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; + +import static org.mockito.Mockito.*; + +public class FlightStreamManagerTests extends OpenSearchTestCase { + + private FlightClient flightClient; + private FlightStreamManager flightStreamManager; + + @Override + public void setUp() throws Exception { + super.setUp(); + flightClient = mock(FlightClient.class); + flightStreamManager = new FlightStreamManager(flightClient); + } + + public void testGetVectorSchemaRoot() { + StreamTicket ticket = new StreamTicket(new byte[]{1, 2, 3}); + FlightStream mockFlightStream = mock(FlightStream.class); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + + when(flightClient.getStream(new Ticket(ticket.getBytes()))).thenReturn(mockFlightStream); + when(mockFlightStream.getRoot()).thenReturn(mockRoot); + when(mockRoot.getSchema()).thenReturn(new Schema(Collections.emptyList())); + + VectorSchemaRoot root = flightStreamManager.getVectorSchemaRoot(ticket); + + assertNotNull(root); + assertEquals(new Schema(Collections.emptyList()), root.getSchema()); + verify(flightClient).getStream(new Ticket(ticket.getBytes())); + } + + public void testGenerateUniqueTicket() { + StreamTicket ticket = flightStreamManager.generateUniqueTicket(); + assertNotNull(ticket); + assertNotNull(ticket.getBytes()); + assertTrue(ticket.getBytes().length > 0); + } + + public void testGetVectorSchemaRootWithException() { + StreamTicket ticket = new StreamTicket(new byte[]{1, 2, 3}); + when(flightClient.getStream(new Ticket(ticket.getBytes()))).thenThrow(new RuntimeException("Test exception")); + + expectThrows(RuntimeException.class, () -> flightStreamManager.getVectorSchemaRoot(ticket)); + verify(flightClient).getStream(new Ticket(ticket.getBytes())); + } + + public void testGenerateUniqueTicketMultipleCalls() { + StreamTicket ticket1 = flightStreamManager.generateUniqueTicket(); + StreamTicket ticket2 = flightStreamManager.generateUniqueTicket(); + + assertNotNull(ticket1); + assertNotNull(ticket2); + assertNotEquals(ticket1, ticket2); + } + + public void testGetVectorSchemaRootWithNullTicket() { + expectThrows(NullPointerException.class, () -> flightStreamManager.getVectorSchemaRoot(null)); + } +} diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java new file mode 100644 index 0000000000000..adcfaae1898dd --- /dev/null +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a.java + * compatible open source license. + */ + +package org.opensearch.flight; + +import org.opensearch.arrow.StreamManager; +import org.opensearch.common.settings.Setting; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.common.settings.Settings; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +public class FlightStreamPluginTests extends OpenSearchTestCase { + + private Settings settings; + private FlightStreamPlugin flightStreamPlugin; + + @Override + public void setUp() throws Exception { + super.setUp(); + settings = Settings.builder().build(); + flightStreamPlugin = new FlightStreamPlugin(settings); + } + + public void testCreateComponents() { + Collection components = flightStreamPlugin.createComponents(null, null, null,null, null,null, null, null, null, null, null); + assertNotNull(components); + assertTrue(components.stream().anyMatch(component -> component instanceof FlightService)); + } + + public void testGetStreamManager() { + StreamManager streamManager = flightStreamPlugin.getStreamManager(); + assertNotNull(streamManager); + assertTrue(streamManager instanceof FlightStreamManager); + } + + public void testGetSettings() { + List> settingsList = flightStreamPlugin.getSettings(); + assertNotNull(settingsList); + assertFalse(settingsList.isEmpty()); + assertTrue(settingsList.stream().anyMatch(setting -> setting.getKey().equals("plugins.flight.port"))); + assertTrue(settingsList.stream().anyMatch(setting -> setting.getKey().equals("plugins.flight.host"))); + } + + public void testCreateComponentsWithNullArguments() { + Collection components = flightStreamPlugin.createComponents(null, null, null,null, null,null, null, null, null, null, null); + assertNotNull(components); + assertFalse(components.isEmpty()); + } + + public void testGetSettingsDefaultValues() { + List> settingsList = flightStreamPlugin.getSettings(); + Optional> portSetting = settingsList.stream() + .filter(setting -> setting.getKey().equals("plugins.flight.port")) + .findFirst(); + Optional> hostSetting = settingsList.stream() + .filter(setting -> setting.getKey().equals("plugins.flight.host")) + .findFirst(); + + assertTrue(portSetting.isPresent()); + assertTrue(hostSetting.isPresent()); + assertEquals(8980, portSetting.get().getDefault(Settings.EMPTY)); + assertEquals("127.0.0.1", hostSetting.get().getDefault(Settings.EMPTY)); + } +} diff --git a/server/src/main/java/org/opensearch/arrow/query/ArrowDocIdCollector.java b/server/src/main/java/org/opensearch/arrow/query/ArrowDocIdCollector.java index 2ab5e4ca6f098..0790f4e9044f6 100644 --- a/server/src/main/java/org/opensearch/arrow/query/ArrowDocIdCollector.java +++ b/server/src/main/java/org/opensearch/arrow/query/ArrowDocIdCollector.java @@ -68,7 +68,7 @@ public void collect(int doc) throws IOException { currentRow++; if (currentRow >= batchSize) { root.setRowCount(batchSize); - flushSignal.awaitConsumption(); + flushSignal.awaitConsumption(1000); currentRow = 0; } } @@ -77,7 +77,7 @@ public void collect(int doc) throws IOException { public void finish() throws IOException { if (currentRow > 0) { root.setRowCount(currentRow); - flushSignal.awaitConsumption(); + flushSignal.awaitConsumption(1000); currentRow = 0; } } diff --git a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java index 00682026d0976..0a01e09fe2094 100644 --- a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java +++ b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java @@ -10,6 +10,12 @@ import org.opensearch.arrow.StreamManager; +/** + * An interface for OpenSearch plugins to implement to provide a StreamManager. + * This interface is used by the Arrow Flight plugin to get the StreamManager instance. + * Other plugins can also implement this interface to provide their own StreamManager implementation. + * @see org.opensearch.arrow.StreamManager + */ public interface StreamManagerPlugin { /** * Returns the StreamManager instance for this plugin. diff --git a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java index 5bc4bb4b1f533..24aa7c7dea41e 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java @@ -155,6 +155,11 @@ public void run(VectorSchemaRoot root, ArrowStreamProvider.FlushSignal flushSign throw new RuntimeException(e); } } + + @Override + public void onCancel() { + + } })); StreamSearchResult streamSearchResult = searchContext.streamSearchResult(); streamSearchResult.flights(List.of(new OSTicket(ticket.getBytes())));