Skip to content

Commit

Permalink
Provide option to disable subchannel retries to let PFLeafLB take con…
Browse files Browse the repository at this point in the history
…trol of retries.
  • Loading branch information
larry-safran committed Sep 26, 2024
1 parent 1d5d64f commit 96f9b7b
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 55 deletions.
13 changes: 13 additions & 0 deletions api/src/main/java/io/grpc/LoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ public abstract class LoadBalancer {
HEALTH_CONSUMER_LISTENER_ARG_KEY =
LoadBalancer.CreateSubchannelArgs.Key.create("internal:health-check-consumer-listener");

@Internal
public static final LoadBalancer.CreateSubchannelArgs.Key<Boolean>
DISABLE_SUBCHANNEL_RECONNECT_KEY =
LoadBalancer.CreateSubchannelArgs.Key.createWithDefault(
"internal:disable-subchannel-reconnect", Boolean.FALSE);

@Internal
public static final Attributes.Key<Boolean>
HAS_HEALTH_PRODUCER_LISTENER_KEY =
Expand Down Expand Up @@ -825,6 +831,13 @@ public String toString() {
.toString();
}

@Internal
public Object[][] getOptions() {
Object[][] retVal = new Object[customOptions.length][2];
System.arraycopy(retVal, 0, customOptions, 0, customOptions.length);
return retVal;
}

@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771")
public static final class Builder {

Expand Down
26 changes: 18 additions & 8 deletions core/src/main/java/io/grpc/internal/InternalSubchannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.grpc.InternalInstrumented;
import io.grpc.InternalLogId;
import io.grpc.InternalWithLogId;
import io.grpc.LoadBalancer;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
Expand Down Expand Up @@ -77,6 +78,7 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
private final CallTracer callsTracer;
private final ChannelTracer channelTracer;
private final ChannelLogger channelLogger;
private final boolean recconectDisabled;

private final List<ClientTransportFilter> transportFilters;

Expand Down Expand Up @@ -159,13 +161,15 @@ protected void handleNotInUse() {

private volatile Attributes connectedAddressAttributes;

InternalSubchannel(List<EquivalentAddressGroup> addressGroups, String authority, String userAgent,
BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Supplier<Stopwatch> stopwatchSupplier, SynchronizationContext syncContext, Callback callback,
InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer,
InternalLogId logId, ChannelLogger channelLogger,
List<ClientTransportFilter> transportFilters) {
InternalSubchannel(LoadBalancer.CreateSubchannelArgs args, String authority, String userAgent,
BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory,
ScheduledExecutorService scheduledExecutor,
Supplier<Stopwatch> stopwatchSupplier, SynchronizationContext syncContext,
Callback callback, InternalChannelz channelz, CallTracer callsTracer,
ChannelTracer channelTracer, InternalLogId logId,
ChannelLogger channelLogger, List<ClientTransportFilter> transportFilters) {
List<EquivalentAddressGroup> addressGroups = args.getAddresses();
Preconditions.checkNotNull(addressGroups, "addressGroups");
Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty");
checkListHasNoNulls(addressGroups, "addressGroups contains null entry");
Expand All @@ -187,6 +191,7 @@ protected void handleNotInUse() {
this.logId = Preconditions.checkNotNull(logId, "logId");
this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger");
this.transportFilters = transportFilters;
this.recconectDisabled = args.getOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY);
}

ChannelLogger getChannelLogger() {
Expand All @@ -196,7 +201,7 @@ ChannelLogger getChannelLogger() {
@Override
public ClientTransport obtainActiveTransport() {
ClientTransport savedTransport = activeTransport;
if (savedTransport != null) {
if (savedTransport != null && state.getState() != IDLE) {
return savedTransport;
}
syncContext.execute(new Runnable() {
Expand Down Expand Up @@ -289,6 +294,11 @@ public void run() {
}

gotoState(ConnectivityStateInfo.forTransientFailure(status));

if (recconectDisabled) {
return;
}

if (reconnectPolicy == null) {
reconnectPolicy = backoffPolicyProvider.get();
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) {
}

final InternalSubchannel internalSubchannel = new InternalSubchannel(
addressGroup,
CreateSubchannelArgs.newBuilder().setAddresses(addressGroup).build(),
authority, userAgent, backoffPolicyProvider, oobTransportFactory,
oobTransportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext,
// All callback methods are run from syncContext
Expand Down Expand Up @@ -1907,7 +1907,7 @@ void onNotInUse(InternalSubchannel is) {
}

final InternalSubchannel internalSubchannel = new InternalSubchannel(
args.getAddresses(),
args,
authority(),
userAgent,
backoffPolicyProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
private BackoffPolicy reconnectPolicy;
@Nullable
private ScheduledHandle reconnectTask = null;
private boolean serializingRetries = isSerializingRetries();
private final boolean serializingRetries = isSerializingRetries();

PickFirstLeafLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
Expand Down Expand Up @@ -234,9 +234,10 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo
return;
}

if (newState == IDLE) {
if (newState == IDLE && subchannelData.state == READY) {
helper.refreshNameResolution();
}

// If we are transitioning from a TRANSIENT_FAILURE to CONNECTING or IDLE we ignore this state
// transition and still keep the LB in TRANSIENT_FAILURE state. This is referred to as "sticky
// transient failure". Only a subchannel state change to READY will get the LB out of
Expand Down Expand Up @@ -291,7 +292,7 @@ void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo
}
}

if ( isPassComplete()) {
if (isPassComplete()) {
rawConnectivityState = TRANSIENT_FAILURE;
updateBalancingState(TRANSIENT_FAILURE,
new Picker(PickResult.withError(stateInfo.getStatus())));
Expand Down Expand Up @@ -462,11 +463,8 @@ public void requestConnection() {
requestConnection();
} else {
if (!addressIndex.isValid()) {
subchannelData.subchannel.shutdown(); // shutdown the previous subchannel
scheduleBackoff();
} else {
subchannelData.subchannel.shutdown(); // shutdown the previous subchannel
subchannels.remove(currentAddress);
requestConnection();
}
}
Expand Down Expand Up @@ -515,9 +513,10 @@ private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs)
HealthListener hcListener = new HealthListener();
final Subchannel subchannel = helper.createSubchannel(
CreateSubchannelArgs.newBuilder()
.setAddresses(Lists.newArrayList(
new EquivalentAddressGroup(addr, attrs)))
.addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener)
.setAddresses(Lists.newArrayList(
new EquivalentAddressGroup(addr, attrs)))
.addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener)
.addOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY, serializingRetries)
.build());
if (subchannel == null) {
log.warning("Was not able to create subchannel for " + addr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.grpc.InternalChannelz;
import io.grpc.InternalLogId;
import io.grpc.InternalWithLogId;
import io.grpc.LoadBalancer;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.InternalSubchannel.CallTracingTransport;
Expand Down Expand Up @@ -1381,7 +1382,9 @@ private void createInternalSubchannel(EquivalentAddressGroup ... addrs) {
InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY);
ChannelTracer subchannelTracer = new ChannelTracer(logId, 10,
fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel");
internalSubchannel = new InternalSubchannel(addressGroups, AUTHORITY, USER_AGENT,
internalSubchannel = new InternalSubchannel(
LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(),
AUTHORITY, USER_AGENT,
mockBackoffPolicyProvider, mockTransportFactory, fakeClock.getScheduledExecutorService(),
fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback,
channelz, CallTracer.getDefaultFactory().create(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assume.assumeTrue;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -74,7 +75,6 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -93,18 +93,22 @@
public class PickFirstLeafLoadBalancerTest {
public static final Status CONNECTION_ERROR =
Status.UNAVAILABLE.withDescription("Simulated connection error");

@Parameterized.Parameters(name = "{0}")
public static List<Boolean> enableHappyEyeballs() {
if (PickFirstLeafLoadBalancer.isSerializingRetries()) {
return Arrays.asList(false);
} else {
return Arrays.asList(false, true);
}
public static final String GRPC_SERIALIZE_RETRIES = "GRPC_SERIALIZE_RETRIES";

@Parameterized.Parameters(name = "{0}-{1}")
public static List<Object[]> data() {
return Arrays.asList(new Object[][] {
{false, false},
{false, true},
{true, false}});
}

@Parameterized.Parameter
@Parameterized.Parameter(value = 0)
public boolean serializeRetries;

@Parameterized.Parameter(value = 1)
public boolean enableHappyEyeballs;

private PickFirstLeafLoadBalancer loadBalancer;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
private static final Attributes.Key<String> FOO = Attributes.Key.create("foo");
Expand Down Expand Up @@ -142,14 +146,18 @@ public void uncaughtException(Thread t, Throwable e) {
private PickSubchannelArgs mockArgs;

private String originalHappyEyeballsEnabledValue;
private String originalSerializeRetriesValue;

@Before
public void setUp() {
assumeTrue(!serializeRetries || !enableHappyEyeballs); // they are not compatible
originalSerializeRetriesValue = System.getProperty(GRPC_SERIALIZE_RETRIES);
System.setProperty(GRPC_SERIALIZE_RETRIES, Boolean.toString(serializeRetries));

originalHappyEyeballsEnabledValue =
System.getProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS);
System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS,
!PickFirstLeafLoadBalancer.isSerializingRetries() && enableHappyEyeballs
? "true" : "false");
Boolean.toString(enableHappyEyeballs));

for (int i = 1; i <= 5; i++) {
SocketAddress addr = new FakeSocketAddress("server" + i);
Expand Down Expand Up @@ -182,6 +190,11 @@ public void setUp() {

@After
public void tearDown() {
if (originalSerializeRetriesValue == null) {
System.clearProperty(GRPC_SERIALIZE_RETRIES);
} else {
System.setProperty(GRPC_SERIALIZE_RETRIES, originalSerializeRetriesValue);
}
if (originalHappyEyeballsEnabledValue == null) {
System.clearProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS);
} else {
Expand Down Expand Up @@ -529,20 +542,7 @@ public void pickAfterStateChangeAfterResolution() {
inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture());
stateListeners[0] = stateListenerCaptor.getValue();

if (enableHappyEyeballs) {
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture());
stateListeners[1] = stateListenerCaptor.getValue();
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture());
stateListeners[2] = stateListenerCaptor.getValue();
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture());
stateListeners[3] = stateListenerCaptor.getValue();
}

reset(mockHelper);

stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(READY));
stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
Expand All @@ -552,11 +552,23 @@ public void pickAfterStateChangeAfterResolution() {
stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING));

Status error = Status.UNAVAILABLE.withDescription("boom!");
reset(mockHelper);

if (enableHappyEyeballs) {
for (SubchannelStateListener listener : stateListeners) {
listener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
}
stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture());
stateListeners[1] = stateListenerCaptor.getValue();
stateListeners[1].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture());
stateListeners[2] = stateListenerCaptor.getValue();
stateListeners[2].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
forwardTimeByConnectionDelay();
inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture());
stateListeners[3] = stateListenerCaptor.getValue();
stateListeners[3].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
forwardTimeByConnectionDelay();
} else {
stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error));
for (int i = 1; i < stateListeners.length; i++) {
Expand Down Expand Up @@ -1533,7 +1545,7 @@ public void updateAddresses_intersecting_ready() {

@Test
public void updateAddresses_intersecting_transient_failure() {
Assume.assumeTrue(!isSerializingRetries());
assumeTrue(!isSerializingRetries());

// Starting first connection attempt
InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2,
Expand Down Expand Up @@ -1799,7 +1811,7 @@ public void updateAddresses_identical_ready() {

@Test
public void updateAddresses_identical_transient_failure() {
Assume.assumeTrue(!isSerializingRetries());
assumeTrue(!isSerializingRetries());

InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2,
mockSubchannel3, mockSubchannel4);
Expand Down Expand Up @@ -2314,7 +2326,7 @@ public void ready_then_transient_failure_again() {

@Test
public void happy_eyeballs_trigger_connection_delay() {
Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
// Starting first connection attempt
InOrder inOrder = inOrder(mockHelper, mockSubchannel1,
mockSubchannel2, mockSubchannel3, mockSubchannel4);
Expand Down Expand Up @@ -2359,7 +2371,7 @@ public void happy_eyeballs_trigger_connection_delay() {

@Test
public void happy_eyeballs_connection_results_happen_after_get_to_end() {
Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs

InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3);
Status error = Status.UNAUTHENTICATED.withDescription("simulated failure");
Expand Down Expand Up @@ -2412,7 +2424,7 @@ public void happy_eyeballs_connection_results_happen_after_get_to_end() {

@Test
public void happy_eyeballs_pick_pushes_index_over_end() {
Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs

InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3,
mockSubchannel2n2, mockSubchannel3n2);
Expand Down Expand Up @@ -2490,7 +2502,7 @@ public void happy_eyeballs_pick_pushes_index_over_end() {

@Test
public void happy_eyeballs_fail_then_trigger_connection_delay() {
Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs
// Starting first connection attempt
InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3);
assertEquals(IDLE, loadBalancer.getConcludedConnectivityState());
Expand Down

0 comments on commit 96f9b7b

Please sign in to comment.