001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.client;
019
020import static org.junit.jupiter.api.Assertions.assertEquals;
021import static org.junit.jupiter.api.Assertions.assertThrows;
022
023import java.io.IOException;
024import java.net.SocketAddress;
025import java.util.Collections;
026import java.util.Map;
027import java.util.Set;
028import java.util.concurrent.CompletableFuture;
029import java.util.concurrent.ExecutorService;
030import java.util.concurrent.Executors;
031import java.util.concurrent.atomic.AtomicInteger;
032import java.util.stream.Collectors;
033import java.util.stream.IntStream;
034import org.apache.hadoop.conf.Configuration;
035import org.apache.hadoop.hbase.HBaseCommonTestingUtil;
036import org.apache.hadoop.hbase.ServerName;
037import org.apache.hadoop.hbase.ipc.RpcClient;
038import org.apache.hadoop.hbase.ipc.RpcClientFactory;
039import org.apache.hadoop.hbase.security.User;
040import org.apache.hadoop.hbase.testclassification.ClientTests;
041import org.apache.hadoop.hbase.testclassification.SmallTests;
042import org.apache.hadoop.hbase.util.FutureUtils;
043import org.junit.jupiter.api.AfterAll;
044import org.junit.jupiter.api.BeforeAll;
045import org.junit.jupiter.api.BeforeEach;
046import org.junit.jupiter.api.Tag;
047import org.junit.jupiter.api.Test;
048import org.slf4j.Logger;
049import org.slf4j.LoggerFactory;
050
051import org.apache.hbase.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder;
052import org.apache.hbase.thirdparty.com.google.protobuf.BlockingRpcChannel;
053import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescriptor;
054import org.apache.hbase.thirdparty.com.google.protobuf.Message;
055import org.apache.hbase.thirdparty.com.google.protobuf.RpcCallback;
056import org.apache.hbase.thirdparty.com.google.protobuf.RpcChannel;
057import org.apache.hbase.thirdparty.com.google.protobuf.RpcController;
058
059import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.ConnectionRegistryService;
060import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.GetClusterIdResponse;
061import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.GetConnectionRegistryResponse;
062
063@Tag(ClientTests.TAG)
064@Tag(SmallTests.TAG)
065public class TestRpcBasedRegistryHedgedReads {
066
067  private static final Logger LOG = LoggerFactory.getLogger(TestRpcBasedRegistryHedgedReads.class);
068
069  private static final String HEDGED_REQS_FANOUT_CONFIG_NAME = "hbase.test.hedged.reqs.fanout";
070  private static final String INITIAL_DELAY_SECS_CONFIG_NAME =
071    "hbase.test.refresh.initial.delay.secs";
072  private static final String REFRESH_INTERVAL_SECS_CONFIG_NAME =
073    "hbase.test.refresh.interval.secs";
074  private static final String MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME =
075    "hbase.test.min.refresh.interval.secs";
076
077  private static final HBaseCommonTestingUtil UTIL = new HBaseCommonTestingUtil();
078
079  private static final ExecutorService EXECUTOR =
080    Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).build());
081
082  private static Set<ServerName> BOOTSTRAP_NODES;
083
084  private static AtomicInteger CALLED = new AtomicInteger(0);
085
086  private static volatile int BAD_RESP_INDEX;
087
088  private static volatile Set<Integer> GOOD_RESP_INDEXS;
089
090  private static GetClusterIdResponse RESP =
091    GetClusterIdResponse.newBuilder().setClusterId("id").build();
092
093  public static final class RpcClientImpl implements RpcClient {
094
095    public RpcClientImpl(Configuration configuration, String clusterId, SocketAddress localAddress,
096      MetricsConnection metrics, Map<String, byte[]> attributes) {
097    }
098
099    @Override
100    public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, int rpcTimeout) {
101      throw new UnsupportedOperationException();
102    }
103
104    @Override
105    public RpcChannel createRpcChannel(ServerName sn, User user, int rpcTimeout) {
106      return new RpcChannelImpl();
107    }
108
109    @Override
110    public void cancelConnections(ServerName sn) {
111    }
112
113    @Override
114    public void close() {
115    }
116
117    @Override
118    public boolean hasCellBlockSupport() {
119      return false;
120    }
121  }
122
123  /**
124   * A dummy RpcChannel implementation that intercepts the GetClusterId() RPC calls and injects
125   * errors. All other RPCs are ignored.
126   */
127  public static final class RpcChannelImpl implements RpcChannel {
128
129    @Override
130    public void callMethod(MethodDescriptor method, RpcController controller, Message request,
131      Message responsePrototype, RpcCallback<Message> done) {
132      if (method.getService().equals(ConnectionRegistryService.getDescriptor())) {
133        // this is for setting up the rpc client
134        done.run(
135          GetConnectionRegistryResponse.newBuilder().setClusterId(RESP.getClusterId()).build());
136        return;
137      }
138      if (!method.getName().equals("GetClusterId")) {
139        // On RPC failures, MasterRegistry internally runs getMasters() RPC to keep the master list
140        // fresh. We do not want to intercept those RPCs here and double count.
141        return;
142      }
143      // simulate the asynchronous behavior otherwise all logic will perform in the same thread...
144      EXECUTOR.execute(() -> {
145        int index = CALLED.getAndIncrement();
146        if (index == BAD_RESP_INDEX) {
147          done.run(GetClusterIdResponse.getDefaultInstance());
148        } else if (GOOD_RESP_INDEXS.contains(index)) {
149          done.run(RESP);
150        } else {
151          controller.setFailed("inject error");
152          done.run(null);
153        }
154      });
155    }
156  }
157
158  private AbstractRpcBasedConnectionRegistry createRegistry(int hedged) throws IOException {
159    Configuration conf = UTIL.getConfiguration();
160    conf.setInt(HEDGED_REQS_FANOUT_CONFIG_NAME, hedged);
161    return new AbstractRpcBasedConnectionRegistry(conf, User.getCurrent(),
162      HEDGED_REQS_FANOUT_CONFIG_NAME, INITIAL_DELAY_SECS_CONFIG_NAME,
163      REFRESH_INTERVAL_SECS_CONFIG_NAME, MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME) {
164
165      @Override
166      protected Set<ServerName> getBootstrapNodes(Configuration conf) throws IOException {
167        return BOOTSTRAP_NODES;
168      }
169
170      @Override
171      protected CompletableFuture<Set<ServerName>> fetchEndpoints() {
172        return CompletableFuture.completedFuture(BOOTSTRAP_NODES);
173      }
174
175      @Override
176      public String getConnectionString() {
177        return "unimplemented";
178      }
179    };
180  }
181
182  @BeforeAll
183  public static void setUpBeforeClass() {
184    Configuration conf = UTIL.getConfiguration();
185    conf.setClass(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, RpcClientImpl.class,
186      RpcClient.class);
187    // disable refresh, we do not need to refresh in this test
188    conf.setLong(INITIAL_DELAY_SECS_CONFIG_NAME, Integer.MAX_VALUE);
189    conf.setLong(REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE);
190    conf.setLong(MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE - 1);
191    BOOTSTRAP_NODES = IntStream.range(0, 10)
192      .mapToObj(i -> ServerName.valueOf("localhost", (10000 + 100 * i), ServerName.NON_STARTCODE))
193      .collect(Collectors.toSet());
194  }
195
196  @AfterAll
197  public static void tearDownAfterClass() {
198    EXECUTOR.shutdownNow();
199  }
200
201  @BeforeEach
202  public void setUp() {
203    CALLED.set(0);
204    BAD_RESP_INDEX = -1;
205    GOOD_RESP_INDEXS = Collections.emptySet();
206  }
207
208  private <T> T logIfError(CompletableFuture<T> future) throws IOException {
209    try {
210      return FutureUtils.get(future);
211    } catch (Throwable t) {
212      LOG.warn("", t);
213      throw t;
214    }
215  }
216
217  @Test
218  public void testAllFailNoHedged() throws IOException {
219    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(1)) {
220      assertThrows(IOException.class, () -> logIfError(registry.getClusterId()));
221      assertEquals(10, CALLED.get());
222    }
223  }
224
225  @Test
226  public void testAllFailHedged3() throws IOException {
227    BAD_RESP_INDEX = 5;
228    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(3)) {
229      assertThrows(IOException.class, () -> logIfError(registry.getClusterId()));
230      assertEquals(10, CALLED.get());
231    }
232  }
233
234  @Test
235  public void testFirstSucceededNoHedge() throws IOException {
236    GOOD_RESP_INDEXS =
237      IntStream.range(0, 10).mapToObj(Integer::valueOf).collect(Collectors.toSet());
238    // will be set to 1
239    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(0)) {
240      String clusterId = logIfError(registry.getClusterId());
241      assertEquals(RESP.getClusterId(), clusterId);
242      assertEquals(1, CALLED.get());
243    }
244  }
245
246  @Test
247  public void testSecondRoundSucceededHedge4() throws IOException {
248    GOOD_RESP_INDEXS = Collections.singleton(6);
249    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(4)) {
250      String clusterId = logIfError(registry.getClusterId());
251      assertEquals(RESP.getClusterId(), clusterId);
252      UTIL.waitFor(5000, () -> CALLED.get() == 8);
253    }
254  }
255
256  @Test
257  public void testSucceededWithLargestHedged() throws IOException, InterruptedException {
258    GOOD_RESP_INDEXS = Collections.singleton(5);
259    try (AbstractRpcBasedConnectionRegistry registry = createRegistry(Integer.MAX_VALUE)) {
260      String clusterId = logIfError(registry.getClusterId());
261      assertEquals(RESP.getClusterId(), clusterId);
262      UTIL.waitFor(5000, () -> CALLED.get() == 10);
263      Thread.sleep(1000);
264      // make sure we do not send more
265      assertEquals(10, CALLED.get());
266    }
267  }
268}