Skip to content

Commit

Permalink
Cache the shard routings with no weight for faster access (#12989)
Browse files Browse the repository at this point in the history
* Cache the shard routings with no weight for faster access

The list of shards to run a query is determined for every request and
the weight of the nodes guides the shard selection. Currently, IndexRoutingTable
caches the shard routings with weight for faster access. But, during cases
where the fail open option is enabled, shards with no weight is also returned
lower in the order along with shards with weights. They will be used as fall
back if the shards with weights can't be used due to some error.

The shard routing with no weight is not cached, hence it does a full loop for
every request, this impacts the search latency when the number of shards to
query or the number of nodes in the cluster is high. The latency impact is
very high when both the number of shards and the number of nodes are high.

This change introduces a caching mechanism for shard routing with no weights
similar to the existing cache for shard routing with weights.

Signed-off-by: Prabhakar Sithanandam <[email protected]>
Co-authored-by: Prabhakar Sithanandam <[email protected]>
  • Loading branch information
backslasht and Prabhakar Sithanandam authored Apr 3, 2024
1 parent b7396e1 commit fb5d036
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -63,7 +62,6 @@
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.emptyMap;

Expand Down Expand Up @@ -96,8 +94,8 @@ public class IndexShardRoutingTable implements Iterable<ShardRouting> {
private volatile Map<AttributesKey, AttributesRoutings> initializingShardsByAttributes = emptyMap();
private final Object shardsByAttributeMutex = new Object();
private final Object shardsByWeightMutex = new Object();
private volatile Map<WeightedRoutingKey, List<ShardRouting>> activeShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, List<ShardRouting>> initializingShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> activeShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> initializingShardsByWeight = emptyMap();

private static final Logger logger = LogManager.getLogger(IndexShardRoutingTable.class);

Expand Down Expand Up @@ -249,7 +247,7 @@ public List<ShardRouting> assignedShards() {
return this.assignedShards;
}

public Map<WeightedRoutingKey, List<ShardRouting>> getActiveShardsByWeight() {
public Map<WeightedRoutingKey, WeightedShardRoutings> getActiveShardsByWeight() {
return activeShardsByWeight;
}

Expand Down Expand Up @@ -338,23 +336,7 @@ public ShardIterator activeInitializingShardsWeightedIt(
// append shards for attribute value with weight zero, so that shard search requests can be tried on
// shard copies in case of request failure from other attribute values.
if (isFailOpenEnabled) {
try {
Stream<String> keys = weightedRouting.weights()
.entrySet()
.stream()
.filter(entry -> entry.getValue().intValue() == WeightedRoutingMetadata.WEIGHED_AWAY_WEIGHT)
.map(Map.Entry::getKey);
keys.forEach(key -> {
ShardIterator iterator = onlyNodeSelectorActiveInitializingShardsIt(weightedRouting.attributeName() + ":" + key, nodes);
while (iterator.remaining() > 0) {
ordered.add(iterator.nextOrNull());
}
});
} catch (IllegalArgumentException e) {
// this exception is thrown by {@link onlyNodeSelectorActiveInitializingShardsIt} in case count of shard
// copies found is zero
logger.debug("no shard copies found for shard id [{}] for node attribute with weight zero", shardId);
}
ordered.addAll(activeInitializingShardsWithoutWeights(weightedRouting, nodes, defaultWeight));
}

return new PlainShardIterator(shardId, ordered);
Expand All @@ -378,6 +360,18 @@ private List<ShardRouting> activeInitializingShardsWithWeights(
return orderedListWithDistinctShards;
}

private List<ShardRouting> activeInitializingShardsWithoutWeights(
WeightedRouting weightedRouting,
DiscoveryNodes nodes,
double defaultWeight
) {
List<ShardRouting> ordered = new ArrayList<>(getActiveShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
if (!allInitializingShards.isEmpty()) {
ordered.addAll(getInitializingShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
}
return ordered.stream().distinct().collect(Collectors.toList());
}

/**
* Returns a list containing shard routings ordered using weighted round-robin scheduling.
*/
Expand Down Expand Up @@ -949,20 +943,60 @@ public int hashCode() {
}
}

/**
* Holder class for shard routing(s) which are classified and stored based on their weights.
*
* @opensearch.api
*/
@PublicApi(since = "2.14.0")
public static class WeightedShardRoutings {
private final List<ShardRouting> shardRoutingsWithWeight;
private final List<ShardRouting> shardRoutingWithoutWeight;

public WeightedShardRoutings(List<ShardRouting> shardRoutingsWithWeight, List<ShardRouting> shardRoutingWithoutWeight) {
this.shardRoutingsWithWeight = Collections.unmodifiableList(shardRoutingsWithWeight);
this.shardRoutingWithoutWeight = Collections.unmodifiableList(shardRoutingWithoutWeight);
}

public List<ShardRouting> getShardRoutingsWithWeight() {
return shardRoutingsWithWeight;
}

public List<ShardRouting> getShardRoutingWithoutWeight() {
return shardRoutingWithoutWeight;
}
}

/**
* *
* Gets active shard routing from memory if available, else calculates and put it in memory.
*/
private List<ShardRouting> getActiveShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> shardRoutings = activeShardsByWeight.get(key);
if (shardRoutings == null) {
synchronized (shardsByWeightMutex) {
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
activeShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
}
if (activeShardsByWeight.get(key) == null) {
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return activeShardsByWeight.get(key).getShardRoutingsWithWeight();
}

private List<ShardRouting> getActiveShardsWithoutWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
if (activeShardsByWeight.get(key) == null) {
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return activeShardsByWeight.get(key).getShardRoutingWithoutWeight();
}

private void populateActiveShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
List<ShardRouting> nonWeightedRoutings = activeShards.stream()
.filter(shard -> !weightedRoutings.contains(shard))
.collect(Collectors.toUnmodifiableList());
synchronized (shardsByWeightMutex) {
activeShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
.immutableMap();
}
return shardRoutings;
}

/**
Expand All @@ -971,14 +1005,34 @@ private List<ShardRouting> getActiveShardsByWeight(WeightedRouting weightedRouti
*/
private List<ShardRouting> getInitializingShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> shardRoutings = initializingShardsByWeight.get(key);
if (shardRoutings == null) {
synchronized (shardsByWeightMutex) {
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
initializingShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
}
if (initializingShardsByWeight.get(key) == null) {
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return initializingShardsByWeight.get(key).getShardRoutingsWithWeight();
}

private List<ShardRouting> getInitializingShardsWithoutWeight(
WeightedRouting weightedRouting,
DiscoveryNodes nodes,
double defaultWeight
) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
if (initializingShardsByWeight.get(key) == null) {
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return initializingShardsByWeight.get(key).getShardRoutingWithoutWeight();
}

private void populateInitializingShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(allInitializingShards, weightedRouting, nodes, defaultWeight);
List<ShardRouting> nonWeightedRoutings = allInitializingShards.stream()
.filter(shard -> !weightedRoutings.contains(shard))
.collect(Collectors.toUnmodifiableList());
synchronized (shardsByWeightMutex) {
initializingShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
.immutableMap();
}
return shardRoutings;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,18 @@ public void testWeightedRoutingWithDifferentWeights() {
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
shardRouting = shardIterator.nextOrNull();
assertNotNull(shardRouting);
assertFalse(Arrays.asList("node2", "node1").contains(shardRouting.currentNodeId()));
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());

weights = Map.of("zone1", -1.0, "zone2", 0.0, "zone3", 1.0);
weightedRouting = new WeightedRouting("zone", weights);
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());

weights = Map.of("zone1", 3.0, "zone2", 2.0, "zone3", 0.0);
weightedRouting = new WeightedRouting("zone", weights);
Expand All @@ -711,8 +720,138 @@ public void testWeightedRoutingWithDifferentWeights() {
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
shardRouting = shardIterator.nextOrNull();
assertNotNull(shardRouting);
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
} finally {
terminate(threadPool);
}
}

public void testWeightedRoutingWithInitializingShards() {
TestThreadPool threadPool = null;
try {
Settings.Builder settings = Settings.builder()
.put("cluster.routing.allocation.node_concurrent_recoveries", 10)
.put("cluster.routing.allocation.awareness.attributes", "zone");
AllocationService strategy = createAllocationService(settings.build());

Metadata metadata = Metadata.builder()
.put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
.build();

RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build();

ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
.metadata(metadata)
.routingTable(routingTable)
.build();

threadPool = new TestThreadPool("testThatOnlyNodesSupport");
ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);

Map<String, String> node1Attributes = new HashMap<>();
node1Attributes.put("zone", "zone1");
Map<String, String> node2Attributes = new HashMap<>();
node2Attributes.put("zone", "zone2");
Map<String, String> node3Attributes = new HashMap<>();
node3Attributes.put("zone", "zone3");

DiscoveryNodes nodes = DiscoveryNodes.builder()
.add(newNode("node1", unmodifiableMap(node1Attributes)))
.add(newNode("node2", unmodifiableMap(node2Attributes)))
.add(newNode("node3", unmodifiableMap(node3Attributes)))
.localNodeId("node1")
.build();
clusterState = ClusterState.builder(clusterState).nodes(nodes).build();
clusterState = strategy.reroute(clusterState, "reroute");

// Making the first shard as active
clusterState = startInitializingShardsAndReroute(strategy, clusterState);
// Making the second shard as active
clusterState = startRandomInitializingShard(clusterState, strategy);

String[] startedNodes = new String[2];
String[] startedZones = new String[2];
String initializingNode = null;
String initializingZone = null;
int i = 0;
for (ShardRouting shard : clusterState.routingTable().allShards()) {
if (shard.initializing()) {
initializingNode = shard.currentNodeId();
initializingZone = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone");

} else {
startedNodes[i] = shard.currentNodeId();
startedZones[i++] = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone");
}
}

Map<String, Double> weights = Map.of(startedZones[0], 1.0, initializingZone, 1.0, startedZones[1], 0.0);
WeightedRouting weightedRouting = new WeightedRouting("zone", weights);

// With fail open enabled set to false, we expect 2 shard routing, first one started, followed by initializing
ShardIterator shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);

assertEquals(2, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

// With fail open enabled set to true, we expect 3 shard routing, first one started, followed by initializing, third one started
// with zero weight
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);

assertEquals(3, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId());

weights = Map.of(initializingZone, 1.0, startedZones[0], 0.0, startedZones[1], 0.0);
weightedRouting = new WeightedRouting("zone", weights);

// only initializing shard has weight with fail open true
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

// only initializing shard has weight with fail open false
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

weights = Map.of(initializingZone, 0.0, startedZones[0], 1.0, startedZones[1], 0.0);
weightedRouting = new WeightedRouting("zone", weights);

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

} finally {
terminate(threadPool);
}
Expand Down

0 comments on commit fb5d036

Please sign in to comment.