diff --git a/libs/arrow-spi/build.gradle b/libs/arrow-spi/build.gradle index d14b7e88cfb8c..425f42f2333d9 100644 --- a/libs/arrow-spi/build.gradle +++ b/libs/arrow-spi/build.gradle @@ -26,6 +26,13 @@ dependencies { api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" implementation "commons-codec:commons-codec:${versions.commonscodec}" + + testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}" + testImplementation "junit:junit:${versions.junit}" + testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}" + testImplementation(project(":test:framework")) { + exclude group: 'org.opensearch', module: 'opensearch-arrow-spi' + } } tasks.named('forbiddenApisMain').configure { diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java index c1043e07b176f..a94dd8cd9caed 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java @@ -47,4 +47,12 @@ public interface StreamManager extends AutoCloseable { * @throws IllegalStateException if the stream has been cancelled or closed */ StreamReader getStreamReader(StreamTicket ticket); + + /** + * Gets the StreamTicketFactory instance associated with this StreamManager. + * By default, returns the singleton instance of StreamTicketFactory. + * + * @return the StreamTicketFactory instance + */ + StreamTicketFactory getStreamTicketFactory(); } diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java index 6dd443a2595cb..c5cd6f16adfdd 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java @@ -102,17 +102,13 @@ public interface StreamProducer extends Closeable { * * @return Estimated number of rows, or -1 if unknown */ - default int estimatedRowCount() { - return -1; - } + int estimatedRowCount(); /** * Task action name * @return action name */ - default String getAction() { - return ""; - } + String getAction(); /** * BatchedJob interface for producing stream data in batches. diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java index fe463ffdc4e3d..6d823f5773b1e 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java @@ -37,15 +37,4 @@ public interface StreamTicket { * @return Base64 encoded byte array containing the ticket information */ byte[] toBytes(); - - /** - * Creates a StreamTicket from its serialized byte representation. - * - * @param bytes Base64 encoded byte array containing ticket information - * @return a new StreamTicket instance - * @throws IllegalArgumentException if the input is invalid - */ - static StreamTicket fromBytes(byte[] bytes) { - throw new UnsupportedOperationException("Implementation must be provided by concrete class"); - } } diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java new file mode 100644 index 0000000000000..d587136c711e6 --- /dev/null +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.spi; + +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Factory interface for creating and managing StreamTicket instances. + * This factory provides methods to create and deserialize StreamTickets, + * ensuring consistent ticket creation. + */ +@ExperimentalApi +public interface StreamTicketFactory { + /** + * Generates a new StreamTicket + * + * @return A new StreamTicket instance + */ + StreamTicket generateTicket(); + + /** + * Deserializes a StreamTicket from its byte representation. + * + * @param bytes The byte array containing the serialized ticket data + * @return A StreamTicket instance reconstructed from the byte array + * @throws IllegalArgumentException if bytes is null or invalid + */ + StreamTicket fromBytes(byte[] bytes); +} diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/BaseFlightProducer.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/BaseFlightProducer.java index ef1b64ad74354..71b89f1b4ee47 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/BaseFlightProducer.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/BaseFlightProducer.java @@ -64,7 +64,7 @@ public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamM */ @Override public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { - StreamTicket streamTicket = FlightStreamTicket.fromBytes(ticket.getBytes()); + StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(ticket.getBytes()); try { FlightStreamManager.StreamProducerHolder streamProducerHolder; if (streamTicket.getNodeID().equals(flightClientManager.getLocalNodeId())) { @@ -127,7 +127,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l @Override public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { // TODO: this api should only be used internally - StreamTicket streamTicket = FlightStreamTicket.fromBytes(descriptor.getCommand()); + StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(descriptor.getCommand()); FlightStreamManager.StreamProducerHolder streamProducerHolder; if (streamTicket.getNodeID().equals(flightClientManager.getLocalNodeId())) { streamProducerHolder = streamManager.getStreamProducer(streamTicket); diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/DefaultStreamTicketFactory.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/DefaultStreamTicketFactory.java new file mode 100644 index 0000000000000..e8a85fe8af9e2 --- /dev/null +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/DefaultStreamTicketFactory.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.core; + +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; +import org.opensearch.common.annotation.ExperimentalApi; + +import java.util.UUID; +import java.util.function.Supplier; + +/** + * Default implementation of StreamTicketFactory + */ +@ExperimentalApi +public class DefaultStreamTicketFactory implements StreamTicketFactory { + + private final Supplier nodeId; + + /** + * Constructs a new DefaultStreamTicketFactory instance. + * + * @param nodeId A Supplier that provides the node ID for the StreamTicket + */ + public DefaultStreamTicketFactory(Supplier nodeId) { + this.nodeId = nodeId; + } + + /** + * Generates a new StreamTicket with a unique ticket ID. + * + * @return A new StreamTicket instance + */ + @Override + public StreamTicket generateTicket() { + return new FlightStreamTicket(generateUniqueTicket(), nodeId.get()); + } + + /** + * Deserializes a StreamTicket from its byte representation. + * + * @param bytes The byte array containing the serialized ticket data + * @return A StreamTicket instance reconstructed from the byte array + * @throws IllegalArgumentException if bytes is null or invalid + */ + @Override + public StreamTicket fromBytes(byte[] bytes) { + return FlightStreamTicket.fromBytes(bytes); + } + + private String generateUniqueTicket() { + return UUID.randomUUID().toString(); + } +} diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamManager.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamManager.java index 52f03ce2f2168..5b63c1334adaf 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamManager.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamManager.java @@ -17,6 +17,7 @@ import org.opensearch.arrow.spi.StreamProducer; import org.opensearch.arrow.spi.StreamReader; import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; import org.opensearch.common.SetOnce; import org.opensearch.common.cache.Cache; import org.opensearch.common.cache.CacheBuilder; @@ -33,7 +34,7 @@ */ public class FlightStreamManager implements StreamManager { - private final StreamTicketFactory ticketFactory; + private final DefaultStreamTicketFactory ticketFactory; private final FlightClientManager clientManager; private final Supplier allocatorSupplier; private final Cache streamProducers; @@ -54,7 +55,7 @@ public FlightStreamManager(Supplier allocatorSupplier, FlightCl .setExpireAfterWrite(expireAfter) .setMaximumWeight(MAX_PRODUCERS) .build(); - this.ticketFactory = new StreamTicketFactory(clientManager::getLocalNodeId); + this.ticketFactory = new DefaultStreamTicketFactory(clientManager::getLocalNodeId); } /** @@ -65,7 +66,7 @@ public FlightStreamManager(Supplier allocatorSupplier, FlightCl */ @Override public StreamTicket registerStream(StreamProducer provider, TaskId parentTaskId) { - FlightStreamTicket ticket = ticketFactory.createTicket(); + StreamTicket ticket = ticketFactory.generateTicket(); streamProducers.put(ticket.getTicketID(), new StreamProducerHolder(provider, allocatorSupplier.get())); return ticket; } @@ -81,6 +82,11 @@ public StreamReader getStreamReader(StreamTicket ticket) { return new FlightStreamReader(stream); } + @Override + public StreamTicketFactory getStreamTicketFactory() { + return ticketFactory; + } + /** * Retrieves the ArrowStreamProvider associated with the given StreamTicket. * diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamTicket.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamTicket.java index 2c37f5e160764..bbb588a460285 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamTicket.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/FlightStreamTicket.java @@ -53,7 +53,7 @@ public byte[] toBytes() { return Base64.getEncoder().encode(buffer.array()); } - public static StreamTicket fromBytes(byte[] bytes) { + static StreamTicket fromBytes(byte[] bytes) { if (bytes == null || bytes.length < 4) { throw new IllegalArgumentException("Invalid byte array input."); } diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/ProxyStreamProducer.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/ProxyStreamProducer.java index fe314fb0cb190..2a50b5023d4b3 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/ProxyStreamProducer.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/ProxyStreamProducer.java @@ -52,6 +52,24 @@ public BatchedJob createJob(BufferAllocator allocator) { return new ProxyBatchedJob(remoteStream); } + /** + * Provides an estimate of the total number of rows that will be produced. + */ + @Override + public int estimatedRowCount() { + // TODO get it from remote flight stream + return -1; + } + + /** + * Task action name + */ + @Override + public String getAction() { + // TODO get it from remote flight stream + return ""; + } + /** * Closes the remote FlightStream. */ diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/StreamTicketFactory.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/StreamTicketFactory.java deleted file mode 100644 index 47d8e667ef377..0000000000000 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/core/StreamTicketFactory.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.arrow.flight.core; - -import java.util.UUID; -import java.util.function.Supplier; - -class StreamTicketFactory { - private final Supplier nodeId; - - public StreamTicketFactory(Supplier nodeId) { - this.nodeId = nodeId; - } - - /** - * Generates a new StreamTicket with a unique ticket ID. - * - * @return A new StreamTicket instance - */ - public FlightStreamTicket createTicket() { - return new FlightStreamTicket(generateUniqueTicket(), nodeId.get()); - } - - private String generateUniqueTicket() { - return UUID.randomUUID().toString(); - } -} diff --git a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java index 905ad07ee3a77..1c42c71cf1d6e 100644 --- a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java +++ b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java @@ -135,14 +135,6 @@ public void testInitializeWithoutSecureTransportSettingsProvider() { }); } - public void testDoubleInitialization() { - flightService.initialize(clusterService, threadPool); - - flightService.initialize(clusterService, threadPool); - - assertNotNull(flightService.getStreamManager()); - } - public void testStopWithoutStart() { flightService.initialize(clusterService, threadPool); @@ -182,6 +174,11 @@ public void testLifecycleStateTransitions() throws Exception { assertEquals("CLOSED", testService.lifecycleState().toString()); } + @Override + public void tearDown() throws Exception { + super.tearDown(); + } + private void verifyServerRunning(FlightService flightService, int clientPort) throws InterruptedException { FlightClientBuilder builder = new FlightClientBuilder( "localhost", diff --git a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/core/BaseFlightProducerTests.java b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/core/BaseFlightProducerTests.java index f8598da033679..146bf4f01952a 100644 --- a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/core/BaseFlightProducerTests.java +++ b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/core/BaseFlightProducerTests.java @@ -44,6 +44,7 @@ public class BaseFlightProducerTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); streamManager = mock(FlightStreamManager.class); + when(streamManager.getStreamTicketFactory()).thenReturn(new DefaultStreamTicketFactory(() -> LOCAL_NODE_ID)); when(flightClientManager.getLocalNodeId()).thenReturn(LOCAL_NODE_ID); allocator = mock(BufferAllocator.class); streamProducer = mock(StreamProducer.class); diff --git a/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java b/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java index 0d7bbb36d15a1..4f308d641181e 100644 --- a/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java +++ b/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java @@ -48,6 +48,11 @@ public StreamReader getStreamReader(StreamTicket ticket) { return streamManager.getStreamReader(ticket); } + @Override + public StreamTicketFactory getStreamTicketFactory() { + return streamManager.getStreamTicketFactory(); + } + @Override public void close() throws Exception { streamManager.close(); diff --git a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java index cc9391bcacc14..ca427f3770407 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java @@ -198,7 +198,7 @@ public String getAction() { } }, searchContext.getTask().getParentTaskId()); StreamSearchResult streamSearchResult = searchContext.streamSearchResult(); - streamSearchResult.flights(List.of(new OSTicket(ticket))); + streamSearchResult.flights(List.of(new OSTicket(ticket.toBytes()))); return false; } } diff --git a/server/src/main/java/org/opensearch/search/stream/OSTicket.java b/server/src/main/java/org/opensearch/search/stream/OSTicket.java index 34a195bafb14f..33c303948b1bb 100644 --- a/server/src/main/java/org/opensearch/search/stream/OSTicket.java +++ b/server/src/main/java/org/opensearch/search/stream/OSTicket.java @@ -8,6 +8,7 @@ package org.opensearch.search.stream; +import org.opensearch.arrow.spi.StreamManager; import org.opensearch.arrow.spi.StreamTicket; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.core.common.io.stream.StreamInput; @@ -25,30 +26,28 @@ @ExperimentalApi public class OSTicket implements Writeable, ToXContentFragment { - private final StreamTicket streamTicket; + private final byte[] bytes; - public OSTicket(StreamTicket ticket) { - this.streamTicket = ticket; + public OSTicket(byte[] bytes) { + this.bytes = bytes; } public OSTicket(StreamInput in) throws IOException { - byte[] bytes = in.readByteArray(); - this.streamTicket = StreamTicket.fromBytes(bytes); + bytes = in.readByteArray(); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - byte[] bytes = streamTicket.toBytes(); return builder.value(new String(bytes, StandardCharsets.UTF_8)); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeByteArray(streamTicket.toBytes()); + out.writeByteArray(bytes); } @Override public String toString() { - return "OSTicket{" + "ticketID='" + streamTicket.getTicketID() + '\'' + ", nodeID='" + streamTicket.getNodeID() + '\'' + '}'; + return "OSTicket{" + new String(bytes, StandardCharsets.UTF_8) + "}"; } }