1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.security;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelDuplexHandler;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelPromise;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30 import org.apache.hadoop.hbase.classification.InterfaceAudience;
31 import org.apache.hadoop.ipc.RemoteException;
32 import org.apache.hadoop.security.UserGroupInformation;
33 import org.apache.hadoop.security.token.Token;
34 import org.apache.hadoop.security.token.TokenIdentifier;
35
36 import javax.security.auth.callback.CallbackHandler;
37 import javax.security.sasl.Sasl;
38 import javax.security.sasl.SaslClient;
39 import javax.security.sasl.SaslException;
40
41 import java.io.IOException;
42 import java.nio.charset.Charset;
43 import java.security.PrivilegedExceptionAction;
44 import java.util.Random;
45
46
47
48
49 @InterfaceAudience.Private
50 public class SaslClientHandler extends ChannelDuplexHandler {
51 public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52
53 private final boolean fallbackAllowed;
54
55 private final UserGroupInformation ticket;
56
57
58
59
60 private final SaslClient saslClient;
61 private final SaslExceptionHandler exceptionHandler;
62 private final SaslSuccessfulConnectHandler successfulConnectHandler;
63 private byte[] saslToken;
64 private boolean firstRead = true;
65
66 private int retryCount = 0;
67 private Random random;
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
84 Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
85 String rpcProtection, SaslExceptionHandler exceptionHandler,
86 SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
87 this.ticket = ticket;
88 this.fallbackAllowed = fallbackAllowed;
89
90 this.exceptionHandler = exceptionHandler;
91 this.successfulConnectHandler = successfulConnectHandler;
92
93 SaslUtil.initSaslProperties(rpcProtection);
94 switch (method) {
95 case DIGEST:
96 if (LOG.isDebugEnabled())
97 LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
98 + " client to authenticate to service at " + token.getService());
99 saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
100 SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
101 break;
102 case KERBEROS:
103 if (LOG.isDebugEnabled()) {
104 LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
105 + " client. Server's Kerberos principal name is " + serverPrincipal);
106 }
107 if (serverPrincipal == null || serverPrincipal.isEmpty()) {
108 throw new IOException("Failed to specify server's Kerberos principal name");
109 }
110 String[] names = SaslUtil.splitKerberosName(serverPrincipal);
111 if (names.length != 3) {
112 throw new IOException(
113 "Kerberos principal does not have the expected format: " + serverPrincipal);
114 }
115 saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
116 names[0], names[1]);
117 break;
118 default:
119 throw new IOException("Unknown authentication method " + method);
120 }
121 if (saslClient == null) {
122 throw new IOException("Unable to find SASL client implementation");
123 }
124 }
125
126
127
128
129
130
131
132
133
134
135 protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
136 CallbackHandler saslClientCallbackHandler) throws IOException {
137 return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
138 saslClientCallbackHandler);
139 }
140
141
142
143
144
145
146
147
148
149
150 protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
151 String userSecondPart) throws IOException {
152 return Sasl
153 .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
154 null);
155 }
156
157 @Override
158 public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
159 saslClient.dispose();
160 }
161
162 private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
163 return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
164
165 @Override
166 public byte[] run() throws Exception {
167 return saslClient.evaluateChallenge(challenge);
168 }
169 });
170 }
171
172 @Override
173 public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
174 saslToken = new byte[0];
175 if (saslClient.hasInitialResponse()) {
176 saslToken = evaluateChallenge(saslToken);
177 }
178 if (saslToken != null) {
179 writeSaslToken(ctx, saslToken);
180 if (LOG.isDebugEnabled()) {
181 LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
182 }
183 }
184 }
185
186 @Override
187 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
188 ByteBuf in = (ByteBuf) msg;
189
190
191 if (!saslClient.isComplete()) {
192 while (!saslClient.isComplete() && in.isReadable()) {
193 readStatus(in);
194 int len = in.readInt();
195 if (firstRead) {
196 firstRead = false;
197 if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
198 if (!fallbackAllowed) {
199 throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
200 + "client is configured to only allow secure connections.");
201 }
202 if (LOG.isDebugEnabled()) {
203 LOG.debug("Server asks us to fall back to simple auth.");
204 }
205 saslClient.dispose();
206
207 ctx.pipeline().remove(this);
208 successfulConnectHandler.onSuccess(ctx.channel());
209 return;
210 }
211 }
212 saslToken = new byte[len];
213 if (LOG.isDebugEnabled()) {
214 LOG.debug("Will read input token of size " + saslToken.length
215 + " for processing by initSASLContext");
216 }
217 in.readBytes(saslToken);
218
219 saslToken = evaluateChallenge(saslToken);
220 if (saslToken != null) {
221 if (LOG.isDebugEnabled()) {
222 LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
223 }
224 writeSaslToken(ctx, saslToken);
225 }
226 }
227
228 if (saslClient.isComplete()) {
229 String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
230
231 if (LOG.isDebugEnabled()) {
232 LOG.debug("SASL client context established. Negotiated QoP: " + qop);
233 }
234
235 boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
236
237 if (!useWrap) {
238 ctx.pipeline().remove(this);
239 }
240 successfulConnectHandler.onSuccess(ctx.channel());
241 }
242 }
243
244 else {
245 try {
246 int length = in.readInt();
247 if (LOG.isDebugEnabled()) {
248 LOG.debug("Actual length is " + length);
249 }
250 saslToken = new byte[length];
251 in.readBytes(saslToken);
252 } catch (IndexOutOfBoundsException e) {
253 return;
254 }
255 try {
256 ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
257
258 b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
259 ctx.fireChannelRead(b);
260
261 } catch (SaslException se) {
262 try {
263 saslClient.dispose();
264 } catch (SaslException ignored) {
265 LOG.debug("Ignoring SASL exception", ignored);
266 }
267 throw se;
268 }
269 }
270 }
271
272
273
274
275
276
277 private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
278 ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
279 b.writeInt(saslToken.length);
280 b.writeBytes(saslToken, 0, saslToken.length);
281 ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
282 @Override
283 public void operationComplete(ChannelFuture future) throws Exception {
284 if (!future.isSuccess()) {
285 exceptionCaught(ctx, future.cause());
286 }
287 }
288 });
289 }
290
291
292
293
294
295
296
297 private static void readStatus(ByteBuf inStream) throws RemoteException {
298 int status = inStream.readInt();
299 if (status != SaslStatus.SUCCESS.state) {
300 throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
301 inStream.toString(Charset.forName("UTF-8")));
302 }
303 }
304
305 @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
306 throws Exception {
307 saslClient.dispose();
308
309 ctx.close();
310
311 if (this.random == null) {
312 this.random = new Random();
313 }
314 exceptionHandler.handle(this.retryCount++, this.random, cause);
315 }
316
317 @Override
318 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
319 throws Exception {
320
321 if (!saslClient.isComplete()) {
322 super.write(ctx, msg, promise);
323 } else {
324 ByteBuf in = (ByteBuf) msg;
325
326 try {
327 saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
328 } catch (SaslException se) {
329 try {
330 saslClient.dispose();
331 } catch (SaslException ignored) {
332 LOG.debug("Ignoring SASL exception", ignored);
333 }
334 promise.setFailure(se);
335 }
336 if (saslToken != null) {
337 ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
338 out.writeInt(saslToken.length);
339 out.writeBytes(saslToken, 0, saslToken.length);
340
341 ctx.write(out).addListener(new ChannelFutureListener() {
342 @Override public void operationComplete(ChannelFuture future) throws Exception {
343 if (!future.isSuccess()) {
344 exceptionCaught(ctx, future.cause());
345 }
346 }
347 });
348
349 saslToken = null;
350 }
351 }
352 }
353
354
355
356
357 public interface SaslExceptionHandler {
358
359
360
361
362
363
364
365 public void handle(int retryCount, Random random, Throwable cause);
366 }
367
368
369
370
371 public interface SaslSuccessfulConnectHandler {
372
373
374
375
376
377 public void onSuccess(Channel channel);
378 }
379 }