001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 */ 018package org.apache.hadoop.hbase.client; 019 020import static org.junit.jupiter.api.Assertions.assertEquals; 021import static org.junit.jupiter.api.Assertions.assertThrows; 022 023import java.io.IOException; 024import java.net.SocketAddress; 025import java.util.Collections; 026import java.util.Map; 027import java.util.Set; 028import java.util.concurrent.CompletableFuture; 029import java.util.concurrent.ExecutorService; 030import java.util.concurrent.Executors; 031import java.util.concurrent.atomic.AtomicInteger; 032import java.util.stream.Collectors; 033import java.util.stream.IntStream; 034import org.apache.hadoop.conf.Configuration; 035import org.apache.hadoop.hbase.HBaseCommonTestingUtil; 036import org.apache.hadoop.hbase.ServerName; 037import org.apache.hadoop.hbase.ipc.RpcClient; 038import org.apache.hadoop.hbase.ipc.RpcClientFactory; 039import org.apache.hadoop.hbase.security.User; 040import org.apache.hadoop.hbase.testclassification.ClientTests; 041import org.apache.hadoop.hbase.testclassification.SmallTests; 042import org.apache.hadoop.hbase.util.FutureUtils; 043import org.junit.jupiter.api.AfterAll; 044import org.junit.jupiter.api.BeforeAll; 045import org.junit.jupiter.api.BeforeEach; 046import org.junit.jupiter.api.Tag; 047import org.junit.jupiter.api.Test; 048import org.slf4j.Logger; 049import org.slf4j.LoggerFactory; 050 051import org.apache.hbase.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder; 052import org.apache.hbase.thirdparty.com.google.protobuf.BlockingRpcChannel; 053import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescriptor; 054import org.apache.hbase.thirdparty.com.google.protobuf.Message; 055import org.apache.hbase.thirdparty.com.google.protobuf.RpcCallback; 056import org.apache.hbase.thirdparty.com.google.protobuf.RpcChannel; 057import org.apache.hbase.thirdparty.com.google.protobuf.RpcController; 058 059import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.ConnectionRegistryService; 060import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.GetClusterIdResponse; 061import org.apache.hadoop.hbase.shaded.protobuf.generated.RegistryProtos.GetConnectionRegistryResponse; 062 063@Tag(ClientTests.TAG) 064@Tag(SmallTests.TAG) 065public class TestRpcBasedRegistryHedgedReads { 066 067 private static final Logger LOG = LoggerFactory.getLogger(TestRpcBasedRegistryHedgedReads.class); 068 069 private static final String HEDGED_REQS_FANOUT_CONFIG_NAME = "hbase.test.hedged.reqs.fanout"; 070 private static final String INITIAL_DELAY_SECS_CONFIG_NAME = 071 "hbase.test.refresh.initial.delay.secs"; 072 private static final String REFRESH_INTERVAL_SECS_CONFIG_NAME = 073 "hbase.test.refresh.interval.secs"; 074 private static final String MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME = 075 "hbase.test.min.refresh.interval.secs"; 076 077 private static final HBaseCommonTestingUtil UTIL = new HBaseCommonTestingUtil(); 078 079 private static final ExecutorService EXECUTOR = 080 Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).build()); 081 082 private static Set<ServerName> BOOTSTRAP_NODES; 083 084 private static AtomicInteger CALLED = new AtomicInteger(0); 085 086 private static volatile int BAD_RESP_INDEX; 087 088 private static volatile Set<Integer> GOOD_RESP_INDEXS; 089 090 private static GetClusterIdResponse RESP = 091 GetClusterIdResponse.newBuilder().setClusterId("id").build(); 092 093 public static final class RpcClientImpl implements RpcClient { 094 095 public RpcClientImpl(Configuration configuration, String clusterId, SocketAddress localAddress, 096 MetricsConnection metrics, Map<String, byte[]> attributes) { 097 } 098 099 @Override 100 public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, int rpcTimeout) { 101 throw new UnsupportedOperationException(); 102 } 103 104 @Override 105 public RpcChannel createRpcChannel(ServerName sn, User user, int rpcTimeout) { 106 return new RpcChannelImpl(); 107 } 108 109 @Override 110 public void cancelConnections(ServerName sn) { 111 } 112 113 @Override 114 public void close() { 115 } 116 117 @Override 118 public boolean hasCellBlockSupport() { 119 return false; 120 } 121 } 122 123 /** 124 * A dummy RpcChannel implementation that intercepts the GetClusterId() RPC calls and injects 125 * errors. All other RPCs are ignored. 126 */ 127 public static final class RpcChannelImpl implements RpcChannel { 128 129 @Override 130 public void callMethod(MethodDescriptor method, RpcController controller, Message request, 131 Message responsePrototype, RpcCallback<Message> done) { 132 if (method.getService().equals(ConnectionRegistryService.getDescriptor())) { 133 // this is for setting up the rpc client 134 done.run( 135 GetConnectionRegistryResponse.newBuilder().setClusterId(RESP.getClusterId()).build()); 136 return; 137 } 138 if (!method.getName().equals("GetClusterId")) { 139 // On RPC failures, MasterRegistry internally runs getMasters() RPC to keep the master list 140 // fresh. We do not want to intercept those RPCs here and double count. 141 return; 142 } 143 // simulate the asynchronous behavior otherwise all logic will perform in the same thread... 144 EXECUTOR.execute(() -> { 145 int index = CALLED.getAndIncrement(); 146 if (index == BAD_RESP_INDEX) { 147 done.run(GetClusterIdResponse.getDefaultInstance()); 148 } else if (GOOD_RESP_INDEXS.contains(index)) { 149 done.run(RESP); 150 } else { 151 controller.setFailed("inject error"); 152 done.run(null); 153 } 154 }); 155 } 156 } 157 158 private AbstractRpcBasedConnectionRegistry createRegistry(int hedged) throws IOException { 159 Configuration conf = UTIL.getConfiguration(); 160 conf.setInt(HEDGED_REQS_FANOUT_CONFIG_NAME, hedged); 161 return new AbstractRpcBasedConnectionRegistry(conf, User.getCurrent(), 162 HEDGED_REQS_FANOUT_CONFIG_NAME, INITIAL_DELAY_SECS_CONFIG_NAME, 163 REFRESH_INTERVAL_SECS_CONFIG_NAME, MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME) { 164 165 @Override 166 protected Set<ServerName> getBootstrapNodes(Configuration conf) throws IOException { 167 return BOOTSTRAP_NODES; 168 } 169 170 @Override 171 protected CompletableFuture<Set<ServerName>> fetchEndpoints() { 172 return CompletableFuture.completedFuture(BOOTSTRAP_NODES); 173 } 174 175 @Override 176 public String getConnectionString() { 177 return "unimplemented"; 178 } 179 }; 180 } 181 182 @BeforeAll 183 public static void setUpBeforeClass() { 184 Configuration conf = UTIL.getConfiguration(); 185 conf.setClass(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, RpcClientImpl.class, 186 RpcClient.class); 187 // disable refresh, we do not need to refresh in this test 188 conf.setLong(INITIAL_DELAY_SECS_CONFIG_NAME, Integer.MAX_VALUE); 189 conf.setLong(REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE); 190 conf.setLong(MIN_REFRESH_INTERVAL_SECS_CONFIG_NAME, Integer.MAX_VALUE - 1); 191 BOOTSTRAP_NODES = IntStream.range(0, 10) 192 .mapToObj(i -> ServerName.valueOf("localhost", (10000 + 100 * i), ServerName.NON_STARTCODE)) 193 .collect(Collectors.toSet()); 194 } 195 196 @AfterAll 197 public static void tearDownAfterClass() { 198 EXECUTOR.shutdownNow(); 199 } 200 201 @BeforeEach 202 public void setUp() { 203 CALLED.set(0); 204 BAD_RESP_INDEX = -1; 205 GOOD_RESP_INDEXS = Collections.emptySet(); 206 } 207 208 private <T> T logIfError(CompletableFuture<T> future) throws IOException { 209 try { 210 return FutureUtils.get(future); 211 } catch (Throwable t) { 212 LOG.warn("", t); 213 throw t; 214 } 215 } 216 217 @Test 218 public void testAllFailNoHedged() throws IOException { 219 try (AbstractRpcBasedConnectionRegistry registry = createRegistry(1)) { 220 assertThrows(IOException.class, () -> logIfError(registry.getClusterId())); 221 assertEquals(10, CALLED.get()); 222 } 223 } 224 225 @Test 226 public void testAllFailHedged3() throws IOException { 227 BAD_RESP_INDEX = 5; 228 try (AbstractRpcBasedConnectionRegistry registry = createRegistry(3)) { 229 assertThrows(IOException.class, () -> logIfError(registry.getClusterId())); 230 assertEquals(10, CALLED.get()); 231 } 232 } 233 234 @Test 235 public void testFirstSucceededNoHedge() throws IOException { 236 GOOD_RESP_INDEXS = 237 IntStream.range(0, 10).mapToObj(Integer::valueOf).collect(Collectors.toSet()); 238 // will be set to 1 239 try (AbstractRpcBasedConnectionRegistry registry = createRegistry(0)) { 240 String clusterId = logIfError(registry.getClusterId()); 241 assertEquals(RESP.getClusterId(), clusterId); 242 assertEquals(1, CALLED.get()); 243 } 244 } 245 246 @Test 247 public void testSecondRoundSucceededHedge4() throws IOException { 248 GOOD_RESP_INDEXS = Collections.singleton(6); 249 try (AbstractRpcBasedConnectionRegistry registry = createRegistry(4)) { 250 String clusterId = logIfError(registry.getClusterId()); 251 assertEquals(RESP.getClusterId(), clusterId); 252 UTIL.waitFor(5000, () -> CALLED.get() == 8); 253 } 254 } 255 256 @Test 257 public void testSucceededWithLargestHedged() throws IOException, InterruptedException { 258 GOOD_RESP_INDEXS = Collections.singleton(5); 259 try (AbstractRpcBasedConnectionRegistry registry = createRegistry(Integer.MAX_VALUE)) { 260 String clusterId = logIfError(registry.getClusterId()); 261 assertEquals(RESP.getClusterId(), clusterId); 262 UTIL.waitFor(5000, () -> CALLED.get() == 10); 263 Thread.sleep(1000); 264 // make sure we do not send more 265 assertEquals(10, CALLED.get()); 266 } 267 } 268}