001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.client;
019
020import static org.junit.Assert.assertEquals;
021import static org.junit.Assert.assertThrows;
022
023import java.io.IOException;
024import java.net.SocketAddress;
025import java.util.Collections;
026import java.util.Set;
027import java.util.concurrent.CompletableFuture;
028import java.util.concurrent.ExecutorService;
029import java.util.concurrent.Executors;
030import java.util.concurrent.atomic.AtomicInteger;
031import java.util.stream.Collectors;
032import java.util.stream.IntStream;
033import org.apache.hadoop.conf.Configuration;
034import org.apache.hadoop.hbase.HBaseClassTestRule;
035import org.apache.hadoop.hbase.HBaseCommonTestingUtility;
036import org.apache.hadoop.hbase.ServerName;
037import org.apache.hadoop.hbase.ipc.RpcClient;
038import org.apache.hadoop.hbase.ipc.RpcClientFactory;
039import org.apache.hadoop.hbase.security.User;
040import org.apache.hadoop.hbase.testclassification.ClientTests;
041import org.apache.hadoop.hbase.testclassification.SmallTests;
042import org.apache.hadoop.hbase.util.FutureUtils;
043import org.junit.AfterClass;
044import org.junit.Before;
045import org.junit.BeforeClass;
046import org.junit.ClassRule;
047import org.junit.Test;
048import org.junit.experimental.categories.Category;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052import org.apache.hbase.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder;
053import org.apache.hbase.thirdparty.com.google.protobuf.BlockingRpcChannel;
054import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescriptor;
055import org.apache.hbase.thirdparty.com.google.protobuf.Message;
056import org.apache.hbase.thirdparty.com.google.protobuf.RpcCallback;
057import org.apache.hbase.thirdparty.com.google.protobuf.RpcChannel;
058import org.apache.hbase.thirdparty.com.google.protobuf.RpcController;
059
060import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.GetClusterIdResponse;
061
062@Category({ ClientTests.class, SmallTests.class })
063public class TestRpcBasedRegistryHedgedReads {
064
065  @ClassRule
066  public static final HBaseClassTestRule CLASS_RULE =
067    HBaseClassTestRule.forClass(TestRpcBasedRegistryHedgedReads.class);
068
069  private static final Logger LOG = LoggerFactory.getLogger(TestRpcBasedRegistryHedgedReads.class);
070
071  private static final String HEDGED_REQS_FANOUT_CONFIG_NAME = "hbase.test.hedged.reqs.fanout";
072  private static final String INITIAL_DELAY_SECS_CONFIG_NAME =
073    "hbase.test.refresh.initial.delay.secs";
074  private static final String REFRESH_INTERVAL_SECS_CONFIG_NAME =
075    "hbase.test.refresh.interval.secs";
076  private static final String MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME =
077    "hbase.test.min.refresh.interval.secs";
078
079  private static final HBaseCommonTestingUtility UTIL = new HBaseCommonTestingUtility();
080
081  private static final ExecutorService EXECUTOR =
082    Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).build());
083
084  private static Set<ServerName> BOOTSTRAP_NODES;
085
086  private static AtomicInteger CALLED = new AtomicInteger(0);
087
088  private static volatile int BAD_RESP_INDEX;
089
090  private static volatile Set<Integer> GOOD_RESP_INDEXS;
091
092  private static GetClusterIdResponse RESP =
093    GetClusterIdResponse.newBuilder().setClusterId("id").build();
094
095  public static final class RpcClientImpl implements RpcClient {
096
097    public RpcClientImpl(Configuration configuration, String clusterId, SocketAddress localAddress,
098      MetricsConnection metrics) {
099    }
100
101    @Override
102    public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, int rpcTimeout) {
103      throw new UnsupportedOperationException();
104    }
105
106    @Override
107    public RpcChannel createRpcChannel(ServerName sn, User user, int rpcTimeout) {
108      return new RpcChannelImpl();
109    }
110
111    @Override
112    public void cancelConnections(ServerName sn) {
113    }
114
115    @Override
116    public void close() {
117    }
118
119    @Override
120    public boolean hasCellBlockSupport() {
121      return false;
122    }
123  }
124
125  /**
126   * A dummy RpcChannel implementation that intercepts the GetClusterId() RPC calls and injects
127   * errors. All other RPCs are ignored.
128   */
129  public static final class RpcChannelImpl implements RpcChannel {
130
131    @Override
132    public void callMethod(MethodDescriptor method, RpcController controller, Message request,
133      Message responsePrototype, RpcCallback<Message> done) {
134      if (!method.getName().equals("GetClusterId")) {
135        // On RPC failures, MasterRegistry internally runs getMasters() RPC to keep the master list
136        // fresh. We do not want to intercept those RPCs here and double count.
137        return;
138      }
139      // simulate the asynchronous behavior otherwise all logic will perform in the same thread...
140      EXECUTOR.execute(() -> {
141        int index = CALLED.getAndIncrement();
142        if (index == BAD_RESP_INDEX) {
143          done.run(GetClusterIdResponse.getDefaultInstance());
144        } else if (GOOD_RESP_INDEXS.contains(index)) {
145          done.run(RESP);
146        } else {
147          controller.setFailed("inject error");
148          done.run(null);
149        }
150      });
151    }
152  }
153
154  private AbstractRpcBasedConnectionRegistry createRegistry(int hedged) throws IOException {
155    Configuration conf = UTIL.getConfiguration();
156    conf.setInt(HEDGED_REQS_FANOUT_CONFIG_NAME, hedged);
157    return new AbstractRpcBasedConnectionRegistry(conf, HEDGED_REQS_FANOUT_CONFIG_NAME,
158      INITIAL_DELAY_SECS_CONFIG_NAME, REFRESH_INTERVAL_SECS_CONFIG_NAME,
159      MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME) {
160
161      @Override
162      protected Set<ServerName> getBootstrapNodes(Configuration conf) throws IOException {
163        return BOOTSTRAP_NODES;
164      }
165
166      @Override
167      protected CompletableFuture<Set<ServerName>> fetchEndpoints() {
168        return CompletableFuture.completedFuture(BOOTSTRAP_NODES);
169      }
170
171      @Override
172      public String getConnectionString() {
173        return "unimplemented";
174      }
175    };
176  }
177
178  @BeforeClass
179  public static void setUpBeforeClass() {
180    Configuration conf = UTIL.getConfiguration();
181    conf.setClass(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, RpcClientImpl.class,
182      RpcClient.class);
183    // disable refresh, we do not need to refresh in this test
184    conf.setLong(INITIAL_DELAY_SECS_CONFIG_NAME, Integer.MAX_VALUE);
185    conf.setLong(REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE);
186    conf.setLong(MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE - 1);
187    BOOTSTRAP_NODES = IntStream.range(0, 10)
188      .mapToObj(i -> ServerName.valueOf("localhost", (10000 + 100 * i), ServerName.NON_STARTCODE))
189      .collect(Collectors.toSet());
190  }
191
192  @AfterClass
193  public static void tearDownAfterClass() {
194    EXECUTOR.shutdownNow();
195  }
196
197  @Before
198  public void setUp() {
199    CALLED.set(0);
200    BAD_RESP_INDEX = -1;
201    GOOD_RESP_INDEXS = Collections.emptySet();
202  }
203
204  private <T> T logIfError(CompletableFuture<T> future) throws IOException {
205    try {
206      return FutureUtils.get(future);
207    } catch (Throwable t) {
208      LOG.warn("", t);
209      throw t;
210    }
211  }
212
213  @Test
214  public void testAllFailNoHedged() throws IOException {
215    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(1)) {
216      assertThrows(IOException.class, () -> logIfError(registry.getClusterId()));
217      assertEquals(10, CALLED.get());
218    }
219  }
220
221  @Test
222  public void testAllFailHedged3() throws IOException {
223    BAD_RESP_INDEX = 5;
224    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(3)) {
225      assertThrows(IOException.class, () -> logIfError(registry.getClusterId()));
226      assertEquals(10, CALLED.get());
227    }
228  }
229
230  @Test
231  public void testFirstSucceededNoHedge() throws IOException {
232    GOOD_RESP_INDEXS =
233      IntStream.range(0, 10).mapToObj(Integer::valueOf).collect(Collectors.toSet());
234    // will be set to 1
235    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(0)) {
236      String clusterId = logIfError(registry.getClusterId());
237      assertEquals(RESP.getClusterId(), clusterId);
238      assertEquals(1, CALLED.get());
239    }
240  }
241
242  @Test
243  public void testSecondRoundSucceededHedge4() throws IOException {
244    GOOD_RESP_INDEXS = Collections.singleton(6);
245    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(4)) {
246      String clusterId = logIfError(registry.getClusterId());
247      assertEquals(RESP.getClusterId(), clusterId);
248      UTIL.waitFor(5000, () -> CALLED.get() == 8);
249    }
250  }
251
252  @Test
253  public void testSucceededWithLargestHedged() throws IOException, InterruptedException {
254    GOOD_RESP_INDEXS = Collections.singleton(5);
255    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(Integer.MAX_VALUE)) {
256      String clusterId = logIfError(registry.getClusterId());
257      assertEquals(RESP.getClusterId(), clusterId);
258      UTIL.waitFor(5000, () -> CALLED.get() == 10);
259      Thread.sleep(1000);
260      // make sure we do not send more
261      assertEquals(10, CALLED.get());
262    }
263  }
264}