001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.hbase.security;
020
021import java.io.BufferedInputStream;
022import java.io.BufferedOutputStream;
023import java.io.DataInputStream;
024import java.io.DataOutputStream;
025import java.io.FilterInputStream;
026import java.io.FilterOutputStream;
027import java.io.IOException;
028import java.io.InputStream;
029import java.io.OutputStream;
030import java.net.InetAddress;
031import java.nio.ByteBuffer;
032
033import javax.security.sasl.Sasl;
034import javax.security.sasl.SaslException;
035
036import org.apache.hadoop.conf.Configuration;
037import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
038import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
039import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
040import org.apache.hadoop.io.WritableUtils;
041import org.apache.hadoop.ipc.RemoteException;
042import org.apache.hadoop.security.SaslInputStream;
043import org.apache.hadoop.security.SaslOutputStream;
044import org.apache.hadoop.security.token.Token;
045import org.apache.hadoop.security.token.TokenIdentifier;
046import org.apache.yetus.audience.InterfaceAudience;
047import org.slf4j.Logger;
048import org.slf4j.LoggerFactory;
049
050/**
051 * A utility class that encapsulates SASL logic for RPC client. Copied from
052 * <code>org.apache.hadoop.security</code>
053 */
054@InterfaceAudience.Private
055public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
056
057  private static final Logger LOG = LoggerFactory.getLogger(HBaseSaslRpcClient.class);
058  private boolean cryptoAesEnable;
059  private CryptoAES cryptoAES;
060  private InputStream saslInputStream;
061  private InputStream cryptoInputStream;
062  private OutputStream saslOutputStream;
063  private OutputStream cryptoOutputStream;
064  private boolean initStreamForCrypto;
065
066  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
067      Token<? extends TokenIdentifier> token, InetAddress serverAddr, SecurityInfo securityInfo,
068      boolean fallbackAllowed) throws IOException {
069    super(conf, provider, token, serverAddr, securityInfo, fallbackAllowed);
070  }
071
072  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
073      Token<? extends TokenIdentifier> token, InetAddress serverAddr, SecurityInfo securityInfo,
074      boolean fallbackAllowed, String rpcProtection, boolean initStreamForCrypto)
075          throws IOException {
076    super(conf, provider, token, serverAddr, securityInfo, fallbackAllowed, rpcProtection);
077    this.initStreamForCrypto = initStreamForCrypto;
078  }
079
080  private static void readStatus(DataInputStream inStream) throws IOException {
081    int status = inStream.readInt(); // read status
082    if (status != SaslStatus.SUCCESS.state) {
083      throw new RemoteException(WritableUtils.readString(inStream),
084          WritableUtils.readString(inStream));
085    }
086  }
087
088  /**
089   * Do client side SASL authentication with server via the given InputStream and OutputStream
090   * @param inS InputStream to use
091   * @param outS OutputStream to use
092   * @return true if connection is set up, or false if needs to switch to simple Auth.
093   * @throws IOException
094   */
095  public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException {
096    DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
097    DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
098
099    try {
100      byte[] saslToken = getInitialResponse();
101      if (saslToken != null) {
102        outStream.writeInt(saslToken.length);
103        outStream.write(saslToken, 0, saslToken.length);
104        outStream.flush();
105        if (LOG.isDebugEnabled()) {
106          LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
107        }
108      }
109      if (!isComplete()) {
110        readStatus(inStream);
111        int len = inStream.readInt();
112        if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
113          if (!fallbackAllowed) {
114            throw new IOException("Server asks us to fall back to SIMPLE auth, "
115                + "but this client is configured to only allow secure connections.");
116          }
117          if (LOG.isDebugEnabled()) {
118            LOG.debug("Server asks us to fall back to simple auth.");
119          }
120          dispose();
121          return false;
122        }
123        saslToken = new byte[len];
124        if (LOG.isDebugEnabled()) {
125          LOG.debug("Will read input token of size " + saslToken.length
126              + " for processing by initSASLContext");
127        }
128        inStream.readFully(saslToken);
129      }
130
131      while (!isComplete()) {
132        saslToken = evaluateChallenge(saslToken);
133        if (saslToken != null) {
134          if (LOG.isDebugEnabled()) {
135            LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
136          }
137          outStream.writeInt(saslToken.length);
138          outStream.write(saslToken, 0, saslToken.length);
139          outStream.flush();
140        }
141        if (!isComplete()) {
142          readStatus(inStream);
143          saslToken = new byte[inStream.readInt()];
144          if (LOG.isDebugEnabled()) {
145            LOG.debug("Will read input token of size " + saslToken.length
146                + " for processing by initSASLContext");
147          }
148          inStream.readFully(saslToken);
149        }
150      }
151
152      try {
153        readStatus(inStream);
154      }
155      catch (IOException e){
156        if(e instanceof RemoteException){
157          LOG.debug("Sasl connection failed: ", e);
158          throw e;
159        }
160      }
161      if (LOG.isDebugEnabled()) {
162        LOG.debug("SASL client context established. Negotiated QoP: "
163            + saslClient.getNegotiatedProperty(Sasl.QOP));
164      }
165      // initial the inputStream, outputStream for both Sasl encryption
166      // and Crypto AES encryption if necessary
167      // if Crypto AES encryption enabled, the saslInputStream/saslOutputStream is
168      // only responsible for connection header negotiation,
169      // cryptoInputStream/cryptoOutputStream is responsible for rpc encryption with Crypto AES
170      saslInputStream = new SaslInputStream(inS, saslClient);
171      saslOutputStream = new SaslOutputStream(outS, saslClient);
172      if (initStreamForCrypto) {
173        cryptoInputStream = new WrappedInputStream(inS);
174        cryptoOutputStream = new WrappedOutputStream(outS);
175      }
176
177      return true;
178    } catch (IOException e) {
179      try {
180        saslClient.dispose();
181      } catch (SaslException ignored) {
182        // ignore further exceptions during cleanup
183      }
184      throw e;
185    }
186  }
187
188  public String getSaslQOP() {
189    return (String) saslClient.getNegotiatedProperty(Sasl.QOP);
190  }
191
192  public void initCryptoCipher(RPCProtos.CryptoCipherMeta cryptoCipherMeta,
193      Configuration conf) throws IOException {
194    // create SaslAES for client
195    cryptoAES = EncryptionUtil.createCryptoAES(cryptoCipherMeta, conf);
196    cryptoAesEnable = true;
197  }
198
199  /**
200   * Get a SASL wrapped InputStream. Can be called only after saslConnect() has been called.
201   * @return a SASL wrapped InputStream
202   * @throws IOException
203   */
204  public InputStream getInputStream() throws IOException {
205    if (!saslClient.isComplete()) {
206      throw new IOException("Sasl authentication exchange hasn't completed yet");
207    }
208    // If Crypto AES is enabled, return cryptoInputStream which unwrap the data with Crypto AES.
209    if (cryptoAesEnable && cryptoInputStream != null) {
210      return cryptoInputStream;
211    }
212    return saslInputStream;
213  }
214
215  class WrappedInputStream extends FilterInputStream {
216    private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0);
217    public WrappedInputStream(InputStream in) throws IOException {
218      super(in);
219    }
220
221    @Override
222    public int read() throws IOException {
223      byte[] b = new byte[1];
224      int n = read(b, 0, 1);
225      return (n != -1) ? b[0] : -1;
226    }
227
228    @Override
229    public int read(byte b[]) throws IOException {
230      return read(b, 0, b.length);
231    }
232
233    @Override
234    public synchronized int read(byte[] buf, int off, int len) throws IOException {
235      // fill the buffer with the next RPC message
236      if (unwrappedRpcBuffer.remaining() == 0) {
237        readNextRpcPacket();
238      }
239      // satisfy as much of the request as possible
240      int readLen = Math.min(len, unwrappedRpcBuffer.remaining());
241      unwrappedRpcBuffer.get(buf, off, readLen);
242      return readLen;
243    }
244
245    // unwrap messages with Crypto AES
246    private void readNextRpcPacket() throws IOException {
247      LOG.debug("reading next wrapped RPC packet");
248      DataInputStream dis = new DataInputStream(in);
249      int rpcLen = dis.readInt();
250      byte[] rpcBuf = new byte[rpcLen];
251      dis.readFully(rpcBuf);
252
253      // unwrap with Crypto AES
254      rpcBuf = cryptoAES.unwrap(rpcBuf, 0, rpcBuf.length);
255      if (LOG.isDebugEnabled()) {
256        LOG.debug("unwrapping token of length:" + rpcBuf.length);
257      }
258      unwrappedRpcBuffer = ByteBuffer.wrap(rpcBuf);
259    }
260  }
261
262  /**
263   * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has been called.
264   * @return a SASL wrapped OutputStream
265   * @throws IOException
266   */
267  public OutputStream getOutputStream() throws IOException {
268    if (!saslClient.isComplete()) {
269      throw new IOException("Sasl authentication exchange hasn't completed yet");
270    }
271    // If Crypto AES is enabled, return cryptoOutputStream which wrap the data with Crypto AES.
272    if (cryptoAesEnable && cryptoOutputStream != null) {
273      return cryptoOutputStream;
274    }
275    return saslOutputStream;
276  }
277
278  class WrappedOutputStream extends FilterOutputStream {
279    public WrappedOutputStream(OutputStream out) throws IOException {
280      super(out);
281    }
282    @Override
283    public void write(byte[] buf, int off, int len) throws IOException {
284      if (LOG.isDebugEnabled()) {
285        LOG.debug("wrapping token of length:" + len);
286      }
287
288      // wrap with Crypto AES
289      byte[] wrapped = cryptoAES.wrap(buf, off, len);
290      DataOutputStream dob = new DataOutputStream(out);
291      dob.writeInt(wrapped.length);
292      dob.write(wrapped, 0, wrapped.length);
293      dob.flush();
294    }
295  }
296}