Skip to content

Commit

Permalink
move StreamTicket to an interface
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya committed Nov 24, 2024
1 parent a4d3d99 commit 12ad318
Show file tree
Hide file tree
Showing 28 changed files with 270 additions and 265 deletions.
9 changes: 1 addition & 8 deletions libs/arrow-spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
* GitHub history for details.
*/

testingConventions.enabled = false
dependencies {
implementation project(':libs:opensearch-common')
implementation project(':libs:opensearch-core')
api "org.apache.arrow:arrow-vector:${versions.arrow}"
api "org.apache.arrow:arrow-format:${versions.arrow}"
Expand All @@ -29,13 +29,6 @@ dependencies {
api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}"

implementation "commons-codec:commons-codec:${versions.commonscodec}"

testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}"
testImplementation "junit:junit:${versions.junit}"
testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}"
testImplementation(project(":test:framework")) {
exclude group: 'org.opensearch', module: 'opensearch-arrow-spi'
}
}

tasks.named('forbiddenApisMain').configure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,14 @@ public interface StreamManager extends AutoCloseable {
StreamTicket registerStream(StreamProducer producer, TaskId parentTaskId);

/**
* Creates a stream iterator for consuming Arrow data using a valid ticket.
* Creates a stream reader for consuming Arrow data using a valid ticket.
* This method may trigger lazy initialization of Arrow resources if this is
* the first access to the stream.
*
* @param ticket The StreamTicket obtained from registerStream
* @return A StreamIterator for consuming the Arrow data
* @return A StreamReader for consuming the Arrow data
* @throws IllegalArgumentException if the ticket is invalid
* @throws IllegalStateException if the stream has been cancelled or closed
*/
StreamReader getStreamIterator(StreamTicket ticket);

/**
* Generates a unique ticket identifier for stream registration.
*
* @return A unique string identifier for use in StreamTicket creation
*/
String generateUniqueTicket();

/**
* Returns the identifier of the node where this StreamManager instance is running.
* This node ID is embedded in stream tickets to enable routing of stream requests
* in a distributed environment.
*
* @return The identifier of the local node
*/
String getLocalNodeId();
StreamReader getStreamReader(StreamTicket ticket);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Represents a producer of Arrow streams. The producer first needs to define the job by implementing this interface and
* then register the job with the {@link StreamManager#registerStream(StreamProducer, TaskId)}, which will return {@link StreamTicket}
* which can be distributed to the consumer. The consumer can then use the ticket to retrieve the stream using
* {@link StreamManager#getStreamIterator(StreamTicket)} and then consume the stream using {@link StreamReader}.
* {@link StreamManager#getStreamReader(StreamTicket)} and then consume the stream using {@link StreamReader}.
* <p>
* BatchedJob supports streaming of intermediate results, allowing consumers to begin processing data before the entire
* result set is generated. This is particularly useful for memory-intensive operations or when dealing with large datasets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* StreamTicket ticket = streamManager.registerStream(producer, taskId);
*
* // consumer
* StreamIterator iterator = streamManager.getStreamIterator(ticket);
* StreamReader iterator = streamManager.getStreamReader(ticket);
* try (VectorSchemaRoot root = iterator.getRoot()) {
* while (iterator.next()) {
* VarCharVector idVector = (VarCharVector)root.getVector("id");
Expand Down
125 changes: 10 additions & 115 deletions libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,147 +10,42 @@

import org.opensearch.common.annotation.ExperimentalApi;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;

/**
* A ticket that uniquely identifies a stream. This ticket is created when a producer registers
* a stream with {@link StreamManager} and can be used by consumers to retrieve the stream using
* {@link StreamManager#getStreamIterator(StreamTicket)}.
* {@link StreamManager#getStreamReader(StreamTicket)}.
*/
@ExperimentalApi
public class StreamTicket {
private static final int MAX_TOTAL_SIZE = 4096;
private static final int MAX_ID_LENGTH = 256;

private final String ticketID;
private final String nodeID;

/**
* Constructs a new StreamTicket with the specified ticket ID and node ID.
*
* @param ticketID the unique identifier for the ticket
* @param nodeID the identifier of the node associated with this ticket
*/
public StreamTicket(String ticketID, String nodeID) {
this.ticketID = ticketID;
this.nodeID = nodeID;
}

public interface StreamTicket {
/**
* Returns the ticket ID associated with this stream ticket.
*
* @return the ticket ID string
*/
public String getTicketID() {
return ticketID;
}
String getTicketID();

/**
* Returns the node ID associated with this stream ticket.
*
* @return the node ID string
*/
public String getNodeID() {
return nodeID;
}
String getNodeID();

/**
* Serializes this ticket into a Base64 encoded byte array that can be deserialized using
* {@link #fromBytes(byte[])}.
* Serializes this ticket into a Base64 encoded byte array.
*
* @return Base64 encoded byte array containing the ticket information
*/
public byte[] toBytes() {
byte[] ticketIDBytes = ticketID.getBytes(StandardCharsets.UTF_8);
byte[] nodeIDBytes = nodeID.getBytes(StandardCharsets.UTF_8);

if (ticketIDBytes.length > Short.MAX_VALUE || nodeIDBytes.length > Short.MAX_VALUE) {
throw new IllegalArgumentException("Field lengths exceed the maximum allowed size.");
}
ByteBuffer buffer = ByteBuffer.allocate(2 + ticketIDBytes.length + 2 + nodeIDBytes.length);
buffer.putShort((short) ticketIDBytes.length);
buffer.putShort((short) nodeIDBytes.length);
buffer.put(ticketIDBytes);
buffer.put(nodeIDBytes);
return Base64.getEncoder().encode(buffer.array());
}
byte[] toBytes();

/**
* Creates a StreamTicket from its serialized byte representation. The byte array should be
* a Base64 encoded string containing the ticketID and nodeID.
* Creates a StreamTicket from its serialized byte representation.
*
* @param bytes Base64 encoded byte array containing ticket information
* @return a new StreamTicket instance
* @throws IllegalArgumentException if the input is invalid
*/
public static StreamTicket fromBytes(byte[] bytes) {
if (bytes == null || bytes.length < 4) {
throw new IllegalArgumentException("Invalid byte array input.");
}

if (bytes.length > MAX_TOTAL_SIZE) {
throw new IllegalArgumentException("Input exceeds maximum allowed size");
}

ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(bytes));

short ticketIDLength = buffer.getShort();
if (ticketIDLength < 0 || ticketIDLength > MAX_ID_LENGTH) {
throw new IllegalArgumentException("Invalid ticketID length: " + ticketIDLength);
}

short nodeIDLength = buffer.getShort();
if (nodeIDLength < 0 || nodeIDLength > MAX_ID_LENGTH) {
throw new IllegalArgumentException("Invalid nodeID length: " + nodeIDLength);
}
byte[] ticketIDBytes = new byte[ticketIDLength];
if (buffer.remaining() < ticketIDLength) {
throw new IllegalArgumentException("Malformed byte array. Not enough data for TicketId.");
}
buffer.get(ticketIDBytes);
byte[] nodeIDBytes = new byte[nodeIDLength];
if (buffer.remaining() < nodeIDLength) {
throw new IllegalArgumentException("Malformed byte array. Not enough data for NodeId.");
}
buffer.get(nodeIDBytes);
String ticketID = new String(ticketIDBytes, StandardCharsets.UTF_8);
String nodeID = new String(nodeIDBytes, StandardCharsets.UTF_8);
return new StreamTicket(ticketID, nodeID);
}

/**
* Returns a hash code value for this StreamTicket.
*
* @return a hash code value for this object
*/
@Override
public int hashCode() {
return Objects.hash(ticketID, nodeID);
}

/**
* Indicates whether some other object is "equal to" this one.
*
* @param obj the reference object with which to compare
* @return true if this object is the same as the obj argument; false otherwise
*/
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
StreamTicket that = (StreamTicket) obj;
return Objects.equals(ticketID, that.ticketID) && Objects.equals(nodeID, that.nodeID);
}

/**
* Returns a string representation of this StreamTicket.
*
* @return a string representation of this object
*/
@Override
public String toString() {
return "StreamTicket{ticketID='" + ticketID + "', nodeID='" + nodeID + "'}";
static StreamTicket fromBytes(byte[] bytes) {
throw new UnsupportedOperationException("Implementation must be provided by concrete class");
}
}
2 changes: 1 addition & 1 deletion modules/arrow-flight-rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ tasks.named('thirdPartyAudit').configure {
'org.apache.parquet.schema.Types$ListBuilder',
'org.apache.parquet.schema.Types$PrimitiveBuilder'
)

ignoreViolations(
// Guava internal classes
'com.google.common.cache.Striped64',
Expand All @@ -124,5 +123,6 @@ tasks.named('thirdPartyAudit').configure {
'org.apache.arrow.memory.util.MemoryUtil$1',
'org.apache.arrow.memory.util.hash.MurmurHasher',
'org.apache.arrow.memory.util.hash.SimpleHasher'

)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.arrow.flight;

import org.opensearch.arrow.flight.bootstrap.FlightStreamPluginImpl;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand All @@ -22,7 +23,6 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.arrow.flight.bootstrap.FlightStreamPluginImpl;
import org.opensearch.plugins.SecureTransportSettingsProvider;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
import org.apache.arrow.util.VisibleForTesting;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.SetOnce;
import org.opensearch.common.lifecycle.AbstractLifecycleComponent;
import org.opensearch.common.settings.Settings;
import org.opensearch.arrow.flight.bootstrap.client.FlightClientManager;
import org.opensearch.arrow.flight.bootstrap.server.FlightServerBuilder;
import org.opensearch.arrow.flight.bootstrap.server.ServerConfig;
Expand All @@ -27,6 +22,11 @@
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;
import org.opensearch.arrow.flight.core.BaseFlightProducer;
import org.opensearch.arrow.flight.core.FlightStreamManager;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.SetOnce;
import org.opensearch.common.lifecycle.AbstractLifecycleComponent;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.SecureTransportSettingsProvider;
import org.opensearch.threadpool.ThreadPool;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.opensearch.arrow.flight.bootstrap;

import org.opensearch.arrow.flight.BaseFlightStreamPlugin;
import org.opensearch.arrow.flight.bootstrap.server.ServerConfig;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand All @@ -21,8 +23,6 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.arrow.flight.BaseFlightStreamPlugin;
import org.opensearch.arrow.flight.bootstrap.server.ServerConfig;
import org.opensearch.plugins.SecureTransportSettingsProvider;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import org.apache.arrow.flight.OpenSearchFlightClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.VisibleForTesting;
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;

import java.util.Map;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.arrow.flight.bootstrap.server;

import org.apache.arrow.flight.Location;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
Expand Down Expand Up @@ -102,6 +103,7 @@ public ServerConfig() {}
*
* @param settings The OpenSearch settings to initialize the server with
*/
@SuppressForbidden(reason = "required for arrow allocator")
public static void init(Settings settings) {
System.setProperty("arrow.allocation.manager.type", ARROW_ALLOCATION_MANAGER_TYPE.get(settings));
System.setProperty("arrow.enable_null_check_for_get", Boolean.toString(ARROW_ENABLE_NULL_CHECK_FOR_GET.get(settings)));
Expand Down Expand Up @@ -191,6 +193,7 @@ private static class Netty4Configs {

public static final Setting<Boolean> NETTY_TRY_UNSAFE = Setting.boolSetting("io.netty.tryUnsafe", true, Setting.Property.NodeScope);

@SuppressForbidden(reason = "required for netty allocator configuration")
public static void init(Settings settings) {
System.setProperty("io.netty.allocator.numDirectArenas", Integer.toString(NETTY_ALLOCATOR_NUM_DIRECT_ARENAS.get(settings)));
System.setProperty("io.netty.noUnsafe", Boolean.toString(NETTY_NO_UNSAFE.get(settings)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Locale;
import java.util.function.Supplier;

import io.netty.handler.ssl.ApplicationProtocolConfig;
Expand Down Expand Up @@ -64,8 +65,8 @@ public SslContext getServerSslContext() {
(PrivilegedExceptionAction<SslContext>) () -> io.netty.handler.ssl.SslContextBuilder.forServer(
parameters.keyManagerFactory()
)
.sslProvider(SslProvider.valueOf(parameters.sslProvider().toUpperCase()))
.clientAuth(ClientAuth.valueOf(parameters.clientAuth().toUpperCase()))
.sslProvider(SslProvider.valueOf(parameters.sslProvider().toUpperCase(Locale.ROOT)))
.clientAuth(ClientAuth.valueOf(parameters.clientAuth().toUpperCase(Locale.ROOT)))
.protocols(parameters.protocols())
.ciphers(parameters.cipherSuites(), SupportedCipherSuiteFilter.INSTANCE)
.sessionCacheSize(0)
Expand Down Expand Up @@ -101,7 +102,7 @@ public SslContext getClientSslContext() {
.get();
return AccessController.doPrivileged(
(PrivilegedExceptionAction<SslContext>) () -> io.netty.handler.ssl.SslContextBuilder.forClient()
.sslProvider(SslProvider.valueOf(parameters.sslProvider().toUpperCase()))
.sslProvider(SslProvider.valueOf(parameters.sslProvider().toUpperCase(Locale.ROOT)))
.protocols(parameters.protocols())
.ciphers(parameters.cipherSuites(), SupportedCipherSuiteFilter.INSTANCE)
.applicationProtocolConfig(
Expand Down
Loading

0 comments on commit 12ad318

Please sign in to comment.