001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.security.provider;
019
020import java.io.IOException;
021import java.lang.reflect.InvocationTargetException;
022import java.util.HashMap;
023import java.util.Optional;
024import java.util.ServiceLoader;
025import java.util.concurrent.atomic.AtomicReference;
026import java.util.stream.Collectors;
027
028import org.apache.hadoop.conf.Configuration;
029import org.apache.yetus.audience.InterfaceAudience;
030import org.slf4j.Logger;
031import org.slf4j.LoggerFactory;
032
033@InterfaceAudience.Private
034public final class SaslServerAuthenticationProviders {
035  private static final Logger LOG = LoggerFactory.getLogger(
036      SaslClientAuthenticationProviders.class);
037
038  public static final String EXTRA_PROVIDERS_KEY = "hbase.server.sasl.provider.extras";
039  private static final AtomicReference<SaslServerAuthenticationProviders> holder =
040      new AtomicReference<>();
041
042  private final HashMap<Byte, SaslServerAuthenticationProvider> providers;
043
044  private SaslServerAuthenticationProviders(Configuration conf,
045      HashMap<Byte, SaslServerAuthenticationProvider> providers) {
046    this.providers = providers;
047  }
048
049  /**
050   * Returns the number of registered providers.
051   */
052  public int getNumRegisteredProviders() {
053    return providers.size();
054  }
055
056  /**
057   * Returns a singleton instance of {@link SaslServerAuthenticationProviders}.
058   */
059  public static SaslServerAuthenticationProviders getInstance(Configuration conf) {
060    SaslServerAuthenticationProviders providers = holder.get();
061    if (null == providers) {
062      synchronized (holder) {
063        // Someone else beat us here
064        providers = holder.get();
065        if (null != providers) {
066          return providers;
067        }
068
069        providers = createProviders(conf);
070        holder.set(providers);
071      }
072    }
073    return providers;
074  }
075
076  /**
077   * Removes the cached singleton instance of {@link SaslServerAuthenticationProviders}.
078   */
079  public static void reset() {
080    synchronized (holder) {
081      holder.set(null);
082    }
083  }
084
085  /**
086   * Adds the given provider into the map of providers if a mapping for the auth code does not
087   * already exist in the map.
088   */
089  static void addProviderIfNotExists(SaslServerAuthenticationProvider provider,
090      HashMap<Byte,SaslServerAuthenticationProvider> providers) {
091    final byte newProviderAuthCode = provider.getSaslAuthMethod().getCode();
092    final SaslServerAuthenticationProvider alreadyRegisteredProvider = providers.get(
093        newProviderAuthCode);
094    if (alreadyRegisteredProvider != null) {
095      throw new RuntimeException("Trying to load SaslServerAuthenticationProvider "
096          + provider.getClass() + ", but "+ alreadyRegisteredProvider.getClass()
097          + " is already registered with the same auth code");
098    }
099    providers.put(newProviderAuthCode, provider);
100  }
101
102  /**
103   * Adds any providers defined in the configuration.
104   */
105  static void addExtraProviders(Configuration conf,
106      HashMap<Byte,SaslServerAuthenticationProvider> providers) {
107    for (String implName : conf.getStringCollection(EXTRA_PROVIDERS_KEY)) {
108      Class<?> clz;
109      try {
110        clz = Class.forName(implName);
111      } catch (ClassNotFoundException e) {
112        LOG.warn("Failed to find SaslServerAuthenticationProvider class {}", implName, e);
113        continue;
114      }
115
116      if (!SaslServerAuthenticationProvider.class.isAssignableFrom(clz)) {
117        LOG.warn("Server authentication class {} is not an instance of "
118            + "SaslServerAuthenticationProvider", clz);
119        continue;
120      }
121
122      try {
123        SaslServerAuthenticationProvider provider =
124            (SaslServerAuthenticationProvider) clz.getConstructor().newInstance();
125        addProviderIfNotExists(provider, providers);
126      } catch (InstantiationException | IllegalAccessException | NoSuchMethodException
127          | InvocationTargetException e) {
128        LOG.warn("Failed to instantiate {}", clz, e);
129      }
130    }
131  }
132
133  /**
134   * Loads server authentication providers from the classpath and configuration, and then creates
135   * the SaslServerAuthenticationProviders instance.
136   */
137  static SaslServerAuthenticationProviders createProviders(Configuration conf) {
138    ServiceLoader<SaslServerAuthenticationProvider> loader =
139        ServiceLoader.load(SaslServerAuthenticationProvider.class);
140    HashMap<Byte,SaslServerAuthenticationProvider> providers = new HashMap<>();
141    for (SaslServerAuthenticationProvider provider : loader) {
142      addProviderIfNotExists(provider, providers);
143    }
144
145    addExtraProviders(conf, providers);
146
147    if (LOG.isTraceEnabled()) {
148      String loadedProviders = providers.values().stream()
149          .map((provider) -> provider.getClass().getName())
150          .collect(Collectors.joining(", "));
151      if (loadedProviders.isEmpty()) {
152        loadedProviders = "None!";
153      }
154      LOG.trace("Found SaslServerAuthenticationProviders {}", loadedProviders);
155    }
156
157    // Initialize the providers once, before we get into the RPC path.
158    providers.forEach((b,provider) -> {
159      try {
160        // Give them a copy, just to make sure there is no funny-business going on.
161        provider.init(new Configuration(conf));
162      } catch (IOException e) {
163        LOG.error("Failed to initialize {}", provider.getClass(), e);
164        throw new RuntimeException(
165            "Failed to initialize " + provider.getClass().getName(), e);
166      }
167    });
168
169    return new SaslServerAuthenticationProviders(conf, providers);
170  }
171
172  /**
173   * Selects the appropriate SaslServerAuthenticationProvider from those available. If there is no
174   * matching provider for the given {@code authByte}, this method will return null.
175   */
176  public SaslServerAuthenticationProvider selectProvider(byte authByte) {
177    return providers.get(Byte.valueOf(authByte));
178  }
179
180  /**
181   * Extracts the SIMPLE authentication provider.
182   */
183  public SaslServerAuthenticationProvider getSimpleProvider() {
184    Optional<SaslServerAuthenticationProvider> opt = providers.values()
185        .stream()
186        .filter((p) -> p instanceof SimpleSaslServerAuthenticationProvider)
187        .findFirst();
188    if (!opt.isPresent()) {
189      throw new RuntimeException("SIMPLE authentication provider not available when it should be");
190    }
191    return opt.get();
192  }
193}