Skip to content

Commit

Permalink
setup for join on coordinator
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Oct 17, 2024
1 parent 130d554 commit 1c8f75a
Show file tree
Hide file tree
Showing 31 changed files with 630 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
*/
@ExperimentalApi
public abstract class StreamManager implements AutoCloseable {

public abstract void setFlightClient(Object flightClient);

private final ConcurrentHashMap<StreamTicket, ArrowStreamProvider> streams;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ protected void doStart() {
final Location location = Location.forGrpcInsecure(host, port);
server = FlightServer.builder(allocator, location, producer).build();
client = FlightClient.builder(allocator, location).build();
streamManager.setFlightClient(client);
server.start();
logger.info("Arrow Flight server started successfully");
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class FlightStreamManager extends StreamManager {

private FlightClient flightClient;
private final FlightClient flightClient;

/**
* Constructs a new FlightStreamManager.
Expand All @@ -42,12 +42,6 @@ public FlightStreamManager(FlightClient flightClient) {
* @param ticket The StreamTicket identifying the desired stream.
* @return The VectorSchemaRoot associated with the given ticket.
*/
@Override
public void setFlightClient(Object flightClient) {
assert flightClient instanceof FlightClient;
this.flightClient = (FlightClient) flightClient;
}

@Override
public VectorSchemaRoot getVectorSchemaRoot(StreamTicket ticket) {
// TODO: for remote streams, register streams in cluster state with node details
Expand All @@ -58,6 +52,7 @@ public VectorSchemaRoot getVectorSchemaRoot(StreamTicket ticket) {

@Override
public StreamTicket generateUniqueTicket() {
// return new StreamTicket("123".getBytes()) {};
return new StreamTicket(UUID.randomUUID().toString().getBytes()) {};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ ReducedQueryPhase reducedQueryPhase(
for (SearchPhaseResult entry : queryResults) {
QuerySearchResult result = entry.queryResult();
if (entry instanceof StreamSearchResult) {
tickets.addAll(((StreamSearchResult)entry).getFlightTickets());
tickets.addAll(((StreamSearchResult) entry).getFlightTickets());
}
from = result.from();
// sorted queries can set the size to 0 if they have enough competitive hits.
Expand Down Expand Up @@ -728,7 +728,7 @@ public static final class ReducedQueryPhase {
this.from = from;
this.isEmptyResult = isEmptyResult;
this.sortValueFormats = sortValueFormats;
this.osTickets = osTickets;
this.osTickets = osTickets;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import org.opensearch.search.query.QuerySearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ScrollQuerySearchResult;
import org.opensearch.search.query.StreamQueryResponse;
import org.opensearch.search.stream.StreamSearchResult;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteClusterService;
Expand Down Expand Up @@ -244,7 +243,6 @@ public void sendExecuteQuery(
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.


if (request.isStreamRequest()) {
Writeable.Reader<SearchPhaseResult> reader = StreamSearchResult::new;
final ActionListener handler = responseWrapper.apply(connection, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public static SearchType fromId(byte id) {
} else if (id == 1 || id == 3) { // TODO this bwc layer can be removed once this is back-ported to 5.3 QUERY_AND_FETCH is removed
// now
return QUERY_THEN_FETCH;
} else if (id == 5) {
} else if (id == 5) {
return STREAM;
} else {
throw new IllegalArgumentException("No search type for [" + id + "]");
Expand Down
128 changes: 109 additions & 19 deletions server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,22 @@
package org.opensearch.action.search;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopFieldDocs;
import org.opensearch.arrow.StreamManager;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.stream.OSTicket;
import org.opensearch.search.stream.StreamSearchResult;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.search.stream.join.Join;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.transport.Transport;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -66,14 +58,58 @@
import java.util.function.BiFunction;

/**
* Async transport action for query then fetch
* Stream at coordinator layer
*
* @opensearch.internal
*/
class StreamAsyncAction extends SearchQueryThenFetchAsyncAction {

public StreamAsyncAction(Logger logger, SearchTransportService searchTransportService, BiFunction<String, String, Transport.Connection> nodeIdToConnection, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Map<String, Set<String>> indexRoutings, SearchPhaseController searchPhaseController, Executor executor, QueryPhaseResultConsumer resultConsumer, SearchRequest request, ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext, Tracer tracer) {
super(logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, resultConsumer, request, listener, shardsIts, timeProvider, clusterState, task, clusters, searchRequestContext, tracer);
private final StreamManager streamManager;
private final Join join;

public StreamAsyncAction(
Logger logger,
SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter,
Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
SearchPhaseController searchPhaseController,
Executor executor,
QueryPhaseResultConsumer resultConsumer,
SearchRequest request,
ActionListener<SearchResponse> listener,
GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchResponse.Clusters clusters,
SearchRequestContext searchRequestContext,
Tracer tracer,
StreamManager streamManager
) {
super(
logger,
searchTransportService,
nodeIdToConnection,
aliasFilter,
concreteIndexBoosts,
indexRoutings,
searchPhaseController,
executor,
resultConsumer,
request,
listener,
shardsIts,
timeProvider,
clusterState,
task,
clusters,
searchRequestContext,
tracer
);
this.streamManager = streamManager;
this.join = searchRequestContext.getRequest().source().getJoin();
}

@Override
Expand All @@ -82,7 +118,8 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> r
}

class StreamSearchReducePhase extends SearchPhase {
private SearchPhaseContext context;
private final SearchPhaseContext context;

protected StreamSearchReducePhase(String name, SearchPhaseContext context) {
super(name);
this.context = context;
Expand All @@ -92,24 +129,77 @@ protected StreamSearchReducePhase(String name, SearchPhaseContext context) {
public void run() {
context.execute(new StreamReduceAction(context, this));
}
};
}

class StreamReduceAction extends AbstractRunnable {
private SearchPhaseContext context;
private final SearchPhaseContext context;
private SearchPhase phase;

StreamReduceAction(SearchPhaseContext context, SearchPhase phase) {
this.context = context;

}

@Override
protected void doRun() throws Exception {

List<OSTicket> tickets = new ArrayList<>();
for (SearchPhaseResult entry : results.getAtomicArray().asList()) {
if (entry instanceof StreamSearchResult) {
tickets.addAll(((StreamSearchResult) entry).getFlightTickets());
((StreamSearchResult) entry).getFlightTickets().forEach(osTicket -> {
System.out.println("Ticket: " + new String(osTicket.getBytes(), StandardCharsets.UTF_8));
// VectorSchemaRoot root = streamManager.getVectorSchemaRoot(osTicket);
// System.out.println("Number of rows: " + root.getRowCount());
});
}
}
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(),null, null, null, false, false, 1, Collections.emptyList(), tickets);

// shard/table, schema

// ticket should contain which IndexShard it comes from
// based on the search request, perform join using these tickets

// join operate on 2 indexes using condition
// join contain already contain the schema, or at least hold the schema data

// StreamTicket joinResult = streamManager.registerStream((allocator) -> new ArrowStreamProvider.Task() {
// @Override
// public VectorSchemaRoot init(BufferAllocator allocator) {
// IntVector docIDVector = new IntVector("docID", allocator);
// FieldVector[] vectors = new FieldVector[]{
// docIDVector
// };
// VectorSchemaRoot root = new VectorSchemaRoot(Arrays.asList(vectors));
// return root;
// }
//
// public void run(VectorSchemaRoot root, ArrowStreamProvider.FlushSignal flushSignal) {
// // TODO perform join algo
// IntVector docIDVector = (IntVector) root.getVector("docID");
// for (int i = 0; i < 10; i++) {
// docIDVector.set(i, i);
// }
// root.setRowCount(10);
// flushSignal.awaitConsumption(1000);
// }
//
// @Override
// public void onCancel() {
//
// }
// });

InternalSearchResponse internalSearchResponse = new InternalSearchResponse(
SearchHits.empty(),
null,
null,
null,
false,
false,
1,
Collections.emptyList(),
List.of(new OSTicket("456".getBytes(), null))
);
context.sendSearchResponse(internalSearchResponse, results.getAtomicArray());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@
import java.util.stream.StreamSupport;

import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.opensearch.action.search.SearchType.*;
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;

/**
Expand Down Expand Up @@ -1324,7 +1325,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
task,
clusters,
searchRequestContext,
tracer
tracer,
searchService.getStreamManager()
);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ public class FeatureFlags {
);

public static final String ARROW_STREAMS = "opensearch.experimental.feature.arrow.streams.enabled";
public static final Setting<Boolean> ARROW_STREAMS_SETTING = Setting.boolSetting(
ARROW_STREAMS,
true,
Property.NodeScope
);
public static final Setting<Boolean> ARROW_STREAMS_SETTING = Setting.boolSetting(ARROW_STREAMS, true, Property.NodeScope);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
Expand Down
9 changes: 6 additions & 3 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;
import static org.opensearch.common.util.FeatureFlags.*;
import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS_SETTING;
import static org.opensearch.common.util.FeatureFlags.BACKGROUND_TASK_EXECUTION_EXPERIMENTAL;
import static org.opensearch.common.util.FeatureFlags.TELEMETRY;
import static org.opensearch.env.NodeEnvironment.collectFileCacheDataPath;
import static org.opensearch.index.ShardIndexingPressureSettings.SHARD_INDEXING_PRESSURE_ENABLED_ATTRIBUTE_KEY;
import static org.opensearch.indices.RemoteStoreSettings.CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED;
Expand Down Expand Up @@ -1364,12 +1366,13 @@ protected Node(
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Only one StreamManagerPlugin can be installed. Found: %d", streamManagerPlugins.size()
"Only one StreamManagerPlugin can be installed. Found: %d",
streamManagerPlugins.size()

)
);
}
if(!streamManagerPlugins.isEmpty()) {
if (!streamManagerPlugins.isEmpty()) {
streamManager = streamManagerPlugins.get(0).getStreamManager();
logger.info("StreamManager initialized");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
};
}

static String[] addString(String[] originalArray, String newString) {
String[] newArray = new String[originalArray.length + 1];
System.arraycopy(originalArray, 0, newArray, 0, originalArray.length);
newArray[newArray.length - 1] = newString;
return newArray;
}

/**
* Parses the rest request on top of the SearchRequest, preserving values that are not overridden by the rest request.
*
Expand All @@ -163,6 +170,10 @@ public static void parseSearchRequest(
searchRequest.source().parseXContent(requestContentParser, true);
}

if (searchRequest.source().getJoin() != null) {
searchRequest.indices(addString(searchRequest.indices(), searchRequest.source().getJoin().getIndex()));
}

final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize());
searchRequest.setBatchedReduceSize(batchedReduceSize);
if (request.hasParam("pre_filter_shard_size")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ final class DefaultSearchContext extends SearchContext {
private final IndexShard indexShard;
private final ClusterService clusterService;
private final IndexService indexService;
private final StreamManager streamManager;
private final StreamManager streamManager;
private final ContextIndexSearcher searcher;
private final DfsSearchResult dfsResult;
private final QuerySearchResult queryResult;
Expand Down
Loading

0 comments on commit 1c8f75a

Please sign in to comment.