001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.rest.filter;
019
020import java.io.IOException;
021import java.util.HashMap;
022import java.util.HashSet;
023import java.util.Map;
024import java.util.Set;
025import java.util.regex.Matcher;
026import java.util.regex.Pattern;
027
028import javax.servlet.Filter;
029import javax.servlet.FilterChain;
030import javax.servlet.FilterConfig;
031import javax.servlet.ServletException;
032import javax.servlet.ServletRequest;
033import javax.servlet.ServletResponse;
034import javax.servlet.http.HttpServletRequest;
035import javax.servlet.http.HttpServletResponse;
036
037import org.apache.yetus.audience.InterfaceAudience;
038import org.slf4j.Logger;
039import org.slf4j.LoggerFactory;
040import org.apache.hadoop.conf.Configuration;
041
042/**
043 * This filter provides protection against cross site request forgery (CSRF)
044 * attacks for REST APIs. Enabling this filter on an endpoint results in the
045 * requirement of all client to send a particular (configurable) HTTP header
046 * with every request. In the absense of this header the filter will reject the
047 * attempt as a bad request.
048 */
049@InterfaceAudience.Public
050public class RestCsrfPreventionFilter implements Filter {
051
052  private static final Logger LOG =
053      LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
054
055  public static final String HEADER_USER_AGENT = "User-Agent";
056  public static final String BROWSER_USER_AGENT_PARAM =
057      "browser-useragents-regex";
058  public static final String CUSTOM_HEADER_PARAM = "custom-header";
059  public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
060      "methods-to-ignore";
061  static final String  BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
062  public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
063  static final String  METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
064  private String  headerName = HEADER_DEFAULT;
065  private Set<String> methodsToIgnore = null;
066  private Set<Pattern> browserUserAgents;
067
068  @Override
069  public void init(FilterConfig filterConfig) throws ServletException {
070    String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
071    if (customHeader != null) {
072      headerName = customHeader;
073    }
074    String customMethodsToIgnore =
075        filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
076    if (customMethodsToIgnore != null) {
077      parseMethodsToIgnore(customMethodsToIgnore);
078    } else {
079      parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
080    }
081
082    String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
083    if (agents == null) {
084      agents = BROWSER_USER_AGENTS_DEFAULT;
085    }
086    parseBrowserUserAgents(agents);
087    LOG.info(String.format("Adding cross-site request forgery (CSRF) protection, "
088        + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s",
089        headerName, methodsToIgnore, browserUserAgents));
090  }
091
092  void parseBrowserUserAgents(String userAgents) {
093    String[] agentsArray =  userAgents.split(",");
094    browserUserAgents = new HashSet<>();
095    for (String patternString : agentsArray) {
096      browserUserAgents.add(Pattern.compile(patternString));
097    }
098  }
099
100  void parseMethodsToIgnore(String mti) {
101    String[] methods = mti.split(",");
102    methodsToIgnore = new HashSet<>();
103    for (int i = 0; i < methods.length; i++) {
104      methodsToIgnore.add(methods[i]);
105    }
106  }
107
108  /**
109   * This method interrogates the User-Agent String and returns whether it
110   * refers to a browser.  If its not a browser, then the requirement for the
111   * CSRF header will not be enforced; if it is a browser, the requirement will
112   * be enforced.
113   * <p>
114   * A User-Agent String is considered to be a browser if it matches
115   * any of the regex patterns from browser-useragent-regex; the default
116   * behavior is to consider everything a browser that matches the following:
117   * "^Mozilla.*,^Opera.*".  Subclasses can optionally override
118   * this method to use different behavior.
119   *
120   * @param userAgent The User-Agent String, or null if there isn't one
121   * @return true if the User-Agent String refers to a browser, false if not
122   */
123  protected boolean isBrowser(String userAgent) {
124    if (userAgent == null) {
125      return false;
126    }
127    for (Pattern pattern : browserUserAgents) {
128      Matcher matcher = pattern.matcher(userAgent);
129      if (matcher.matches()) {
130        return true;
131      }
132    }
133    return false;
134  }
135
136  /**
137   * Defines the minimal API requirements for the filter to execute its
138   * filtering logic.  This interface exists to facilitate integration in
139   * components that do not run within a servlet container and therefore cannot
140   * rely on a servlet container to dispatch to the {@link #doFilter} method.
141   * Applications that do run inside a servlet container will not need to write
142   * code that uses this interface.  Instead, they can use typical servlet
143   * container configuration mechanisms to insert the filter.
144   */
145  public interface HttpInteraction {
146
147    /**
148     * Returns the value of a header.
149     *
150     * @param header name of header
151     * @return value of header
152     */
153    String getHeader(String header);
154
155    /**
156     * Returns the method.
157     *
158     * @return method
159     */
160    String getMethod();
161
162    /**
163     * Called by the filter after it decides that the request may proceed.
164     *
165     * @throws IOException if there is an I/O error
166     * @throws ServletException if the implementation relies on the servlet API
167     *     and a servlet API call has failed
168     */
169    void proceed() throws IOException, ServletException;
170
171    /**
172     * Called by the filter after it decides that the request is a potential
173     * CSRF attack and therefore must be rejected.
174     *
175     * @param code status code to send
176     * @param message response message
177     * @throws IOException if there is an I/O error
178     */
179    void sendError(int code, String message) throws IOException;
180  }
181
182  /**
183   * Handles an {@link HttpInteraction} by applying the filtering logic.
184   *
185   * @param httpInteraction caller's HTTP interaction
186   * @throws IOException if there is an I/O error
187   * @throws ServletException if the implementation relies on the servlet API
188   *     and a servlet API call has failed
189   */
190  public void handleHttpInteraction(HttpInteraction httpInteraction)
191      throws IOException, ServletException {
192    if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
193        methodsToIgnore.contains(httpInteraction.getMethod()) ||
194        httpInteraction.getHeader(headerName) != null) {
195      httpInteraction.proceed();
196    } else {
197      httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
198          "Missing Required Header for CSRF Vulnerability Protection");
199    }
200  }
201
202  @Override
203  public void doFilter(ServletRequest request, ServletResponse response,
204      final FilterChain chain) throws IOException, ServletException {
205    final HttpServletRequest httpRequest = (HttpServletRequest)request;
206    final HttpServletResponse httpResponse = (HttpServletResponse)response;
207    handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
208        httpResponse, chain));
209  }
210
211  @Override
212  public void destroy() {
213  }
214
215  /**
216   * Constructs a mapping of configuration properties to be used for filter
217   * initialization.  The mapping includes all properties that start with the
218   * specified configuration prefix.  Property names in the mapping are trimmed
219   * to remove the configuration prefix.
220   *
221   * @param conf configuration to read
222   * @param confPrefix configuration prefix
223   * @return mapping of configuration properties to be used for filter
224   *     initialization
225   */
226  public static Map<String, String> getFilterParams(Configuration conf,
227      String confPrefix) {
228    Map<String, String> filterConfigMap = new HashMap<>();
229    for (Map.Entry<String, String> entry : conf) {
230      String name = entry.getKey();
231      if (name.startsWith(confPrefix)) {
232        String value = conf.get(name);
233        name = name.substring(confPrefix.length());
234        filterConfigMap.put(name, value);
235      }
236    }
237    return filterConfigMap;
238  }
239
240  /**
241   * {@link HttpInteraction} implementation for use in the servlet filter.
242   */
243  private static final class ServletFilterHttpInteraction
244      implements HttpInteraction {
245
246    private final FilterChain chain;
247    private final HttpServletRequest httpRequest;
248    private final HttpServletResponse httpResponse;
249
250    /**
251     * Creates a new ServletFilterHttpInteraction.
252     *
253     * @param httpRequest request to process
254     * @param httpResponse response to process
255     * @param chain filter chain to forward to if HTTP interaction is allowed
256     */
257    public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
258        HttpServletResponse httpResponse, FilterChain chain) {
259      this.httpRequest = httpRequest;
260      this.httpResponse = httpResponse;
261      this.chain = chain;
262    }
263
264    @Override
265    public String getHeader(String header) {
266      return httpRequest.getHeader(header);
267    }
268
269    @Override
270    public String getMethod() {
271      return httpRequest.getMethod();
272    }
273
274    @Override
275    public void proceed() throws IOException, ServletException {
276      chain.doFilter(httpRequest, httpResponse);
277    }
278
279    @Override
280    public void sendError(int code, String message) throws IOException {
281      httpResponse.sendError(code, message);
282    }
283  }
284}