1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  package org.apache.hadoop.hbase.security;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  
28  import org.apache.commons.logging.Log;
29  import org.apache.commons.logging.LogFactory;
30  import org.apache.hadoop.hbase.classification.InterfaceAudience;
31  import org.apache.hadoop.ipc.RemoteException;
32  import org.apache.hadoop.security.UserGroupInformation;
33  import org.apache.hadoop.security.token.Token;
34  import org.apache.hadoop.security.token.TokenIdentifier;
35  
36  import javax.security.auth.callback.CallbackHandler;
37  import javax.security.sasl.Sasl;
38  import javax.security.sasl.SaslClient;
39  import javax.security.sasl.SaslException;
40  
41  import java.io.IOException;
42  import java.nio.charset.Charset;
43  import java.security.PrivilegedExceptionAction;
44  import java.util.Random;
45  
46  
47  
48  
49  @InterfaceAudience.Private
50  public class SaslClientHandler extends ChannelDuplexHandler {
51    private static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52  
53    private final boolean fallbackAllowed;
54  
55    private final UserGroupInformation ticket;
56  
57    
58  
59  
60    private final SaslClient saslClient;
61    private final SaslExceptionHandler exceptionHandler;
62    private final SaslSuccessfulConnectHandler successfulConnectHandler;
63    private byte[] saslToken;
64    private boolean firstRead = true;
65  
66    private int retryCount = 0;
67    private Random random;
68  
69    
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83    public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
84        Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
85        String rpcProtection, SaslExceptionHandler exceptionHandler,
86        SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
87      this.ticket = ticket;
88      this.fallbackAllowed = fallbackAllowed;
89  
90      this.exceptionHandler = exceptionHandler;
91      this.successfulConnectHandler = successfulConnectHandler;
92  
93      SaslUtil.initSaslProperties(rpcProtection);
94      switch (method) {
95      case DIGEST:
96        if (LOG.isDebugEnabled())
97          LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
98              + " client to authenticate to service at " + token.getService());
99        saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
100           SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
101       break;
102     case KERBEROS:
103       if (LOG.isDebugEnabled()) {
104         LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
105             + " client. Server's Kerberos principal name is " + serverPrincipal);
106       }
107       if (serverPrincipal == null || serverPrincipal.isEmpty()) {
108         throw new IOException("Failed to specify server's Kerberos principal name");
109       }
110       String[] names = SaslUtil.splitKerberosName(serverPrincipal);
111       if (names.length != 3) {
112         throw new IOException(
113             "Kerberos principal does not have the expected format: " + serverPrincipal);
114       }
115       saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
116           names[0], names[1]);
117       break;
118     default:
119       throw new IOException("Unknown authentication method " + method);
120     }
121     if (saslClient == null) {
122       throw new IOException("Unable to find SASL client implementation");
123     }
124   }
125 
126   
127 
128 
129 
130 
131 
132 
133 
134 
135   protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
136       CallbackHandler saslClientCallbackHandler) throws IOException {
137     return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
138         saslClientCallbackHandler);
139   }
140 
141   
142 
143 
144 
145 
146 
147 
148 
149 
150   protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
151       String userSecondPart) throws IOException {
152     return Sasl
153         .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
154             null);
155   }
156 
157   @Override
158   public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
159     saslClient.dispose();
160   }
161 
162   private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
163     return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
164 
165       @Override
166       public byte[] run() throws Exception {
167         return saslClient.evaluateChallenge(challenge);
168       }
169     });
170   }
171 
172   @Override
173   public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
174     saslToken = new byte[0];
175     if (saslClient.hasInitialResponse()) {
176       saslToken = evaluateChallenge(saslToken);
177     }
178     if (saslToken != null) {
179       writeSaslToken(ctx, saslToken);
180       if (LOG.isDebugEnabled()) {
181         LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
182       }
183     }
184   }
185 
186   @Override
187   public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
188     ByteBuf in = (ByteBuf) msg;
189 
190     
191     if (!saslClient.isComplete()) {
192       while (!saslClient.isComplete() && in.isReadable()) {
193         readStatus(in);
194         int len = in.readInt();
195         if (firstRead) {
196           firstRead = false;
197           if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
198             if (!fallbackAllowed) {
199               throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
200                   + "client is configured to only allow secure connections.");
201             }
202             if (LOG.isDebugEnabled()) {
203               LOG.debug("Server asks us to fall back to simple auth.");
204             }
205             saslClient.dispose();
206 
207             ctx.pipeline().remove(this);
208             successfulConnectHandler.onSuccess(ctx.channel());
209             return;
210           }
211         }
212         saslToken = new byte[len];
213         if (LOG.isDebugEnabled()) {
214           LOG.debug("Will read input token of size " + saslToken.length
215               + " for processing by initSASLContext");
216         }
217         in.readBytes(saslToken);
218 
219         saslToken = evaluateChallenge(saslToken);
220         if (saslToken != null) {
221           if (LOG.isDebugEnabled()) {
222             LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
223           }
224           writeSaslToken(ctx, saslToken);
225         }
226       }
227 
228       if (saslClient.isComplete()) {
229         String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
230 
231         if (LOG.isDebugEnabled()) {
232           LOG.debug("SASL client context established. Negotiated QoP: " + qop);
233         }
234 
235         boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
236 
237         if (!useWrap) {
238           ctx.pipeline().remove(this);
239         }
240         successfulConnectHandler.onSuccess(ctx.channel());
241       }
242     }
243     
244     else {
245       try {
246         int length = in.readInt();
247         if (LOG.isDebugEnabled()) {
248           LOG.debug("Actual length is " + length);
249         }
250         saslToken = new byte[length];
251         in.readBytes(saslToken);
252       } catch (IndexOutOfBoundsException e) {
253         return;
254       }
255       try {
256         ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
257 
258         b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
259         ctx.fireChannelRead(b);
260 
261       } catch (SaslException se) {
262         try {
263           saslClient.dispose();
264         } catch (SaslException ignored) {
265           LOG.debug("Ignoring SASL exception", ignored);
266         }
267         throw se;
268       }
269     }
270   }
271 
272   
273 
274 
275 
276 
277   private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
278     ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
279     b.writeInt(saslToken.length);
280     b.writeBytes(saslToken, 0, saslToken.length);
281     ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
282       @Override
283       public void operationComplete(ChannelFuture future) throws Exception {
284         if (!future.isSuccess()) {
285           exceptionCaught(ctx, future.cause());
286         }
287       }
288     });
289   }
290 
291   
292 
293 
294 
295 
296 
297   private static void readStatus(ByteBuf inStream) throws RemoteException {
298     int status = inStream.readInt(); 
299     if (status != SaslStatus.SUCCESS.state) {
300       throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
301           inStream.toString(Charset.forName("UTF-8")));
302     }
303   }
304 
305   @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
306       throws Exception {
307     saslClient.dispose();
308 
309     ctx.close();
310 
311     if (this.random == null) {
312       this.random = new Random();
313     }
314     exceptionHandler.handle(this.retryCount++, this.random, cause);
315   }
316 
317   @Override
318   public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
319       throws Exception {
320     
321     if (!saslClient.isComplete()) {
322       super.write(ctx, msg, promise);
323     } else {
324       ByteBuf in = (ByteBuf) msg;
325 
326       try {
327         saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
328       } catch (SaslException se) {
329         try {
330           saslClient.dispose();
331         } catch (SaslException ignored) {
332           LOG.debug("Ignoring SASL exception", ignored);
333         }
334         promise.setFailure(se);
335       }
336       if (saslToken != null) {
337         ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
338         out.writeInt(saslToken.length);
339         out.writeBytes(saslToken, 0, saslToken.length);
340 
341         ctx.write(out).addListener(new ChannelFutureListener() {
342           @Override public void operationComplete(ChannelFuture future) throws Exception {
343             if (!future.isSuccess()) {
344               exceptionCaught(ctx, future.cause());
345             }
346           }
347         });
348 
349         saslToken = null;
350       }
351     }
352   }
353 
354   
355 
356 
357   public interface SaslExceptionHandler {
358     
359 
360 
361 
362 
363 
364 
365     public void handle(int retryCount, Random random, Throwable cause);
366   }
367 
368   
369 
370 
371   public interface SaslSuccessfulConnectHandler {
372     
373 
374 
375 
376 
377     public void onSuccess(Channel channel);
378   }
379 }