001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase;
019
020import java.io.IOException;
021import java.util.concurrent.CompletableFuture;
022import java.util.function.Supplier;
023import org.apache.hadoop.hbase.client.AsyncConnection;
024import org.apache.hadoop.hbase.client.Connection;
025import org.junit.jupiter.api.extension.AfterAllCallback;
026import org.junit.jupiter.api.extension.BeforeAllCallback;
027import org.junit.jupiter.api.extension.Extension;
028import org.junit.jupiter.api.extension.ExtensionContext;
029
030/**
031 * An {@link Extension} that manages the lifecycle of an instance of {@link AsyncConnection}.
032 * </p>
033 * Use in combination with {@link MiniClusterExtension}, for example:
034 *
035 * <pre>
036 * {
037 *   public class TestMyClass {
038 *
039 *     &#64;Order(1)
040 *     &#64;RegisterExtension
041 *     private static final MiniClusterExtension miniClusterExtension =
042 *        miniClusterExtension.newBuilder().build();
043 *
044 *     &#64;Order(2)
045 *     &#64;RegisterExtension
046 *     private static final ConnectionExtension connectionExtension =
047 *         ConnectionExtension.createAsyncConnectionExtension(
048 *            miniClusterExtension::createConnection);
049 *   }
050 * </pre>
051 */
052public final class ConnectionExtension implements BeforeAllCallback, AfterAllCallback {
053
054  private final Supplier<Connection> connectionSupplier;
055  private final Supplier<CompletableFuture<AsyncConnection>> asyncConnectionSupplier;
056
057  private Connection connection;
058  private AsyncConnection asyncConnection;
059
060  public static ConnectionExtension
061    createConnectionExtension(final Supplier<Connection> connectionSupplier) {
062    return new ConnectionExtension(connectionSupplier, null);
063  }
064
065  public static ConnectionExtension createAsyncConnectionExtension(
066    final Supplier<CompletableFuture<AsyncConnection>> asyncConnectionSupplier) {
067    return new ConnectionExtension(null, asyncConnectionSupplier);
068  }
069
070  public static ConnectionExtension createConnectionExtension(
071    final Supplier<Connection> connectionSupplier,
072    final Supplier<CompletableFuture<AsyncConnection>> asyncConnectionSupplier) {
073    return new ConnectionExtension(connectionSupplier, asyncConnectionSupplier);
074  }
075
076  private ConnectionExtension(final Supplier<Connection> connectionSupplier,
077    final Supplier<CompletableFuture<AsyncConnection>> asyncConnectionSupplier) {
078    this.connectionSupplier = connectionSupplier;
079    this.asyncConnectionSupplier = asyncConnectionSupplier;
080  }
081
082  public Connection getConnection() {
083    if (connection == null) {
084      throw new IllegalStateException(
085        "ConnectionExtension not initialized with a synchronous connection.");
086    }
087    return connection;
088  }
089
090  public AsyncConnection getAsyncConnection() {
091    if (asyncConnection == null) {
092      throw new IllegalStateException(
093        "ConnectionExtension not initialized with an asynchronous connection.");
094    }
095    return asyncConnection;
096  }
097
098  @Override
099  public void beforeAll(ExtensionContext context) {
100    if (connectionSupplier != null) {
101      this.connection = connectionSupplier.get();
102    }
103    if (asyncConnectionSupplier != null) {
104      this.asyncConnection = asyncConnectionSupplier.get().join();
105    }
106    if (connection == null && asyncConnection != null) {
107      this.connection = asyncConnection.toConnection();
108    }
109  }
110
111  @Override
112  public void afterAll(ExtensionContext context) {
113    CompletableFuture<Void> closeConnection = CompletableFuture.runAsync(() -> {
114      if (this.connection != null) {
115        try {
116          connection.close();
117        } catch (IOException e) {
118          throw new RuntimeException(e);
119        }
120      }
121    });
122    CompletableFuture<Void> closeAsyncConnection = CompletableFuture.runAsync(() -> {
123      if (this.asyncConnection != null) {
124        try {
125          asyncConnection.close();
126        } catch (IOException e) {
127          throw new RuntimeException(e);
128        }
129      }
130    });
131    CompletableFuture.allOf(closeConnection, closeAsyncConnection).join();
132  }
133}