View Javadoc

1   /**
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   * http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  package org.apache.hadoop.hbase.security;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  
28  import org.apache.commons.logging.Log;
29  import org.apache.commons.logging.LogFactory;
30  import org.apache.hadoop.hbase.classification.InterfaceAudience;
31  import org.apache.hadoop.ipc.RemoteException;
32  import org.apache.hadoop.security.UserGroupInformation;
33  import org.apache.hadoop.security.token.Token;
34  import org.apache.hadoop.security.token.TokenIdentifier;
35  
36  import javax.security.auth.callback.CallbackHandler;
37  import javax.security.sasl.Sasl;
38  import javax.security.sasl.SaslClient;
39  import javax.security.sasl.SaslException;
40  
41  import java.io.IOException;
42  import java.nio.charset.Charset;
43  import java.security.PrivilegedExceptionAction;
44  import java.util.Random;
45  
46  /**
47   * Handles Sasl connections
48   */
49  @InterfaceAudience.Private
50  public class SaslClientHandler extends ChannelDuplexHandler {
51    public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52  
53    private final boolean fallbackAllowed;
54  
55    private final UserGroupInformation ticket;
56  
57    /**
58     * Used for client or server's token to send or receive from each other.
59     */
60    private final SaslClient saslClient;
61    private final SaslExceptionHandler exceptionHandler;
62    private final SaslSuccessfulConnectHandler successfulConnectHandler;
63    private byte[] saslToken;
64    private boolean firstRead = true;
65  
66    private int retryCount = 0;
67    private Random random;
68  
69    /**
70     * Constructor
71     *
72     * @param ticket                   the ugi
73     * @param method                   auth method
74     * @param token                    for Sasl
75     * @param serverPrincipal          Server's Kerberos principal name
76     * @param fallbackAllowed          True if server may also fall back to less secure connection
77     * @param rpcProtection            Quality of protection. Can be 'authentication', 'integrity' or
78     *                                 'privacy'.
79     * @param exceptionHandler         handler for exceptions
80     * @param successfulConnectHandler handler for succesful connects
81     * @throws java.io.IOException if handler could not be created
82     */
83    public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
84        Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
85        String rpcProtection, SaslExceptionHandler exceptionHandler,
86        SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
87      this.ticket = ticket;
88      this.fallbackAllowed = fallbackAllowed;
89  
90      this.exceptionHandler = exceptionHandler;
91      this.successfulConnectHandler = successfulConnectHandler;
92  
93      SaslUtil.initSaslProperties(rpcProtection);
94      switch (method) {
95      case DIGEST:
96        if (LOG.isDebugEnabled())
97          LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
98              + " client to authenticate to service at " + token.getService());
99        saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
100           SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
101       break;
102     case KERBEROS:
103       if (LOG.isDebugEnabled()) {
104         LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
105             + " client. Server's Kerberos principal name is " + serverPrincipal);
106       }
107       if (serverPrincipal == null || serverPrincipal.isEmpty()) {
108         throw new IOException("Failed to specify server's Kerberos principal name");
109       }
110       String[] names = SaslUtil.splitKerberosName(serverPrincipal);
111       if (names.length != 3) {
112         throw new IOException(
113             "Kerberos principal does not have the expected format: " + serverPrincipal);
114       }
115       saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
116           names[0], names[1]);
117       break;
118     default:
119       throw new IOException("Unknown authentication method " + method);
120     }
121     if (saslClient == null) {
122       throw new IOException("Unable to find SASL client implementation");
123     }
124   }
125 
126   /**
127    * Create a Digest Sasl client
128    *
129    * @param mechanismNames            names of mechanisms
130    * @param saslDefaultRealm          default realm for sasl
131    * @param saslClientCallbackHandler handler for the client
132    * @return new SaslClient
133    * @throws java.io.IOException if creation went wrong
134    */
135   protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
136       CallbackHandler saslClientCallbackHandler) throws IOException {
137     return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
138         saslClientCallbackHandler);
139   }
140 
141   /**
142    * Create Kerberos client
143    *
144    * @param mechanismNames names of mechanisms
145    * @param userFirstPart  first part of username
146    * @param userSecondPart second part of username
147    * @return new SaslClient
148    * @throws java.io.IOException if fails
149    */
150   protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
151       String userSecondPart) throws IOException {
152     return Sasl
153         .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
154             null);
155   }
156 
157   @Override
158   public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
159     saslClient.dispose();
160   }
161 
162   private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
163     return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
164 
165       @Override
166       public byte[] run() throws Exception {
167         return saslClient.evaluateChallenge(challenge);
168       }
169     });
170   }
171 
172   @Override
173   public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
174     saslToken = new byte[0];
175     if (saslClient.hasInitialResponse()) {
176       saslToken = evaluateChallenge(saslToken);
177     }
178     if (saslToken != null) {
179       writeSaslToken(ctx, saslToken);
180       if (LOG.isDebugEnabled()) {
181         LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
182       }
183     }
184   }
185 
186   @Override
187   public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
188     ByteBuf in = (ByteBuf) msg;
189 
190     // If not complete, try to negotiate
191     if (!saslClient.isComplete()) {
192       while (!saslClient.isComplete() && in.isReadable()) {
193         readStatus(in);
194         int len = in.readInt();
195         if (firstRead) {
196           firstRead = false;
197           if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
198             if (!fallbackAllowed) {
199               throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
200                   + "client is configured to only allow secure connections.");
201             }
202             if (LOG.isDebugEnabled()) {
203               LOG.debug("Server asks us to fall back to simple auth.");
204             }
205             saslClient.dispose();
206 
207             ctx.pipeline().remove(this);
208             successfulConnectHandler.onSuccess(ctx.channel());
209             return;
210           }
211         }
212         saslToken = new byte[len];
213         if (LOG.isDebugEnabled()) {
214           LOG.debug("Will read input token of size " + saslToken.length
215               + " for processing by initSASLContext");
216         }
217         in.readBytes(saslToken);
218 
219         saslToken = evaluateChallenge(saslToken);
220         if (saslToken != null) {
221           if (LOG.isDebugEnabled()) {
222             LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
223           }
224           writeSaslToken(ctx, saslToken);
225         }
226       }
227 
228       if (saslClient.isComplete()) {
229         String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
230 
231         if (LOG.isDebugEnabled()) {
232           LOG.debug("SASL client context established. Negotiated QoP: " + qop);
233         }
234 
235         boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
236 
237         if (!useWrap) {
238           ctx.pipeline().remove(this);
239         }
240         successfulConnectHandler.onSuccess(ctx.channel());
241       }
242     }
243     // Normal wrapped reading
244     else {
245       try {
246         int length = in.readInt();
247         if (LOG.isDebugEnabled()) {
248           LOG.debug("Actual length is " + length);
249         }
250         saslToken = new byte[length];
251         in.readBytes(saslToken);
252       } catch (IndexOutOfBoundsException e) {
253         return;
254       }
255       try {
256         ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
257 
258         b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
259         ctx.fireChannelRead(b);
260 
261       } catch (SaslException se) {
262         try {
263           saslClient.dispose();
264         } catch (SaslException ignored) {
265           LOG.debug("Ignoring SASL exception", ignored);
266         }
267         throw se;
268       }
269     }
270   }
271 
272   /**
273    * Write SASL token
274    * @param ctx to write to
275    * @param saslToken to write
276    */
277   private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
278     ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
279     b.writeInt(saslToken.length);
280     b.writeBytes(saslToken, 0, saslToken.length);
281     ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
282       @Override
283       public void operationComplete(ChannelFuture future) throws Exception {
284         if (!future.isSuccess()) {
285           exceptionCaught(ctx, future.cause());
286         }
287       }
288     });
289   }
290 
291   /**
292    * Get the read status
293    *
294    * @param inStream to read
295    * @throws org.apache.hadoop.ipc.RemoteException if status was not success
296    */
297   private static void readStatus(ByteBuf inStream) throws RemoteException {
298     int status = inStream.readInt(); // read status
299     if (status != SaslStatus.SUCCESS.state) {
300       throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
301           inStream.toString(Charset.forName("UTF-8")));
302     }
303   }
304 
305   @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
306       throws Exception {
307     saslClient.dispose();
308 
309     ctx.close();
310 
311     if (this.random == null) {
312       this.random = new Random();
313     }
314     exceptionHandler.handle(this.retryCount++, this.random, cause);
315   }
316 
317   @Override
318   public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
319       throws Exception {
320     // If not complete, try to negotiate
321     if (!saslClient.isComplete()) {
322       super.write(ctx, msg, promise);
323     } else {
324       ByteBuf in = (ByteBuf) msg;
325 
326       try {
327         saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
328       } catch (SaslException se) {
329         try {
330           saslClient.dispose();
331         } catch (SaslException ignored) {
332           LOG.debug("Ignoring SASL exception", ignored);
333         }
334         promise.setFailure(se);
335       }
336       if (saslToken != null) {
337         ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
338         out.writeInt(saslToken.length);
339         out.writeBytes(saslToken, 0, saslToken.length);
340 
341         ctx.write(out).addListener(new ChannelFutureListener() {
342           @Override public void operationComplete(ChannelFuture future) throws Exception {
343             if (!future.isSuccess()) {
344               exceptionCaught(ctx, future.cause());
345             }
346           }
347         });
348 
349         saslToken = null;
350       }
351     }
352   }
353 
354   /**
355    * Handler for exceptions during Sasl connection
356    */
357   public interface SaslExceptionHandler {
358     /**
359      * Handle the exception
360      *
361      * @param retryCount current retry count
362      * @param random     to create new backoff with
363      * @param cause      of fail
364      */
365     public void handle(int retryCount, Random random, Throwable cause);
366   }
367 
368   /**
369    * Handler for successful connects
370    */
371   public interface SaslSuccessfulConnectHandler {
372     /**
373      * Runs on success
374      *
375      * @param channel which is successfully authenticated
376      */
377     public void onSuccess(Channel channel);
378   }
379 }