001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.security;
019
020import java.io.BufferedInputStream;
021import java.io.BufferedOutputStream;
022import java.io.DataInputStream;
023import java.io.DataOutputStream;
024import java.io.FilterInputStream;
025import java.io.FilterOutputStream;
026import java.io.IOException;
027import java.io.InputStream;
028import java.io.OutputStream;
029import java.net.InetAddress;
030import java.nio.ByteBuffer;
031import javax.security.sasl.Sasl;
032import javax.security.sasl.SaslException;
033import org.apache.hadoop.conf.Configuration;
034import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
035import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
036import org.apache.hadoop.io.WritableUtils;
037import org.apache.hadoop.ipc.RemoteException;
038import org.apache.hadoop.security.SaslInputStream;
039import org.apache.hadoop.security.SaslOutputStream;
040import org.apache.hadoop.security.token.Token;
041import org.apache.hadoop.security.token.TokenIdentifier;
042import org.apache.yetus.audience.InterfaceAudience;
043import org.slf4j.Logger;
044import org.slf4j.LoggerFactory;
045
046import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
047
048/**
049 * A utility class that encapsulates SASL logic for RPC client. Copied from
050 * <code>org.apache.hadoop.security</code>
051 */
052@InterfaceAudience.Private
053public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
054
055  private static final Logger LOG = LoggerFactory.getLogger(HBaseSaslRpcClient.class);
056  private boolean cryptoAesEnable;
057  private CryptoAES cryptoAES;
058  private InputStream saslInputStream;
059  private InputStream cryptoInputStream;
060  private OutputStream saslOutputStream;
061  private OutputStream cryptoOutputStream;
062  private boolean initStreamForCrypto;
063
064  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
065    Token<? extends TokenIdentifier> token, InetAddress serverAddr, SecurityInfo securityInfo,
066    boolean fallbackAllowed) throws IOException {
067    super(conf, provider, token, serverAddr, securityInfo, fallbackAllowed);
068  }
069
070  public HBaseSaslRpcClient(Configuration conf, SaslClientAuthenticationProvider provider,
071    Token<? extends TokenIdentifier> token, InetAddress serverAddr, SecurityInfo securityInfo,
072    boolean fallbackAllowed, String rpcProtection, boolean initStreamForCrypto) throws IOException {
073    super(conf, provider, token, serverAddr, securityInfo, fallbackAllowed, rpcProtection);
074    this.initStreamForCrypto = initStreamForCrypto;
075  }
076
077  private static void readStatus(DataInputStream inStream) throws IOException {
078    int status = inStream.readInt(); // read status
079    if (status != SaslStatus.SUCCESS.state) {
080      throw new RemoteException(WritableUtils.readString(inStream),
081        WritableUtils.readString(inStream));
082    }
083  }
084
085  /**
086   * Do client side SASL authentication with server via the given InputStream and OutputStream
087   * @param inS  InputStream to use
088   * @param outS OutputStream to use
089   * @return true if connection is set up, or false if needs to switch to simple Auth. n
090   */
091  public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException {
092    DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
093    DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
094
095    try {
096      byte[] saslToken = getInitialResponse();
097      if (saslToken != null) {
098        outStream.writeInt(saslToken.length);
099        outStream.write(saslToken, 0, saslToken.length);
100        outStream.flush();
101        if (LOG.isDebugEnabled()) {
102          LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
103        }
104      }
105      if (!isComplete()) {
106        readStatus(inStream);
107        int len = inStream.readInt();
108        if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
109          if (!fallbackAllowed) {
110            throw new IOException("Server asks us to fall back to SIMPLE auth, "
111              + "but this client is configured to only allow secure connections.");
112          }
113          if (LOG.isDebugEnabled()) {
114            LOG.debug("Server asks us to fall back to simple auth.");
115          }
116          dispose();
117          return false;
118        }
119        saslToken = new byte[len];
120        if (LOG.isDebugEnabled()) {
121          LOG.debug("Will read input token of size " + saslToken.length
122            + " for processing by initSASLContext");
123        }
124        inStream.readFully(saslToken);
125      }
126
127      while (!isComplete()) {
128        saslToken = evaluateChallenge(saslToken);
129        if (saslToken != null) {
130          if (LOG.isDebugEnabled()) {
131            LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
132          }
133          outStream.writeInt(saslToken.length);
134          outStream.write(saslToken, 0, saslToken.length);
135          outStream.flush();
136        }
137        if (!isComplete()) {
138          readStatus(inStream);
139          saslToken = new byte[inStream.readInt()];
140          if (LOG.isDebugEnabled()) {
141            LOG.debug("Will read input token of size " + saslToken.length
142              + " for processing by initSASLContext");
143          }
144          inStream.readFully(saslToken);
145        }
146      }
147
148      if (LOG.isDebugEnabled()) {
149        LOG.debug("SASL client context established. Negotiated QoP: "
150          + saslClient.getNegotiatedProperty(Sasl.QOP));
151      }
152      // initial the inputStream, outputStream for both Sasl encryption
153      // and Crypto AES encryption if necessary
154      // if Crypto AES encryption enabled, the saslInputStream/saslOutputStream is
155      // only responsible for connection header negotiation,
156      // cryptoInputStream/cryptoOutputStream is responsible for rpc encryption with Crypto AES
157      saslInputStream = new SaslInputStream(inS, saslClient);
158      saslOutputStream = new SaslOutputStream(outS, saslClient);
159      if (initStreamForCrypto) {
160        cryptoInputStream = new WrappedInputStream(inS);
161        cryptoOutputStream = new WrappedOutputStream(outS);
162      }
163
164      return true;
165    } catch (IOException e) {
166      try {
167        saslClient.dispose();
168      } catch (SaslException ignored) {
169        // ignore further exceptions during cleanup
170      }
171      throw e;
172    }
173  }
174
175  public String getSaslQOP() {
176    return (String) saslClient.getNegotiatedProperty(Sasl.QOP);
177  }
178
179  public void initCryptoCipher(RPCProtos.CryptoCipherMeta cryptoCipherMeta, Configuration conf)
180    throws IOException {
181    // create SaslAES for client
182    cryptoAES = EncryptionUtil.createCryptoAES(cryptoCipherMeta, conf);
183    cryptoAesEnable = true;
184  }
185
186  /**
187   * Get a SASL wrapped InputStream. Can be called only after saslConnect() has been called.
188   * @return a SASL wrapped InputStream n
189   */
190  public InputStream getInputStream() throws IOException {
191    if (!saslClient.isComplete()) {
192      throw new IOException("Sasl authentication exchange hasn't completed yet");
193    }
194    // If Crypto AES is enabled, return cryptoInputStream which unwrap the data with Crypto AES.
195    if (cryptoAesEnable && cryptoInputStream != null) {
196      return cryptoInputStream;
197    }
198    return saslInputStream;
199  }
200
201  class WrappedInputStream extends FilterInputStream {
202    private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0);
203
204    public WrappedInputStream(InputStream in) throws IOException {
205      super(in);
206    }
207
208    @Override
209    public int read() throws IOException {
210      byte[] b = new byte[1];
211      int n = read(b, 0, 1);
212      return (n != -1) ? b[0] : -1;
213    }
214
215    @Override
216    public int read(byte b[]) throws IOException {
217      return read(b, 0, b.length);
218    }
219
220    @Override
221    public synchronized int read(byte[] buf, int off, int len) throws IOException {
222      // fill the buffer with the next RPC message
223      if (unwrappedRpcBuffer.remaining() == 0) {
224        readNextRpcPacket();
225      }
226      // satisfy as much of the request as possible
227      int readLen = Math.min(len, unwrappedRpcBuffer.remaining());
228      unwrappedRpcBuffer.get(buf, off, readLen);
229      return readLen;
230    }
231
232    // unwrap messages with Crypto AES
233    private void readNextRpcPacket() throws IOException {
234      LOG.debug("reading next wrapped RPC packet");
235      DataInputStream dis = new DataInputStream(in);
236      int rpcLen = dis.readInt();
237      byte[] rpcBuf = new byte[rpcLen];
238      dis.readFully(rpcBuf);
239
240      // unwrap with Crypto AES
241      rpcBuf = cryptoAES.unwrap(rpcBuf, 0, rpcBuf.length);
242      if (LOG.isDebugEnabled()) {
243        LOG.debug("unwrapping token of length:" + rpcBuf.length);
244      }
245      unwrappedRpcBuffer = ByteBuffer.wrap(rpcBuf);
246    }
247  }
248
249  /**
250   * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has been called.
251   * @return a SASL wrapped OutputStream n
252   */
253  public OutputStream getOutputStream() throws IOException {
254    if (!saslClient.isComplete()) {
255      throw new IOException("Sasl authentication exchange hasn't completed yet");
256    }
257    // If Crypto AES is enabled, return cryptoOutputStream which wrap the data with Crypto AES.
258    if (cryptoAesEnable && cryptoOutputStream != null) {
259      return cryptoOutputStream;
260    }
261    return saslOutputStream;
262  }
263
264  class WrappedOutputStream extends FilterOutputStream {
265    public WrappedOutputStream(OutputStream out) throws IOException {
266      super(out);
267    }
268
269    @Override
270    public void write(byte[] buf, int off, int len) throws IOException {
271      if (LOG.isDebugEnabled()) {
272        LOG.debug("wrapping token of length:" + len);
273      }
274
275      // wrap with Crypto AES
276      byte[] wrapped = cryptoAES.wrap(buf, off, len);
277      DataOutputStream dob = new DataOutputStream(out);
278      dob.writeInt(wrapped.length);
279      dob.write(wrapped, 0, wrapped.length);
280      dob.flush();
281    }
282  }
283}