001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.security;
019
020import java.io.BufferedInputStream;
021import java.io.BufferedOutputStream;
022import java.io.DataInputStream;
023import java.io.DataOutputStream;
024import java.io.FilterInputStream;
025import java.io.FilterOutputStream;
026import java.io.IOException;
027import java.io.InputStream;
028import java.io.OutputStream;
029import java.net.InetAddress;
030import java.nio.ByteBuffer;
031import javax.security.sasl.Sasl;
032import javax.security.sasl.SaslException;
033import org.apache.hadoop.conf.Configuration;
034import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
035import org.apache.hadoop.hbase.ipc.FallbackDisallowedException;
036import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
037import org.apache.hadoop.io.WritableUtils;
038import org.apache.hadoop.ipc.RemoteException;
039import org.apache.hadoop.security.SaslInputStream;
040import org.apache.hadoop.security.SaslOutputStream;
041import org.apache.hadoop.security.token.Token;
042import org.apache.hadoop.security.token.TokenIdentifier;
043import org.apache.yetus.audience.InterfaceAudience;
044import org.slf4j.Logger;
045import org.slf4j.LoggerFactory;
046
047import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
048
049/**
050 * A utility class that encapsulates SASL logic for RPC client. Copied from
051 * <code>org.apache.hadoop.security</code>
052 */
053@InterfaceAudience.Private
054public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
055
056  private static final Logger LOG = LoggerFactory.getLogger(HBaseSaslRpcClient.class);
057  private boolean cryptoAesEnable;
058  private CryptoAES cryptoAES;
059  private InputStream saslInputStream;
060  private InputStream cryptoInputStream;
061  private OutputStream saslOutputStream;
062  private OutputStream cryptoOutputStream;
063  private boolean initStreamForCrypto;
064
065  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
066    Token<? extends TokenIdentifier> token, InetAddress serverAddr, String servicePrincipal,
067    boolean fallbackAllowed) throws IOException {
068    super(conf, provider, token, serverAddr, servicePrincipal, fallbackAllowed);
069  }
070
071  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
072    Token<? extends TokenIdentifier> token, InetAddress serverAddr, String servicePrincipal,
073    boolean fallbackAllowed, String rpcProtection, boolean initStreamForCrypto) throws IOException {
074    super(conf, provider, token, serverAddr, servicePrincipal, fallbackAllowed, rpcProtection);
075    this.initStreamForCrypto = initStreamForCrypto;
076  }
077
078  private static void readStatus(DataInputStream inStream) throws IOException {
079    int status = inStream.readInt(); // read status
080    if (status != SaslStatus.SUCCESS.state) {
081      throw new RemoteException(WritableUtils.readString(inStream),
082        WritableUtils.readString(inStream));
083    }
084  }
085
086  /**
087   * Do client side SASL authentication with server via the given InputStream and OutputStream
088   * @param inS  InputStream to use
089   * @param outS OutputStream to use
090   * @return true if connection is set up, or false if needs to switch to simple Auth.
091   */
092  public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException {
093    DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
094    DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
095
096    try {
097      byte[] saslToken = getInitialResponse();
098      if (saslToken != null) {
099        outStream.writeInt(saslToken.length);
100        outStream.write(saslToken, 0, saslToken.length);
101        outStream.flush();
102        if (LOG.isDebugEnabled()) {
103          LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
104        }
105      }
106      if (!isComplete()) {
107        readStatus(inStream);
108        int len = inStream.readInt();
109        if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
110          if (!fallbackAllowed) {
111            throw new FallbackDisallowedException();
112          }
113          LOG.debug("Server asks us to fall back to simple auth.");
114          dispose();
115          return false;
116        }
117        saslToken = new byte[len];
118        if (LOG.isDebugEnabled()) {
119          LOG.debug("Will read input token of size " + saslToken.length
120            + " for processing by initSASLContext");
121        }
122        inStream.readFully(saslToken);
123      }
124
125      while (!isComplete()) {
126        saslToken = evaluateChallenge(saslToken);
127        if (saslToken != null) {
128          if (LOG.isDebugEnabled()) {
129            LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
130          }
131          outStream.writeInt(saslToken.length);
132          outStream.write(saslToken, 0, saslToken.length);
133          outStream.flush();
134        }
135        if (!isComplete()) {
136          readStatus(inStream);
137          saslToken = new byte[inStream.readInt()];
138          if (LOG.isDebugEnabled()) {
139            LOG.debug("Will read input token of size " + saslToken.length
140              + " for processing by initSASLContext");
141          }
142          inStream.readFully(saslToken);
143        }
144      }
145
146      if (LOG.isDebugEnabled()) {
147        LOG.debug("SASL client context established. Negotiated QoP: "
148          + saslClient.getNegotiatedProperty(Sasl.QOP));
149      }
150      // initial the inputStream, outputStream for both Sasl encryption
151      // and Crypto AES encryption if necessary
152      // if Crypto AES encryption enabled, the saslInputStream/saslOutputStream is
153      // only responsible for connection header negotiation,
154      // cryptoInputStream/cryptoOutputStream is responsible for rpc encryption with Crypto AES
155      saslInputStream = new SaslInputStream(inS, saslClient);
156      saslOutputStream = new SaslOutputStream(outS, saslClient);
157      if (initStreamForCrypto) {
158        cryptoInputStream = new WrappedInputStream(inS);
159        cryptoOutputStream = new WrappedOutputStream(outS);
160      }
161
162      return true;
163    } catch (IOException e) {
164      try {
165        saslClient.dispose();
166      } catch (SaslException ignored) {
167        // ignore further exceptions during cleanup
168      }
169      throw e;
170    }
171  }
172
173  public String getSaslQOP() {
174    return (String) saslClient.getNegotiatedProperty(Sasl.QOP);
175  }
176
177  public void initCryptoCipher(RPCProtos.CryptoCipherMeta cryptoCipherMeta, Configuration conf)
178    throws IOException {
179    // create SaslAES for client
180    cryptoAES = EncryptionUtil.createCryptoAES(cryptoCipherMeta, conf);
181    cryptoAesEnable = true;
182  }
183
184  /**
185   * Get a SASL wrapped InputStream. Can be called only after saslConnect() has been called.
186   * @return a SASL wrapped InputStream
187   */
188  public InputStream getInputStream() throws IOException {
189    if (!saslClient.isComplete()) {
190      throw new IOException("Sasl authentication exchange hasn't completed yet");
191    }
192    // If Crypto AES is enabled, return cryptoInputStream which unwrap the data with Crypto AES.
193    if (cryptoAesEnable && cryptoInputStream != null) {
194      return cryptoInputStream;
195    }
196    return saslInputStream;
197  }
198
199  class WrappedInputStream extends FilterInputStream {
200    private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0);
201
202    public WrappedInputStream(InputStream in) throws IOException {
203      super(in);
204    }
205
206    @Override
207    public int read() throws IOException {
208      byte[] b = new byte[1];
209      int n = read(b, 0, 1);
210      return (n != -1) ? b[0] : -1;
211    }
212
213    @Override
214    public int read(byte b[]) throws IOException {
215      return read(b, 0, b.length);
216    }
217
218    @Override
219    public synchronized int read(byte[] buf, int off, int len) throws IOException {
220      // fill the buffer with the next RPC message
221      if (unwrappedRpcBuffer.remaining() == 0) {
222        readNextRpcPacket();
223      }
224      // satisfy as much of the request as possible
225      int readLen = Math.min(len, unwrappedRpcBuffer.remaining());
226      unwrappedRpcBuffer.get(buf, off, readLen);
227      return readLen;
228    }
229
230    // unwrap messages with Crypto AES
231    private void readNextRpcPacket() throws IOException {
232      LOG.debug("reading next wrapped RPC packet");
233      DataInputStream dis = new DataInputStream(in);
234      int rpcLen = dis.readInt();
235      byte[] rpcBuf = new byte[rpcLen];
236      dis.readFully(rpcBuf);
237
238      // unwrap with Crypto AES
239      rpcBuf = cryptoAES.unwrap(rpcBuf, 0, rpcBuf.length);
240      if (LOG.isDebugEnabled()) {
241        LOG.debug("unwrapping token of length:" + rpcBuf.length);
242      }
243      unwrappedRpcBuffer = ByteBuffer.wrap(rpcBuf);
244    }
245  }
246
247  /**
248   * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has been called.
249   * @return a SASL wrapped OutputStream
250   */
251  public OutputStream getOutputStream() throws IOException {
252    if (!saslClient.isComplete()) {
253      throw new IOException("Sasl authentication exchange hasn't completed yet");
254    }
255    // If Crypto AES is enabled, return cryptoOutputStream which wrap the data with Crypto AES.
256    if (cryptoAesEnable && cryptoOutputStream != null) {
257      return cryptoOutputStream;
258    }
259    return saslOutputStream;
260  }
261
262  class WrappedOutputStream extends FilterOutputStream {
263    public WrappedOutputStream(OutputStream out) throws IOException {
264      super(out);
265    }
266
267    @Override
268    public void write(byte[] buf, int off, int len) throws IOException {
269      if (LOG.isDebugEnabled()) {
270        LOG.debug("wrapping token of length:" + len);
271      }
272
273      // wrap with Crypto AES
274      byte[] wrapped = cryptoAES.wrap(buf, off, len);
275      DataOutputStream dob = new DataOutputStream(out);
276      dob.writeInt(wrapped.length);
277      dob.write(wrapped, 0, wrapped.length);
278      dob.flush();
279    }
280  }
281}