001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.io.compress.zstd;
019
020import com.github.luben.zstd.ZstdDecompressCtx;
021import com.github.luben.zstd.ZstdDictDecompress;
022import edu.umd.cs.findbugs.annotations.Nullable;
023import java.io.IOException;
024import java.nio.ByteBuffer;
025import org.apache.hadoop.hbase.io.compress.BlockDecompressorHelper;
026import org.apache.hadoop.hbase.io.compress.ByteBuffDecompressor;
027import org.apache.hadoop.hbase.io.compress.Compression;
028import org.apache.hadoop.hbase.nio.ByteBuff;
029import org.apache.hadoop.hbase.nio.SingleByteBuff;
030import org.apache.yetus.audience.InterfaceAudience;
031
032/**
033 * Glue for ByteBuffDecompressor on top of zstd-jni
034 */
035@InterfaceAudience.Private
036public class ZstdByteBuffDecompressor implements ByteBuffDecompressor {
037
038  protected int dictId;
039  protected ZstdDecompressCtx ctx;
040  // Intended to be set to false by some unit tests
041  private boolean allowByteBuffDecompression;
042
043  ZstdByteBuffDecompressor(@Nullable byte[] dictionaryBytes) {
044    ctx = new ZstdDecompressCtx();
045    if (dictionaryBytes != null) {
046      this.ctx.loadDict(new ZstdDictDecompress(dictionaryBytes));
047      dictId = ZstdCodec.getDictionaryId(dictionaryBytes);
048    }
049    allowByteBuffDecompression = true;
050  }
051
052  @Override
053  public boolean canDecompress(ByteBuff output, ByteBuff input) {
054    return allowByteBuffDecompression && output instanceof SingleByteBuff
055      && input instanceof SingleByteBuff;
056  }
057
058  @Override
059  public int decompress(ByteBuff output, ByteBuff input, int inputLen) throws IOException {
060    return BlockDecompressorHelper.decompress(output, input, inputLen, this::decompressRaw);
061  }
062
063  private int decompressRaw(ByteBuff output, ByteBuff input, int inputLen) throws IOException {
064    if (output instanceof SingleByteBuff && input instanceof SingleByteBuff) {
065      ByteBuffer nioOutput = output.nioByteBuffers()[0];
066      ByteBuffer nioInput = input.nioByteBuffers()[0];
067      int origOutputPos = nioOutput.position();
068      int n;
069      if (nioOutput.isDirect() && nioInput.isDirect()) {
070        n = ctx.decompressDirectByteBuffer(nioOutput, nioOutput.position(),
071          nioOutput.limit() - nioOutput.position(), nioInput, nioInput.position(), inputLen);
072      } else if (!nioOutput.isDirect() && !nioInput.isDirect()) {
073        n = ctx.decompressByteArray(nioOutput.array(),
074          nioOutput.arrayOffset() + nioOutput.position(), nioOutput.limit() - nioOutput.position(),
075          nioInput.array(), nioInput.arrayOffset() + nioInput.position(), inputLen);
076      } else if (nioOutput.isDirect() && !nioInput.isDirect()) {
077        n = ctx.decompressByteArrayToDirectByteBuffer(nioOutput, nioOutput.position(),
078          nioOutput.limit() - nioOutput.position(), nioInput.array(),
079          nioInput.arrayOffset() + nioInput.position(), inputLen);
080      } else if (!nioOutput.isDirect() && nioInput.isDirect()) {
081        n = ctx.decompressDirectByteBufferToByteArray(nioOutput.array(),
082          nioOutput.arrayOffset() + nioOutput.position(), nioOutput.limit() - nioOutput.position(),
083          nioInput, nioInput.position(), inputLen);
084      } else {
085        throw new IllegalStateException("Unreachable line");
086      }
087
088      nioOutput.position(origOutputPos + n);
089      nioInput.position(input.position() + inputLen);
090
091      return n;
092    } else {
093      throw new IllegalStateException(
094        "At least one buffer is not a SingleByteBuff, this is not supported");
095    }
096  }
097
098  @Override
099  public void reinit(@Nullable Compression.HFileDecompressionContext newHFileDecompressionContext) {
100    if (newHFileDecompressionContext != null) {
101      if (newHFileDecompressionContext instanceof ZstdHFileDecompressionContext) {
102        ZstdHFileDecompressionContext zstdContext =
103          (ZstdHFileDecompressionContext) newHFileDecompressionContext;
104        allowByteBuffDecompression = zstdContext.isAllowByteBuffDecompression();
105        if (zstdContext.getDict() == null && dictId != 0) {
106          ctx.loadDict((byte[]) null);
107          dictId = 0;
108        } else if (zstdContext.getDictId() != dictId) {
109          this.ctx.loadDict(zstdContext.getDict());
110          this.dictId = zstdContext.getDictId();
111        }
112      } else {
113        throw new IllegalArgumentException(
114          "ZstdByteBuffDecompression#reinit() was given an HFileDecompressionContext that was not "
115            + "a ZstdHFileDecompressionContext, this should never happen");
116      }
117    }
118  }
119
120  @Override
121  public void close() {
122    ctx.close();
123  }
124
125}