Skip to content

Commit

Permalink
Added support for cluster and flightInfo API
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhmaurya committed Oct 18, 2024
1 parent e7c70f4 commit fa9634d
Show file tree
Hide file tree
Showing 19 changed files with 673 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ public class OpenSearchNode implements TestClusterConfiguration {
private boolean isWorkingDirConfigured = false;
private String httpPort = "0";
private String transportPort = "0";
private String streamPort = "0";
private Path confPathData;
private String keystorePassword = "";
private boolean preserveDataDir = false;
Expand Down Expand Up @@ -1175,6 +1176,8 @@ private void createConfiguration() {
baseConfig.put("node.portsfile", "true");
baseConfig.put("http.port", httpPort);
baseConfig.put("transport.port", transportPort);
baseConfig.put("node.attr.transport.stream.port", streamPort);

// Default the watermarks to absurdly low to prevent the tests from failing on nodes without enough disk space
baseConfig.put("cluster.routing.allocation.disk.watermark.low", "1b");
baseConfig.put("cluster.routing.allocation.disk.watermark.high", "1b");
Expand Down Expand Up @@ -1447,6 +1450,10 @@ void setTransportPort(String transportPort) {
this.transportPort = transportPort;
}

void setStreamPort(String streamPort) {
this.streamPort = streamPort;
}

void setDataPath(Path dataPath) {
this.confPathData = dataPath;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class RunTask extends DefaultTestClustersTask {
public static final String CUSTOM_SETTINGS_PREFIX = "tests.opensearch.";
private static final int DEFAULT_HTTP_PORT = 9200;
private static final int DEFAULT_TRANSPORT_PORT = 9300;
private static final int DEFAULT_STREAM_PORT = 8815;
private static final int DEFAULT_DEBUG_PORT = 5005;
public static final String LOCALHOST_ADDRESS_PREFIX = "127.0.0.1:";

Expand Down Expand Up @@ -140,6 +141,8 @@ public void beforeStart() {
int debugPort = DEFAULT_DEBUG_PORT;
int httpPort = DEFAULT_HTTP_PORT;
int transportPort = DEFAULT_TRANSPORT_PORT;
int streamPort = DEFAULT_STREAM_PORT;

Map<String, String> additionalSettings = System.getProperties()
.entrySet()
.stream()
Expand All @@ -164,15 +167,19 @@ public void beforeStart() {
firstNode.setHttpPort(String.valueOf(httpPort));
httpPort++;
firstNode.setTransportPort(String.valueOf(transportPort));
firstNode.setStreamPort(String.valueOf(streamPort));
transportPort++;
streamPort++;
firstNode.setting("discovery.seed_hosts", LOCALHOST_ADDRESS_PREFIX + DEFAULT_TRANSPORT_PORT);
cluster.setPreserveDataDir(preserveData);
for (OpenSearchNode node : cluster.getNodes()) {
if (node != firstNode) {
node.setHttpPort(String.valueOf(httpPort));
httpPort++;
node.setTransportPort(String.valueOf(transportPort));
node.setStreamPort(String.valueOf(streamPort));
transportPort++;
streamPort++;
node.setting("discovery.seed_hosts", LOCALHOST_ADDRESS_PREFIX + DEFAULT_TRANSPORT_PORT);
}
additionalSettings.forEach(node::setting);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public interface ArrowStreamProvider {
*/
Task create(BufferAllocator allocator);

default int estimatedRowCount() {
return -1;
}

/**
* Represents a task for managing an Arrow stream.
*/
Expand Down
58 changes: 41 additions & 17 deletions libs/arrow/src/main/java/org/opensearch/arrow/StreamManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

package org.opensearch.arrow;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.common.annotation.ExperimentalApi;

import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

/**
* Abstract class for managing Arrow streams.
Expand All @@ -20,25 +22,27 @@
*/
@ExperimentalApi
public abstract class StreamManager implements AutoCloseable {
private final ConcurrentHashMap<StreamTicket, ArrowStreamProvider> streams;

private final ConcurrentHashMap<String, StreamHolder> streamProviders;
private final Supplier<BufferAllocator> allocatorSupplier;
/**
* Constructs a new StreamManager with an empty stream map.
*/
public StreamManager() {
this.streams = new ConcurrentHashMap<>();
public StreamManager(Supplier<BufferAllocator> allocatorSupplier) {
this.allocatorSupplier = allocatorSupplier;
this.streamProviders = new ConcurrentHashMap<>();
}

/**
* Registers a new stream with the given ArrowStreamProvider.
*
* @param factory The ArrowStreamProvider to register.
* @param provider The ArrowStreamProvider to register.
* @return A new StreamTicket for the registered stream.
*/
public StreamTicket registerStream(ArrowStreamProvider factory) {
StreamTicket ticket = generateUniqueTicket();
streams.put(ticket, factory);
return ticket;
public StreamTicket registerStream(ArrowStreamProvider provider) {
String ticket = generateUniqueTicket();
VectorSchemaRoot root = provider.create(allocatorSupplier.get()).init(allocatorSupplier.get());
streamProviders.put(ticket, new StreamHolder(provider, root));
return new StreamTicket(ticket, getNodeId());
}

/**
Expand All @@ -47,8 +51,8 @@ public StreamTicket registerStream(ArrowStreamProvider factory) {
* @param ticket The StreamTicket of the desired stream.
* @return The ArrowStreamProvider associated with the ticket, or null if not found.
*/
public ArrowStreamProvider getStream(StreamTicket ticket) {
return streams.get(ticket);
public StreamHolder getStreamProvider(StreamTicket ticket) {
return streamProviders.get(ticket.getTicketID());
}

/**
Expand All @@ -64,25 +68,27 @@ public ArrowStreamProvider getStream(StreamTicket ticket) {
*
* @param ticket The StreamTicket of the stream to remove.
*/
public void removeStream(StreamTicket ticket) {
streams.remove(ticket);
public void removeStreamProvider(StreamTicket ticket) {
streamProviders.remove(ticket.getTicketID());
}

/**
* Returns the map of all registered streams.
*
* @return A ConcurrentHashMap of all registered streams.
*/
public ConcurrentHashMap<StreamTicket, ArrowStreamProvider> getStreams() {
return streams;
public ConcurrentHashMap<String, StreamHolder> getStreamProviders() {
return streamProviders;
}

/**
* Generates a unique StreamTicket.
*
* @return A new, unique StreamTicket.
*/
public abstract StreamTicket generateUniqueTicket();
public abstract String generateUniqueTicket();

public abstract String getNodeId();

/**
* Closes the StreamManager and cancels all associated streams.
Expand All @@ -91,6 +97,24 @@ public ConcurrentHashMap<StreamTicket, ArrowStreamProvider> getStreams() {
*/
public void close() {
// TODO: logic to cancel all threads and clear the streamManager queue
streams.clear();
streamProviders.clear();
}

public static class StreamHolder {
private final ArrowStreamProvider provider;
private final VectorSchemaRoot root;

public StreamHolder(ArrowStreamProvider provider, VectorSchemaRoot root) {
this.provider = provider;
this.root = root;
}

public ArrowStreamProvider getProvider() {
return provider;
}

public VectorSchemaRoot getRoot() {
return root;
}
}
}
123 changes: 79 additions & 44 deletions libs/arrow/src/main/java/org/opensearch/arrow/StreamTicket.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,64 +8,99 @@

package org.opensearch.arrow;

import java.util.Arrays;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;
/**
* Represents a ticket for identifying and managing Arrow streams.
* This class encapsulates a byte array that serves as a unique identifier for a stream.
*/


public class StreamTicket {
private final byte[] bytes;

/**
* Constructs a new StreamTicket with the given byte array.
*
* @param bytes The byte array to use as the ticket identifier.
*/
public StreamTicket(byte[] bytes) {
this.bytes = bytes;
private final String ticketID;
private final String nodeID;

public StreamTicket(String ticketID, String nodeID) {
this.ticketID = ticketID;
this.nodeID = nodeID;
}

public String getTicketID() {
return ticketID;
}

public String getNodeID() {
return nodeID;
}

/**
* Retrieves the byte array representing this ticket.
*
* @return The byte array identifier of this ticket.
*/
public byte[] getBytes() {
return bytes;
public byte[] toBytes() {
byte[] ticketIDBytes = ticketID.getBytes(StandardCharsets.UTF_8);
byte[] nodeIDBytes = nodeID.getBytes(StandardCharsets.UTF_8);

if (ticketIDBytes.length > Short.MAX_VALUE || nodeIDBytes.length > Short.MAX_VALUE) {
throw new IllegalArgumentException("Field lengths exceed the maximum allowed size.");
}

ByteBuffer buffer = ByteBuffer.allocate(2 + ticketIDBytes.length + 2 + nodeIDBytes.length); // 2 bytes for length
buffer.putShort((short) ticketIDBytes.length);
buffer.put(ticketIDBytes);
buffer.putShort((short) nodeIDBytes.length);
buffer.put(nodeIDBytes);
return Base64.getEncoder().encode(buffer.array());
}

public static StreamTicket fromBytes(byte[] bytes) {
if (bytes == null || bytes.length < 4) {
throw new IllegalArgumentException("Invalid byte array input.");
}
ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(bytes));

short ticketIDLength = buffer.getShort();
if (ticketIDLength < 0) {
throw new IllegalArgumentException("Invalid ticketID length.");
}

byte[] ticketIDBytes = new byte[ticketIDLength];
if (buffer.remaining() < ticketIDLength) {
throw new IllegalArgumentException("Malformed byte array. Not enough data for ticketID.");
}
buffer.get(ticketIDBytes);

short nodeIDLength = buffer.getShort();
if (nodeIDLength < 0) {
throw new IllegalArgumentException("Invalid nodeID length.");
}

byte[] nodeIDBytes = new byte[nodeIDLength];
if (buffer.remaining() < nodeIDLength) {
throw new IllegalArgumentException("Malformed byte array.");
}
buffer.get(nodeIDBytes);

String ticketID = new String(ticketIDBytes, StandardCharsets.UTF_8);
String nodeID = new String(nodeIDBytes, StandardCharsets.UTF_8);

return new StreamTicket(ticketID, nodeID);
}

/**
* Computes the hash code for this StreamTicket.
*
* @return The hash code value for this object.
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(bytes);
return result;
return Objects.hash(ticketID, nodeID);
}

/**
* Compares this StreamTicket to the specified object for equality.
*
* @param obj The object to compare this StreamTicket against.
* @return true if the given object represents a StreamTicket equivalent to this ticket, false otherwise.
*/
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (!(obj instanceof StreamTicket)) {
return false;
}
StreamTicket other = (StreamTicket) obj;
return Arrays.equals(bytes, other.getBytes());
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
StreamTicket that = (StreamTicket) obj;
return Objects.equals(ticketID, that.ticketID) && Objects.equals(nodeID, that.nodeID);
}

@Override
public String toString() {
return "StreamTicket{ticketID='" + ticketID + "', nodeID='" + nodeID + "'}";
}
}

Loading

0 comments on commit fa9634d

Please sign in to comment.