From bc784c0ef9f7727497f8543492979c60d5ce5ef8 Mon Sep 17 00:00:00 2001 From: Larry Safran Date: Tue, 26 Sep 2023 17:31:58 -0700 Subject: [PATCH] Revert "Change Round Robin and WeightedRoundRobin into petiole policies (#10528)" (#10575) This reverts commit e1334eae7bba39d85a952bc5ab5aeb4cb05a56d8. --- .../grpc/internal/ManagedChannelImplTest.java | 2 +- .../java/io/grpc/internal/TestUtils.java | 5 - examples/android/strictmode/app/build.gradle | 1 - .../android/strictmode/app/proguard-rules.pro | 1 - util/build.gradle | 12 +- .../io/grpc/util/MultiChildLoadBalancer.java | 241 +++-------- .../io/grpc/util/RoundRobinLoadBalancer.java | 234 +++++++++-- .../OutlierDetectionLoadBalancerTest.java | 4 +- .../grpc/util/RoundRobinLoadBalancerTest.java | 220 +++++----- .../java/io/grpc/util/AbstractTestHelper.java | 156 ------- xds/build.gradle | 3 +- .../grpc/xds/ClusterManagerLoadBalancer.java | 92 +---- .../xds/WeightedRoundRobinLoadBalancer.java | 243 ++++------- .../java/io/grpc/xds/orca/OrcaOobUtil.java | 8 +- .../WeightedRoundRobinLoadBalancerTest.java | 391 ++++++++---------- 15 files changed, 629 insertions(+), 984 deletions(-) delete mode 100644 util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index df35afae163..da2bc072afc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -161,7 +161,7 @@ /** Unit tests for {@link ManagedChannelImpl}. */ @RunWith(JUnit4.class) // TODO(creamsoup) remove backward compatible check when fully migrated -@SuppressWarnings({"deprecation", "DataFlowIssue"}) +@SuppressWarnings("deprecation") public class ManagedChannelImplTest { private static final int DEFAULT_PORT = 447; diff --git a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java index 02df28f2e70..974f36e595c 100644 --- a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java +++ b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java @@ -24,7 +24,6 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; -import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -144,10 +143,6 @@ public Runnable answer(InvocationOnMock invocation) throws Throwable { return captor; } - public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { - return new EquivalentAddressGroup(eag.getAddresses()); - } - private TestUtils() { } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 85e283b1137..c00b8fbd99b 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -53,7 +53,6 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-core:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-okhttp:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-protobuf-lite:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-stub:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION diff --git a/examples/android/strictmode/app/proguard-rules.pro b/examples/android/strictmode/app/proguard-rules.pro index d5715fd16cf..1507a526787 100644 --- a/examples/android/strictmode/app/proguard-rules.pro +++ b/examples/android/strictmode/app/proguard-rules.pro @@ -15,4 +15,3 @@ -dontwarn javax.naming.** -dontwarn okio.** -dontwarn sun.misc.Unsafe - diff --git a/util/build.gradle b/util/build.gradle index cdd32e0ceb5..a05c55b27bb 100644 --- a/util/build.gradle +++ b/util/build.gradle @@ -1,6 +1,5 @@ plugins { id "java-library" - id "java-test-fixtures" id "maven-publish" id "me.champeau.jmh" @@ -20,18 +19,11 @@ dependencies { implementation libraries.animalsniffer.annotations, libraries.guava - testImplementation libraries.guava.testlib, - testFixtures(project(':grpc-api')), + testImplementation testFixtures(project(':grpc-api')), testFixtures(project(':grpc-core')), project(':grpc-testing') + testImplementation libraries.guava.testlib - testFixturesApi project(':grpc-core') - testFixturesImplementation libraries.guava, - libraries.junit, - libraries.mockito.core, - testFixtures(project(':grpc-api')), - testFixtures(project(':grpc-core')), - project(':grpc-testing') jmh project(':grpc-testing') signature libraries.signature.java diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index 2f0aa04cf9d..8f2269af261 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -16,29 +16,25 @@ package io.grpc.util; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import io.grpc.Attributes; import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.Status; -import io.grpc.internal.PickFirstLoadBalancerProvider; -import java.util.Collection; -import java.util.Collections; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; import java.util.HashMap; -import java.util.List; import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -50,34 +46,23 @@ @Internal public abstract class MultiChildLoadBalancer extends LoadBalancer { + @VisibleForTesting + public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName()); private final Map childLbStates = new HashMap<>(); private final Helper helper; + protected final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; // Set to true if currently in the process of handling resolved addresses. - @VisibleForTesting - boolean resolvingAddresses; - - protected final PickFirstLoadBalancerProvider pickFirstLbProvider = - new PickFirstLoadBalancerProvider(); - + private boolean resolvingAddresses; protected MultiChildLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger.log(Level.FINE, "Created"); } - @SuppressWarnings("ReferenceEquality") - protected static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { - if (eag.getAttributes() == Attributes.EMPTY) { - return eag; - } else { - return new EquivalentAddressGroup(eag.getAddresses()); - } - } - - protected abstract SubchannelPicker getSubchannelPicker( - Map childPickers); - protected SubchannelPicker getInitialPicker() { return EMPTY_PICKER; } @@ -86,43 +71,11 @@ protected SubchannelPicker getErrorPicker(Status error) { return new FixedResultPicker(PickResult.withError(error)); } - @VisibleForTesting - protected Collection getChildLbStates() { - return childLbStates.values(); - } - - protected ChildLbState getChildLbState(Object key) { - if (key == null) { - return null; - } - return childLbStates.get(key); - } - - protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { - return getChildLbState(stripAttrs(eag)); - } - - /** - * Override to utilize parsing of the policy configuration or alternative helper/lb generation. - */ - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { - Map childLbMap = new HashMap<>(); - List addresses = resolvedAddresses.getAddresses(); - Object policyConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); - for (EquivalentAddressGroup eag : addresses) { - EquivalentAddressGroup strippedEag = stripAttrs(eag); // keys need to be just addresses - if (!childLbMap.containsKey(strippedEag)) { - childLbMap.put(strippedEag, - createChildLbState(strippedEag, policyConfig, getInitialPicker())); - } - } - return childLbMap; - } + protected abstract Map getPolicySelectionMap( + ResolvedAddresses resolvedAddresses); - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker) { - return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); - } + protected abstract SubchannelPicker getSubchannelPicker( + Map childPickers); @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @@ -134,61 +87,25 @@ public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } } - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - checkArgument(key instanceof EquivalentAddressGroup, "key is wrong type"); - - // Retrieve the non-stripped version - EquivalentAddressGroup eag = null; - for (EquivalentAddressGroup equivalentAddressGroup : resolvedAddresses.getAddresses()) { - if (stripAttrs(equivalentAddressGroup).equals(key)) { - eag = equivalentAddressGroup; - break; - } - } - - checkNotNull(eag, key.toString() + " no longer present in load balancer children"); - - return resolvedAddresses.toBuilder() - .setAddresses(Collections.singletonList(eag)) - .setLoadBalancingPolicyConfig(childConfig) - .build(); - } - - - private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - Map newChildren = createChildLbMap(resolvedAddresses); - - if (newChildren.isEmpty()) { - handleNameResolutionError(Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. " + resolvedAddresses)); - return false; - } - - // Do adds and updates - for (Map.Entry entry : newChildren.entrySet()) { + Map newChildPolicies = getPolicySelectionMap(resolvedAddresses); + for (Map.Entry entry : newChildPolicies.entrySet()) { final Object key = entry.getKey(); - LoadBalancerProvider childPolicyProvider = entry.getValue().getPolicyProvider(); + LoadBalancerProvider childPolicyProvider = entry.getValue().getProvider(); Object childConfig = entry.getValue().getConfig(); if (!childLbStates.containsKey(key)) { - childLbStates.put(key, entry.getValue()); + childLbStates.put(key, new ChildLbState(key, childPolicyProvider, getInitialPicker())); } else { - // Reuse the existing one - ChildLbState existingChildLbState = childLbStates.get(key); - if (existingChildLbState.isDeactivated()) { - existingChildLbState.reactivate(childPolicyProvider); - } + childLbStates.get(key).reactivate(childPolicyProvider); } - LoadBalancer childLb = childLbStates.get(key).lb; - childLb.handleResolvedAddresses(getChildAddresses(key, resolvedAddresses, childConfig)); + ResolvedAddresses childAddresses = + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); + childLb.handleResolvedAddresses(childAddresses); } - - // Do removals - for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { - if (!newChildren.containsKey(key)) { + for (Object key : childLbStates.keySet()) { + if (!newChildPolicies.containsKey(key)) { childLbStates.get(key).deactivate(); } } @@ -222,10 +139,10 @@ public void shutdown() { childLbStates.clear(); } - protected void updateOverallBalancingState() { + private void updateOverallBalancingState() { ConnectivityState overallState = null; final Map childPickers = new HashMap<>(); - for (ChildLbState childLbState : getChildLbStates()) { + for (ChildLbState childLbState : childLbStates.values()) { if (childLbState.deactivated) { continue; } @@ -238,7 +155,7 @@ protected void updateOverallBalancingState() { } @Nullable - protected static ConnectivityState aggregateState( + private static ConnectivityState aggregateState( @Nullable ConnectivityState overallState, ConnectivityState childState) { if (overallState == null) { return childState; @@ -255,109 +172,67 @@ protected static ConnectivityState aggregateState( return overallState; } - protected Helper getHelper() { - return helper; - } - - protected void removeChild(Object key) { - childLbStates.remove(key); - } - - - public class ChildLbState { + private final class ChildLbState { private final Object key; - private final Object config; private final GracefulSwitchLoadBalancer lb; private LoadBalancerProvider policyProvider; private ConnectivityState currentState = CONNECTING; private SubchannelPicker currentPicker; private boolean deactivated; + @Nullable + ScheduledHandle deletionTimer; - public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker) { + ChildLbState(Object key, LoadBalancerProvider policyProvider, SubchannelPicker initialPicker) { this.key = key; this.policyProvider = policyProvider; lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper()); lb.switchTo(policyProvider); currentPicker = initialPicker; - config = childConfig; - } - - - @Override - public String toString() { - return "Address = " + key - + ", state = " + currentState - + ", picker type: " + currentPicker.getClass() - + ", lb: " + lb.delegate().getClass() - + (deactivated ? ", deactivated" : ""); - } - - public Object getKey() { - return key; - } - - Object getConfig() { - return config; - } - - public LoadBalancerProvider getPolicyProvider() { - return policyProvider; - } - - protected Subchannel getSubchannels(PickSubchannelArgs args) { - return getCurrentPicker().pickSubchannel(args).getSubchannel(); - } - - ConnectivityState getCurrentState() { - return currentState; } - public SubchannelPicker getCurrentPicker() { - return currentPicker; - } - - public boolean isDeactivated() { - return deactivated; - } - - @VisibleForTesting - LoadBalancer getLb() { - return this.lb; - } - - protected void setDeactivated() { - deactivated = true; - } - - protected void deactivate() { + void deactivate() { if (deactivated) { return; } - shutdown(); - childLbStates.remove(key); + class DeletionTask implements Runnable { + @Override + public void run() { + shutdown(); + childLbStates.remove(key); + } + } + + deletionTimer = + syncContext.schedule( + new DeletionTask(), + DELAYED_CHILD_DELETION_TIME_MINUTES, + TimeUnit.MINUTES, + timeService); deactivated = true; logger.log(Level.FINE, "Child balancer {0} deactivated", key); } - protected void reactivate(LoadBalancerProvider policyProvider) { + void reactivate(LoadBalancerProvider policyProvider) { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + deactivated = false; + logger.log(Level.FINE, "Child balancer {0} reactivated", key); + } if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) { Object[] objects = { key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()}; logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects); lb.switchTo(policyProvider); this.policyProvider = policyProvider; - } else { - logger.log(Level.FINE, "Child balancer {0} reactivated", key); } - - deactivated = false; } - protected void shutdown() { + void shutdown() { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + } lb.shutdown(); - this.currentState = SHUTDOWN; logger.log(Level.FINE, "Child balancer {0} deleted", key); } diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 9873e3e451d..56097084928 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -16,9 +16,11 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; @@ -35,10 +37,13 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nonnull; @@ -47,23 +52,131 @@ * EquivalentAddressGroup}s from the {@link NameResolver}. */ @Internal -public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { +public class RoundRobinLoadBalancer extends LoadBalancer { @VisibleForTesting static final Attributes.Key> STATE_INFO = Attributes.Key.create("state-info"); + private final Helper helper; + private final Map subchannels = + new HashMap<>(); private final Random random; private ConnectivityState currentState; protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); public RoundRobinLoadBalancer(Helper helper) { - super(helper); + this.helper = checkNotNull(helper, "helper"); this.random = new Random(); } @Override - protected SubchannelPicker getSubchannelPicker(Map childPickers) { - throw new UnsupportedOperationException(); // local updateOverallBalancingState doesn't use this + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } + + List servers = resolvedAddresses.getAddresses(); + Set currentAddrs = subchannels.keySet(); + Map latestAddrs = stripAttrs(servers); + Set removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet()); + + for (Map.Entry latestEntry : + latestAddrs.entrySet()) { + EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey(); + EquivalentAddressGroup originalAddressGroup = latestEntry.getValue(); + Subchannel existingSubchannel = subchannels.get(strippedAddressGroup); + if (existingSubchannel != null) { + // EAG's Attributes may have changed. + existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup)); + continue; + } + // Create new subchannels for new addresses. + + // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel + // doesn't need them. They're describing the resolved server list but we're not taking + // any action based on this information. + Attributes.Builder subchannelAttrs = Attributes.newBuilder() + // NB(lukaszx0): because attributes are immutable we can't set new value for the key + // after creation but since we can mutate the values we leverage that and set + // AtomicReference which will allow mutating state info for given channel. + .set(STATE_INFO, + new Ref<>(ConnectivityStateInfo.forNonError(IDLE))); + + final Subchannel subchannel = checkNotNull( + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(originalAddressGroup) + .setAttributes(subchannelAttrs.build()) + .build()), + "subchannel"); + subchannel.start(new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo state) { + processSubchannelState(subchannel, state); + } + }); + subchannels.put(strippedAddressGroup, subchannel); + subchannel.requestConnection(); + } + + ArrayList removedSubchannels = new ArrayList<>(); + for (EquivalentAddressGroup addressGroup : removedAddrs) { + removedSubchannels.add(subchannels.remove(addressGroup)); + } + + // Update the picker before shutting down the subchannels, to reduce the chance of the race + // between picking a subchannel and shutting it down. + updateBalancingState(); + + // Shutdown removed subchannels + for (Subchannel removedSubchannel : removedSubchannels) { + shutdownSubchannel(removedSubchannel); + } + + return true; + } + + @Override + public void handleNameResolutionError(Status error) { + if (currentState != READY) { + updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error)); + } + } + + private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { + return; + } + if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { + helper.refreshNameResolution(); + } + if (stateInfo.getState() == IDLE) { + subchannel.requestConnection(); + } + Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); + if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) { + if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) { + return; + } + } + subchannelStateRef.value = stateInfo; + updateBalancingState(); + } + + private void shutdownSubchannel(Subchannel subchannel) { + subchannel.shutdown(); + getSubchannelStateInfoRef(subchannel).value = + ConnectivityStateInfo.forNonError(SHUTDOWN); + } + + @Override + public void shutdown() { + for (Subchannel subchannel : getSubchannels()) { + shutdownSubchannel(subchannel); + } + subchannels.clear(); } private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); @@ -71,27 +184,29 @@ protected SubchannelPicker getSubchannelPicker(Map chi /** * Updates picker with the list of active subchannels (state == READY). */ - @Override - protected void updateOverallBalancingState() { - List activeList = getReadyChildren(); + @SuppressWarnings("ReferenceEquality") + private void updateBalancingState() { + List activeList = filterNonFailingSubchannels(getSubchannels()); if (activeList.isEmpty()) { - // No READY subchannels - - // RRLB will request connection immediately on subchannel IDLE. + // No READY subchannels, determine aggregate state and error status boolean isConnecting = false; - for (ChildLbState childLbState : getChildLbStates()) { - ConnectivityState state = childLbState.getCurrentState(); - if (state == CONNECTING || state == IDLE) { + Status aggStatus = EMPTY_OK; + for (Subchannel subchannel : getSubchannels()) { + ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value; + // This subchannel IDLE is not because of channel IDLE_TIMEOUT, + // in which case LB is already shutdown. + // RRLB will request connection immediately on subchannel IDLE. + if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { isConnecting = true; - break; + } + if (aggStatus == EMPTY_OK || !aggStatus.isOk()) { + aggStatus = stateInfo.getStatus(); } } - - if (isConnecting) { - updateBalancingState(CONNECTING, new EmptyPicker(Status.OK)); - } else { - updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); - } + updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE, + // If all subchannels are TRANSIENT_FAILURE, return the Status associated with + // an arbitrary subchannel, otherwise return OK. + new EmptyPicker(aggStatus)); } else { updateBalancingState(READY, createReadyPicker(activeList)); } @@ -99,39 +214,72 @@ protected void updateOverallBalancingState() { private void updateBalancingState(ConnectivityState state, RoundRobinPicker picker) { if (state != currentState || !picker.isEquivalentTo(currentPicker)) { - getHelper().updateBalancingState(state, picker); + helper.updateBalancingState(state, picker); currentState = state; currentPicker = picker; } } - protected RoundRobinPicker createReadyPicker(Collection children) { + protected RoundRobinPicker createReadyPicker(List activeList) { // initialize the Picker to a random start index to ensure that a high frequency of Picker // churn does not skew subchannel selection. - int startIndex = random.nextInt(children.size()); + int startIndex = random.nextInt(activeList.size()); + return new ReadyPicker(activeList, startIndex); + } - List pickerList = new ArrayList<>(); - for (ChildLbState child : children) { - SubchannelPicker picker = child.getCurrentPicker(); - pickerList.add(picker); + /** + * Filters out non-ready subchannels. + */ + private static List filterNonFailingSubchannels( + Collection subchannels) { + List readySubchannels = new ArrayList<>(subchannels.size()); + for (Subchannel subchannel : subchannels) { + if (isReady(subchannel)) { + readySubchannels.add(subchannel); + } } - - return new ReadyPicker(pickerList, startIndex); + return readySubchannels; } /** - * Filters out non-ready and deactivated child load balancers (subchannels). + * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and + * remove all attributes. The values are the original EAGs. */ - private List getReadyChildren() { - List activeChildren = new ArrayList<>(); - for (ChildLbState child : getChildLbStates()) { - if (!child.isDeactivated() && child.getCurrentState() == READY) { - activeChildren.add(child); - } + private static Map stripAttrs( + List groupList) { + Map addrs = new HashMap<>(groupList.size() * 2); + for (EquivalentAddressGroup group : groupList) { + addrs.put(stripAttrs(group), group); } - return activeChildren; + return addrs; + } + + private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + return new EquivalentAddressGroup(eag.getAddresses()); + } + + @VisibleForTesting + protected Collection getSubchannels() { + return subchannels.values(); + } + + private static Ref getSubchannelStateInfoRef( + Subchannel subchannel) { + return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); + } + + // package-private to avoid synthetic access + static boolean isReady(Subchannel subchannel) { + return getSubchannelStateInfoRef(subchannel).value.getState() == READY; + } + + private static Set setsDifference(Set a, Set b) { + Set aCopy = new HashSet<>(a); + aCopy.removeAll(b); + return aCopy; } + // Only subclasses are ReadyPicker or EmptyPicker public abstract static class RoundRobinPicker extends SubchannelPicker { public abstract boolean isEquivalentTo(RoundRobinPicker picker); } @@ -141,11 +289,11 @@ static class ReadyPicker extends RoundRobinPicker { private static final AtomicIntegerFieldUpdater indexUpdater = AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); - private final List list; // non-empty + private final List list; // non-empty @SuppressWarnings("unused") private volatile int index; - public ReadyPicker(List list, int startIndex) { + public ReadyPicker(List list, int startIndex) { Preconditions.checkArgument(!list.isEmpty(), "empty list"); this.list = list; this.index = startIndex - 1; @@ -153,7 +301,7 @@ public ReadyPicker(List list, int startIndex) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - return list.get(nextIndex()).pickSubchannel(args); + return PickResult.withSubchannel(nextSubchannel()); } @Override @@ -161,7 +309,7 @@ public String toString() { return MoreObjects.toStringHelper(ReadyPicker.class).add("list", list).toString(); } - private int nextIndex() { + private Subchannel nextSubchannel() { int size = list.size(); int i = indexUpdater.incrementAndGet(this); if (i >= size) { @@ -169,11 +317,11 @@ private int nextIndex() { i %= size; indexUpdater.compareAndSet(this, oldi, i); } - return i; + return list.get(i); } @VisibleForTesting - List getList() { + List getList() { return list; } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index ac5bd8b98c4..13f13421a1e 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -512,7 +512,7 @@ public void successRateOneOutlier_configChange() { loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); - generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12); + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); // Move forward in time to a point where the detection timer has fired. forwardTime(config); @@ -546,7 +546,7 @@ public void successRateOneOutlier_unejected() { assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); // Now we produce more load, but the subchannel start working and is no longer an outlier. - generateLoad(ImmutableMap.of(), 12); + generateLoad(ImmutableMap.of(), 8); // Move forward in time to a point where the detection timer has fired. fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 3b7f6599d03..23b6e1c10c8 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -22,21 +22,23 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.util.RoundRobinLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -53,19 +55,16 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; -import io.grpc.internal.TestUtils; -import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; +import io.grpc.util.RoundRobinLoadBalancer.Ref; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -76,8 +75,10 @@ import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; /** Unit test for {@link RoundRobinLoadBalancer}. */ @RunWith(JUnit4.class) @@ -88,9 +89,7 @@ public class RoundRobinLoadBalancerTest { private RoundRobinLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); - private final Map, Subchannel> subchannels = - new ConcurrentHashMap<>(); - private final Map mockToRealSubChannelMap = new HashMap<>(); + private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); private final Map subchannelStateListeners = Maps.newLinkedHashMap(); private final Attributes affinity = @@ -102,7 +101,8 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor stateCaptor; @Captor private ArgumentCaptor createArgsCaptor; - private Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper())); + @Mock + private Helper mockHelper; @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -113,14 +113,32 @@ public void setUp() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); + Subchannel sc = mock(Subchannel.class); + subchannels.put(Arrays.asList(eag), sc); } - loadBalancer = new RoundRobinLoadBalancer(mockHelper); - } + when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) + .then(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + doAnswer( + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put( + subchannel, (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); - private boolean acceptAddresses(List eagList, Attributes attrs) { - return loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(eagList).setAttributes(attrs).build()); + loadBalancer = new RoundRobinLoadBalancer(mockHelper); } @After @@ -130,9 +148,10 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { - boolean addressesAccepted = acceptAddresses(servers, affinity); - assertThat(addressesAccepted).isTrue(); final Subchannel readySubchannel = subchannels.values().iterator().next(); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -159,6 +178,10 @@ public void pickAfterResolved() throws Exception { @Test public void pickAfterResolvedUpdatedHosts() throws Exception { + Subchannel removedSubchannel = mock(Subchannel.class); + Subchannel oldSubchannel = mock(Subchannel.class); + Subchannel newSubchannel = mock(Subchannel.class); + Attributes.Key key = Attributes.Key.create("check-that-it-is-propagated"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); @@ -170,13 +193,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { EquivalentAddressGroup newEag = new EquivalentAddressGroup( newAddr, Attributes.newBuilder().set(key, "newattr").build()); - Subchannel removedSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(removedEag).build()); - Subchannel oldSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(oldEag1).build()); - Subchannel newSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(newEag).build()); - subchannels.put(Collections.singletonList(removedEag), removedSubchannel); subchannels.put(Collections.singletonList(oldEag1), oldSubchannel); subchannels.put(Collections.singletonList(newEag), newSubchannel); @@ -185,7 +201,9 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = acceptAddresses(currentServers, affinity); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) + .build()); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -200,11 +218,8 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(removedSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(oldSubchannel); + assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, + oldSubchannel); // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -217,15 +232,13 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); verify(removedSubchannel, times(1)).shutdown(); + deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel); + assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, + newSubchannel); - verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); @@ -237,26 +250,29 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); assertThat(addressesAccepted).isTrue(); - - // TODO figure out if this method testing the right things - - ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel(); + Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + Ref subchannelStateInfo = subchannel.getAttributes().get( + STATE_INFO); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); + assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, + ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(childLbState.getCurrentState()).isEqualTo(READY); + assertThat(subchannelStateInfo.value).isEqualTo( + ConnectivityStateInfo.forNonError(READY)); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); @@ -264,7 +280,8 @@ public void pickAfterStateChange() throws Exception { deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); verify(subchannel, times(2)).requestConnection(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -274,14 +291,15 @@ public void pickAfterStateChange() throws Exception { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel(); - verify(child).shutdown(); + for (Subchannel sc : loadBalancer.getSubchannels()) { + verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -293,34 +311,36 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - Map childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. - for ( ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getSubchannels(mockArgs); - childToSubChannelMap.put(child, sc); + for (Subchannel sc : loadBalancer.getSubchannels()) { Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState( sc, ConnectivityStateInfo.forTransientFailure(error)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); inOrder.verify(mockHelper).refreshNameResolution(); deliverSubchannelState( sc, ConnectivityStateInfo.forNonError(CONNECTING)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); + Ref scStateInfo = sc.getAttributes().get( + STATE_INFO); + assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(scStateInfo.value.getStatus()).isEqualTo(error); } - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class)); inOrder.verifyNoMoreInteractions(); - ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = childToSubChannelMap.get(child); + Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(child.getCurrentState()).isEqualTo(READY); + Ref subchannelStateInfo = subchannel.getAttributes().get( + STATE_INFO); + assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -330,15 +350,16 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); // Simulate state transitions for each subchannel individually. - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getSubchannels(mockArgs); + for (Subchannel sc : loadBalancer.getSubchannels()) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -349,7 +370,7 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { // Simulate receiving go-away so READY subchannels transit to IDLE. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - verify(sc, times(1)).requestConnection(); + verify(sc, times(2)).requestConnection(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); } @@ -362,13 +383,12 @@ public void pickerRoundRobin() throws Exception { Subchannel subchannel1 = mock(Subchannel.class); Subchannel subchannel2 = mock(Subchannel.class); - ArrayList pickers = Lists.newArrayList( - TestUtils.pickerOf(subchannel), TestUtils.pickerOf(subchannel1), - TestUtils.pickerOf(subchannel2)); - - ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(pickers), + ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList( + Lists.newArrayList(subchannel, subchannel1, subchannel2)), 0 /* startIndex */); + assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2); + assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(subchannel2, picker.pickSubchannel(mockArgs).getSubchannel()); @@ -379,7 +399,7 @@ public void pickerRoundRobin() throws Exception { public void pickerEmptyList() throws Exception { SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN); - assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, picker.pickSubchannel(mockArgs).getStatus()); } @@ -397,13 +417,12 @@ public void nameResolutionErrorWithNoChannels() throws Exception { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { - boolean addressesAccepted = acceptAddresses(servers, affinity); final Subchannel readySubchannel = subchannels.values().iterator().next(); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); - loadBalancer.resolvingAddresses = true; loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); - loadBalancer.resolvingAddresses = false; verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(2)) @@ -424,14 +443,15 @@ public void nameResolutionErrorWithActiveChannels() throws Exception { @Test public void subchannelStateIsolation() throws Exception { - boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); - assertThat(addressesAccepted).isTrue(); - Iterator subchannelIterator = subchannels.values().iterator(); Subchannel sc1 = subchannelIterator.next(); Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -458,7 +478,7 @@ public void subchannelStateIsolation() throws Exception { // The IDLE subchannel is dropped from the picker, but a reconnection is requested assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1, sc3); - verify(sc2, times(1)).requestConnection(); + verify(sc2, times(2)).requestConnection(); // The failing subchannel is dropped from the picker, with no requested reconnect assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); @@ -471,7 +491,7 @@ public void subchannelStateIsolation() throws Exception { public void readyPicker_emptyList() { // ready picker list must be non-empty try { - new ReadyPicker(Collections.emptyList(), 0); + new ReadyPicker(Collections.emptyList(), 0); fail(); } catch (IllegalArgumentException expected) { } @@ -483,10 +503,9 @@ public void internalPickerComparisons() { EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK")); EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯")); - acceptAddresses(servers, Attributes.EMPTY); // create subchannels Iterator subchannelIterator = subchannels.values().iterator(); - SubchannelPicker sc1 = TestUtils.pickerOf(subchannelIterator.next()); - SubchannelPicker sc2 = TestUtils.pickerOf(subchannelIterator.next()); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0); ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0); ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1); @@ -507,27 +526,18 @@ public void internalPickerComparisons() { public void emptyAddresses() { assertThat(loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) + .setAddresses(Collections.emptyList()) .setAttributes(affinity) .build())).isFalse(); } - private List getList(SubchannelPicker picker) { - - if (picker instanceof ReadyPicker) { - List subchannelList = new ArrayList<>(); - for (SubchannelPicker childPicker : ((ReadyPicker) picker).getList()) { - subchannelList.add(childPicker.pickSubchannel(mockArgs).getSubchannel()); - } - return subchannelList; - } else { - return new ArrayList<>(); - } + private static List getList(SubchannelPicker picker) { + return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : + Collections.emptyList(); } private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - Subchannel realSc = mockToRealSubChannelMap.get(subchannel); - subchannelStateListeners.get(realSc).onSubchannelState(newState); + subchannelStateListeners.get(subchannel).onSubchannelState(newState); } private static class FakeSocketAddress extends SocketAddress { @@ -542,22 +552,4 @@ public String toString() { return "FakeSocketAddress-" + name; } } - - private class TestHelper extends AbstractTestHelper { - - @Override - public Map, Subchannel> getSubchannelMap() { - return subchannels; - } - - @Override - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; - } - - @Override - public Map getSubchannelStateListeners() { - return subchannelStateListeners; - } - } } diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java deleted file mode 100644 index 40986178328..00000000000 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright 2023 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.util; - -import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.Mockito.mock; - -import io.grpc.Attributes; -import io.grpc.Channel; -import io.grpc.ChannelLogger; -import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer.CreateSubchannelArgs; -import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancer.Subchannel; -import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * A real class that can be used as a delegate of a mock Helper to provide more real representation - * and track the subchannels as is needed with petiole policies where the subchannels are no - * longer direct children of the loadbalancer. - *
- * To use it replace
- * \@mock Helper mockHelper
- * with
- *

Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));

- *
- * TestHelper will need to define accessors for the maps that information is store within as - * those maps need to be defined in the Test class. - */ -public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { - - public abstract Map, Subchannel> getSubchannelMap(); - - public abstract Map getMockToRealSubChannelMap(); - - public abstract Map getSubchannelStateListeners(); - - @Override - public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - // do nothing, should have been done in the wrapper helpers - } - - @Override - protected Helper delegate() { - throw new UnsupportedOperationException("This helper class is only for use in this test"); - } - - @Override - public Subchannel createSubchannel(CreateSubchannelArgs args) { - Subchannel subchannel = getSubchannelMap().get(args.getAddresses()); - if (subchannel == null) { - TestSubchannel delegate = new TestSubchannel(args); - subchannel = mock(Subchannel.class, delegatesTo(delegate)); - getSubchannelMap().put(args.getAddresses(), subchannel); - getMockToRealSubChannelMap().put(subchannel, delegate); - } - - return subchannel; - } - - @Override - public void refreshNameResolution() { - // no-op - } - - public void setChannel(Subchannel subchannel, Channel channel) { - ((TestSubchannel)subchannel).channel = channel; - } - - @Override - public String toString() { - return "Test Helper"; - } - - private class TestSubchannel extends ForwardingSubchannel { - final CreateSubchannelArgs args; - Channel channel; - - public TestSubchannel(CreateSubchannelArgs args) { - this.args = args; - } - - @Override - protected Subchannel delegate() { - throw new UnsupportedOperationException("Only to be used in tests"); - } - - @Override - public List getAllAddresses() { - return args.getAddresses(); - } - - @Override - public Attributes getAttributes() { - return args.getAttributes(); - } - - @Override - public void requestConnection() { - // Ignore, we will manually update state - } - - @Override - public void updateAddresses(List addrs) { - // Do nothing, will be handled in wrappers - } - - @Override - public void start(SubchannelStateListener listener) { - getSubchannelStateListeners().put(this, listener); - } - - @Override - public void shutdown() { - getSubchannelStateListeners().remove(this); - for (EquivalentAddressGroup eag : getAllAddresses()) { - getSubchannelMap().remove(Collections.singletonList(eag)); - } - } - - @Override - public Channel asChannel() { - return channel; - } - - @Override - public ChannelLogger getChannelLogger() { - return mock(ChannelLogger.class); - } - - @Override - public String toString() { - return "Mock Subchannel" + args.toString(); - } - } -} - diff --git a/xds/build.gradle b/xds/build.gradle index a6db9db9937..3f3cf6a0f6e 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -58,8 +58,7 @@ dependencies { def nettyDependency = implementation project(':grpc-netty') testImplementation project(':grpc-rls') - testImplementation testFixtures(project(':grpc-core')), - testFixtures(project(':grpc-util')) + testImplementation testFixtures(project(':grpc-core')) annotationProcessor libraries.auto.value // At runtime use the epoll included in grpc-netty-shaded diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index 895125d3229..a4489204236 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -16,68 +16,36 @@ package io.grpc.xds; -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import io.grpc.InternalLogId; -import io.grpc.LoadBalancerProvider; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.util.HashMap; import java.util.Map; -import java.util.Map.Entry; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; /** * The top-level load balancing policy. */ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { - @VisibleForTesting - public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; - protected final SynchronizationContext syncContext; - private final ScheduledExecutorService timeService; private final XdsLogger logger; ClusterManagerLoadBalancer(Helper helper) { super(helper); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); - this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger = XdsLogger.withLogId( InternalLogId.allocate("cluster_manager-lb", helper.getAuthority())); - logger.log(XdsLogLevel.INFO, "Created"); } @Override - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); - } - - @Override - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { + protected Map getPolicySelectionMap( + ResolvedAddresses resolvedAddresses) { ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map newChildPolicies = new HashMap<>(); - if (config != null) { - for (Entry entry : config.childPolicies.entrySet()) { - ChildLbState child = getChildLbState(entry.getKey()); - if (child == null) { - child = new ClusterManagerLbState(entry.getKey(), - entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker()); - } - newChildPolicies.put(entry.getKey(), child); - } - } + Map newChildPolicies = new HashMap<>(config.childPolicies); logger.log( XdsLogLevel.INFO, "Received cluster_manager lb config: child names={0}", newChildPolicies.keySet()); @@ -107,58 +75,4 @@ public String toString() { } }; } - - private class ClusterManagerLbState extends ChildLbState { - @Nullable - ScheduledHandle deletionTimer; - - public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider, - Object childConfig, SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); - } - - @Override - protected void shutdown() { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - } - super.shutdown(); - } - - @Override - protected void reactivate(LoadBalancerProvider policyProvider) { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); - } - - super.reactivate(policyProvider); - } - - @Override - protected void deactivate() { - if (isDeactivated()) { - return; - } - - class DeletionTask implements Runnable { - - @Override - public void run() { - shutdown(); - removeChild(getKey()); - } - } - - deletionTimer = - syncContext.schedule( - new DeletionTask(), - DELAYED_CHILD_DELETION_TIME_MINUTES, - TimeUnit.MINUTES, - timeService); - setDeactivated(); - logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey()); - } - - } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 216221d2505..833683729c2 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -17,20 +17,17 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Deadline.Ticker; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; -import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -43,13 +40,11 @@ import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; -import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -95,14 +90,6 @@ public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random ra this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random); } - @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker) { - ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, - initialPicker); - return childLbState; - } - @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) { @@ -124,100 +111,9 @@ public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } @Override - public RoundRobinPicker createReadyPicker(Collection activeList) { - return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty); - } - - @Override - protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { - return super.getChildLbStateEag(eag); - } - - @VisibleForTesting - final class WeightedChildLbState extends ChildLbState { - - private final Set subchannels = new HashSet<>(); - private volatile long lastUpdated; - private volatile long nonEmptySince; - private volatile double weight = 0; - - private OrcaReportListener orcaReportListener; - - public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); - } - - @VisibleForTesting - EquivalentAddressGroup getEag() { - return stripAttrs((EquivalentAddressGroup) getKey()); - } - - private double getWeight() { - if (config == null) { - return 0; - } - long now = ticker.nanoTime(); - if (now - lastUpdated >= config.weightExpirationPeriodNanos) { - nonEmptySince = infTime; - return 0; - } else if (now - nonEmptySince < config.blackoutPeriodNanos - && config.blackoutPeriodNanos > 0) { - return 0; - } else { - return weight; - } - } - - public void addSubchannel(WrrSubchannel wrrSubchannel) { - subchannels.add(wrrSubchannel); - } - - public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { - if (orcaReportListener != null - && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { - return orcaReportListener; - } - orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); - return orcaReportListener; - } - - public void removeSubchannel(WrrSubchannel wrrSubchannel) { - subchannels.remove(wrrSubchannel); - } - - final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { - private final float errorUtilizationPenalty; - - OrcaReportListener(float errorUtilizationPenalty) { - this.errorUtilizationPenalty = errorUtilizationPenalty; - } - - @Override - public void onLoadReport(MetricReport report) { - double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); - if (utilization > 0 && report.getQps() > 0) { - double penalty = 0; - if (report.getEps() > 0 && errorUtilizationPenalty > 0) { - penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; - } - newWeight = report.getQps() / (utilization + penalty); - } - if (newWeight == 0) { - return; - } - if (nonEmptySince == infTime) { - nonEmptySince = ticker.nanoTime(); - } - lastUpdated = ticker.nanoTime(); - weight = newWeight; - } - } + public RoundRobinPicker createReadyPicker(List activeList) { + return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport, + config.errorUtilizationPenalty); } private final class UpdateWeightTask implements Runnable { @@ -232,18 +128,16 @@ public void run() { } private void afterAcceptAddresses() { - for (ChildLbState child : getChildLbStates()) { - WeightedChildLbState wChild = (WeightedChildLbState) child; - for (WrrSubchannel weightedSubchannel : wChild.subchannels) { - if (config.enableOobLoadReport) { - OrcaOobUtil.setListener(weightedSubchannel, - wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), - OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); - } else { - OrcaOobUtil.setListener(weightedSubchannel, null, null); - } + for (Subchannel subchannel : getSubchannels()) { + WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel; + if (config.enableOobLoadReport) { + OrcaOobUtil.setListener(weightedSubchannel, + weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty), + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) + .build()); + } else { + OrcaOobUtil.setListener(weightedSubchannel, null, null); } } } @@ -275,69 +169,105 @@ protected Helper delegate() { @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - checkElementIndex(0, args.getAddresses().size(), "Empty address group"); - WeightedChildLbState childLbState = - (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0)); - return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState); + return wrr.new WrrSubchannel(delegate().createSubchannel(args)); } } @VisibleForTesting final class WrrSubchannel extends ForwardingSubchannel { private final Subchannel delegate; - private final WeightedChildLbState owner; + private volatile long lastUpdated; + private volatile long nonEmptySince; + private volatile double weight; - WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) { + WrrSubchannel(Subchannel delegate) { this.delegate = checkNotNull(delegate, "delegate"); - this.owner = checkNotNull(owner, "owner"); } @Override public void start(SubchannelStateListener listener) { - owner.addSubchannel(this); delegate().start(new SubchannelStateListener() { @Override public void onSubchannelState(ConnectivityStateInfo newState) { if (newState.getState().equals(ConnectivityState.READY)) { - owner.nonEmptySince = infTime; + nonEmptySince = infTime; } listener.onSubchannelState(newState); } }); } + private double getWeight() { + if (config == null) { + return 0; + } + long now = ticker.nanoTime(); + if (now - lastUpdated >= config.weightExpirationPeriodNanos) { + nonEmptySince = infTime; + return 0; + } else if (now - nonEmptySince < config.blackoutPeriodNanos + && config.blackoutPeriodNanos > 0) { + return 0; + } else { + return weight; + } + } + @Override protected Subchannel delegate() { return delegate; } - @Override - public void shutdown() { - super.shutdown(); - owner.removeSubchannel(this); + final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { + private final float errorUtilizationPenalty; + + OrcaReportListener(float errorUtilizationPenalty) { + this.errorUtilizationPenalty = errorUtilizationPenalty; + } + + @Override + public void onLoadReport(MetricReport report) { + double newWeight = 0; + // Prefer application utilization and fallback to CPU utilization if unset. + double utilization = + report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() + : report.getCpuUtilization(); + if (utilization > 0 && report.getQps() > 0) { + double penalty = 0; + if (report.getEps() > 0 && errorUtilizationPenalty > 0) { + penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; + } + newWeight = report.getQps() / (utilization + penalty); + } + if (newWeight == 0) { + return; + } + if (nonEmptySince == infTime) { + nonEmptySince = ticker.nanoTime(); + } + lastUpdated = ticker.nanoTime(); + weight = newWeight; + } } } @VisibleForTesting final class WeightedRoundRobinPicker extends RoundRobinPicker { - private final List children; + private final List list; private final Map subchannelToReportListenerMap = new HashMap<>(); private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private volatile StaticStrideScheduler scheduler; - WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, + WeightedRoundRobinPicker(List list, boolean enableOobLoadReport, float errorUtilizationPenalty) { - checkNotNull(children, "children"); - Preconditions.checkArgument(!children.isEmpty(), "empty child list"); - this.children = children; - for (ChildLbState child : children) { - WeightedChildLbState wChild = (WeightedChildLbState) child; - for (WrrSubchannel subchannel : wChild.subchannels) { - this.subchannelToReportListenerMap - .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); - } + checkNotNull(list, "list"); + Preconditions.checkArgument(!list.isEmpty(), "empty list"); + this.list = list; + for (Subchannel subchannel : list) { + this.subchannelToReportListenerMap.put(subchannel, + ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)); } this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; @@ -346,24 +276,22 @@ final class WeightedRoundRobinPicker extends RoundRobinPicker { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - ChildLbState childLbState = children.get(scheduler.pick()); - WeightedChildLbState wChild = (WeightedChildLbState) childLbState; - PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); - Subchannel subchannel = pickResult.getSubchannel(); + Subchannel subchannel = list.get(scheduler.pick()); if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( subchannelToReportListenerMap.getOrDefault(subchannel, - wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); + ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)))); } else { return PickResult.withSubchannel(subchannel); } } private void updateWeight() { - float[] newWeights = new float[children.size()]; - for (int i = 0; i < children.size(); i++) { - double newWeight = ((WeightedChildLbState)children.get(i)).getWeight(); + float[] newWeights = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + WrrSubchannel subchannel = (WrrSubchannel) list.get(i); + double newWeight = subchannel.getWeight(); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } this.scheduler = new StaticStrideScheduler(newWeights, sequence); @@ -374,12 +302,12 @@ public String toString() { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", children).toString(); + .add("list", list).toString(); } @VisibleForTesting - List getChildren() { - return children; + List getList() { + return list; } @Override @@ -394,8 +322,7 @@ public boolean isEquivalentTo(RoundRobinPicker picker) { // the lists cannot contain duplicate subchannels return enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && children.size() == other.children.size() && new HashSet<>( - children).containsAll(other.children); + && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list); } } @@ -577,13 +504,11 @@ private Builder() { } - @SuppressWarnings("UnusedReturnValue") Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) { this.blackoutPeriodNanos = blackoutPeriodNanos; return this; } - @SuppressWarnings("UnusedReturnValue") Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) { this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; return this; diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index 32e905225d2..c90a9f58d31 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -202,9 +202,7 @@ public interface OrcaOobReportListener { */ public static void setListener(Subchannel subchannel, OrcaOobReportListener listener, OrcaReportingConfig config) { - Attributes attributes = subchannel.getAttributes(); - SubchannelImpl orcaSubchannel = - (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); + SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); if (orcaSubchannel == null) { throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled." + " Try to use a subchannel created by OrcaOobUtil.OrcaHelper."); @@ -243,9 +241,7 @@ protected Helper delegate() { public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); Subchannel subchannel = super.createSubchannel(args); - Attributes attributes = subchannel.getAttributes(); - SubchannelImpl orcaSubchannel = - (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); + SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); OrcaReportingState orcaState; if (orcaSubchannel == null) { // Only the first load balancing policy requesting ORCA reports instantiates an diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index c59ad1318e2..ac08f69f88c 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -17,10 +17,11 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -34,6 +35,7 @@ import com.google.protobuf.Duration; import io.grpc.Attributes; import io.grpc.Channel; +import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -48,15 +50,12 @@ import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; -import io.grpc.internal.TestUtils; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; -import io.grpc.util.AbstractTestHelper; -import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler; -import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; import java.net.SocketAddress; import java.util.Arrays; import java.util.HashMap; @@ -68,7 +67,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; @@ -89,8 +87,8 @@ public class WeightedRoundRobinLoadBalancerTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - private final TestHelper testHelperInstance = new TestHelper(); - private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); + @Mock + Helper helper; @Mock private LoadBalancer.PickSubchannelArgs mockArgs; @@ -101,8 +99,9 @@ public class WeightedRoundRobinLoadBalancerTest { private ArgumentCaptor pickerCaptor2; private final List servers = Lists.newArrayList(); + private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - private final Map mockToRealSubChannelMap = new HashMap<>(); + private final Map subchannelStateListeners = Maps.newLinkedHashMap(); @@ -135,8 +134,7 @@ public void setup() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = helper.createSubchannel(CreateSubchannelArgs.newBuilder().setAddresses(eag) - .build()); + Subchannel sc = mock(Subchannel.class); Channel channel = mock(Channel.class); when(channel.newCall(any(), any())).then( new Answer>() { @@ -149,13 +147,35 @@ public ClientCall answer( return clientCall; } }); - testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel); + when(sc.asChannel()).thenReturn(channel); subchannels.put(Arrays.asList(eag), sc); } + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn( + fakeClock.getScheduledExecutorService()); + when(helper.createSubchannel(any(CreateSubchannelArgs.class))) + .then(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class)); + doAnswer( + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put( + subchannel, (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), new FakeRandom(0)); - - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); } @Test @@ -163,44 +183,44 @@ public void wrrLifeCycle() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getChildren().size()).isEqualTo(1); + assertThat(weightedPicker.getList().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getChildren().size()).isEqualTo(2); + assertThat(weightedPicker.getList().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); assertThat(weightedPickerStr).contains("list="); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - - assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag()); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -218,44 +238,35 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on verifyNoMoreInteractions(mockArgs); } - /** - * Picks subchannel using mockArgs, gets its EAG, and then strips the Attrs to make a key. - */ - private EquivalentAddressGroup getAddressesFromPick(WeightedRoundRobinPicker weightedPicker) { - return TestUtils.stripAttrs( - weightedPicker.pickSubchannel(mockArgs).getSubchannel().getAddresses()); - } - @Test public void enableOobLoadReportConfig() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener assertThat(oobCalls.isEmpty()).isTrue(); @@ -269,8 +280,7 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on eq(ConnectivityState.READY), pickerCaptor2.capture()); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); @@ -285,52 +295,46 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); - weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); - + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + r1); + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + r2); + weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + r3); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 10000; i++) { - EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio)) - .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio )) - .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio )) - .isAtMost(0.0002); - } - - private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) { - return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel)); - } - - private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) { - return picker.getChildren().get(index); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio)) + .isLessThan(0.0002); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio )) + .isLessThan(0.0002); + assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio )) + .isLessThan(0.0002); } @Test @@ -468,14 +472,14 @@ public void emptyConfig() { assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(null) .setAttributes(affinity).build())).isFalse(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(fakeClock.getPendingTasks()).isEmpty(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().getClass().getName()) @@ -488,51 +492,51 @@ public void blackoutPeriod() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); - for (int i = 0; i < 10000; i++) { - EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // within blackout period, fallback to simple round robin - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isLessThan(0.002); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); - for (int i = 0; i < 10000; i++) { - EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -541,39 +545,39 @@ public void updateWeightTimer() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getChildren().size()).isEqualTo(1); + assertThat(weightedPicker.getList().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getChildren().size()).isEqualTo(2); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + assertThat(weightedPicker.getList().size()).isEqualTo(2); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -582,18 +586,17 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild2.getEag()); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel2); + } @Test @@ -601,52 +604,52 @@ public void weightExpired() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) .isLessThan(0.002); // weight expired, fallback to simple round robin assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)) .isLessThan(0.002); } @@ -655,113 +658,107 @@ public void rrFallback() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - Map qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2, - weightedChild2.getEag(), 1); - Map pickCount = new HashMap<>(); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + Map qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2, + weightedSubchannel2, 1); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - EquivalentAddressGroup addresses = getAddresses(pickResult); - pickCount.merge(addresses, 1, Integer::sum); + pickCount.put(pickResult.getSubchannel(), + pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel(); + subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, + 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); } - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); - - // Identical to above except forwards time after each pick pickCount.clear(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - EquivalentAddressGroup addresses = getAddresses(pickResult); - pickCount.merge(addresses, 1, Integer::sum); + pickCount.put(pickResult.getSubchannel(), + pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel subchannel = (WrrSubchannel) pickResult.getSubchannel(); + subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, + 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) .isAtMost(0.1); } - private static EquivalentAddressGroup getAddresses(PickResult pickResult) { - return TestUtils.stripAttrs(pickResult.getSubchannel().getAddresses()); - } - @Test public void unknownWeightIsAvgWeight() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( - any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1) - .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2) - .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - getSubchannelStateListener(readySubchannel3) - .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.merge(result.getAddresses(), 1, Integer::sum); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) .isLessThan(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) .isLessThan(0.002); } @@ -770,33 +767,33 @@ public void pickFromOtherThread() throws Exception { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); - Map pickCount = new ConcurrentHashMap<>(); - pickCount.put(weightedChild1.getEag(), new AtomicInteger(0)); - pickCount.put(weightedChild2.getEag(), new AtomicInteger(0)); + Map pickCount = new ConcurrentHashMap<>(); + pickCount.put(weightedSubchannel1, new AtomicInteger(0)); + pickCount.put(weightedSubchannel2, new AtomicInteger(0)); new Thread(new Runnable() { @Override public void run() { @@ -805,7 +802,7 @@ public void run() { barrier.await(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.get(result.getAddresses()).addAndGet(1); + pickCount.get(result).addAndGet(1); } barrier.await(); } catch (Exception ex) { @@ -816,15 +813,15 @@ public void run() { assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); barrier.await(); for (int i = 0; i < 1000; i++) { - EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs)); + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); pickCount.get(result).addAndGet(1); } barrier.await(); assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -1107,34 +1104,4 @@ public int nextInt() { return nextInt; } } - - private class TestHelper extends AbstractTestHelper { - - @Override - public Map, Subchannel> getSubchannelMap() { - return subchannels; - } - - @Override - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; - } - - @Override - public Map getSubchannelStateListeners() { - return subchannelStateListeners; - } - - @Override - public SynchronizationContext getSynchronizationContext() { - return syncContext; - } - - @Override - public ScheduledExecutorService getScheduledExecutorService() { - return fakeClock.getScheduledExecutorService(); - } - - - } }