001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.rest.filter;
019
020import java.io.IOException;
021import java.util.Collections;
022import java.util.HashMap;
023import java.util.HashSet;
024import java.util.Map;
025import java.util.Set;
026import java.util.regex.Matcher;
027import java.util.regex.Pattern;
028import javax.servlet.Filter;
029import javax.servlet.FilterChain;
030import javax.servlet.FilterConfig;
031import javax.servlet.ServletException;
032import javax.servlet.ServletRequest;
033import javax.servlet.ServletResponse;
034import javax.servlet.http.HttpServletRequest;
035import javax.servlet.http.HttpServletResponse;
036import org.apache.hadoop.conf.Configuration;
037import org.apache.yetus.audience.InterfaceAudience;
038import org.slf4j.Logger;
039import org.slf4j.LoggerFactory;
040
041/**
042 * This filter provides protection against cross site request forgery (CSRF) attacks for REST APIs.
043 * Enabling this filter on an endpoint results in the requirement of all client to send a particular
044 * (configurable) HTTP header with every request. In the absense of this header the filter will
045 * reject the attempt as a bad request.
046 */
047@InterfaceAudience.Public
048public class RestCsrfPreventionFilter implements Filter {
049  private static final Logger LOG = LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
050
051  public static final String HEADER_USER_AGENT = "User-Agent";
052  public static final String BROWSER_USER_AGENT_PARAM = "browser-useragents-regex";
053  public static final String CUSTOM_HEADER_PARAM = "custom-header";
054  public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = "methods-to-ignore";
055  static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
056  public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
057  static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
058  private String headerName = HEADER_DEFAULT;
059  private Set<String> methodsToIgnore = null;
060  private Set<Pattern> browserUserAgents;
061
062  @Override
063  public void init(FilterConfig filterConfig) {
064    String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
065    if (customHeader != null) {
066      headerName = customHeader;
067    }
068    String customMethodsToIgnore = filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
069    if (customMethodsToIgnore != null) {
070      parseMethodsToIgnore(customMethodsToIgnore);
071    } else {
072      parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
073    }
074
075    String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
076    if (agents == null) {
077      agents = BROWSER_USER_AGENTS_DEFAULT;
078    }
079    parseBrowserUserAgents(agents);
080    LOG.info(String.format(
081      "Adding cross-site request forgery (CSRF) protection, "
082        + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s",
083      headerName, methodsToIgnore, browserUserAgents));
084  }
085
086  void parseBrowserUserAgents(String userAgents) {
087    String[] agentsArray = userAgents.split(",");
088    browserUserAgents = new HashSet<>();
089    for (String patternString : agentsArray) {
090      browserUserAgents.add(Pattern.compile(patternString));
091    }
092  }
093
094  void parseMethodsToIgnore(String mti) {
095    String[] methods = mti.split(",");
096    methodsToIgnore = new HashSet<>();
097    Collections.addAll(methodsToIgnore, methods);
098  }
099
100  /**
101   * This method interrogates the User-Agent String and returns whether it refers to a browser. If
102   * its not a browser, then the requirement for the CSRF header will not be enforced; if it is a
103   * browser, the requirement will be enforced.
104   * <p>
105   * A User-Agent String is considered to be a browser if it matches any of the regex patterns from
106   * browser-useragent-regex; the default behavior is to consider everything a browser that matches
107   * the following: "^Mozilla.*,^Opera.*". Subclasses can optionally override this method to use
108   * different behavior.
109   * @param userAgent The User-Agent String, or null if there isn't one
110   * @return true if the User-Agent String refers to a browser, false if not
111   */
112  protected boolean isBrowser(String userAgent) {
113    if (userAgent == null) {
114      return false;
115    }
116    for (Pattern pattern : browserUserAgents) {
117      Matcher matcher = pattern.matcher(userAgent);
118      if (matcher.matches()) {
119        return true;
120      }
121    }
122    return false;
123  }
124
125  /**
126   * Defines the minimal API requirements for the filter to execute its filtering logic. This
127   * interface exists to facilitate integration in components that do not run within a servlet
128   * container and therefore cannot rely on a servlet container to dispatch to the {@link #doFilter}
129   * method. Applications that do run inside a servlet container will not need to write code that
130   * uses this interface. Instead, they can use typical servlet container configuration mechanisms
131   * to insert the filter.
132   */
133  public interface HttpInteraction {
134    /**
135     * Returns the value of a header.
136     * @param header name of header
137     * @return value of header
138     */
139    String getHeader(String header);
140
141    /**
142     * Returns the method.
143     */
144    String getMethod();
145
146    /**
147     * Called by the filter after it decides that the request may proceed.
148     * @throws IOException      if there is an I/O error
149     * @throws ServletException if the implementation relies on the servlet API and a servlet API
150     *                          call has failed
151     */
152    void proceed() throws IOException, ServletException;
153
154    /**
155     * Called by the filter after it decides that the request is a potential CSRF attack and
156     * therefore must be rejected.
157     * @param code    status code to send
158     * @param message response message
159     * @throws IOException if there is an I/O error
160     */
161    void sendError(int code, String message) throws IOException;
162  }
163
164  /**
165   * Handles an {@link HttpInteraction} by applying the filtering logic.
166   * @param httpInteraction caller's HTTP interaction
167   * @throws IOException      if there is an I/O error
168   * @throws ServletException if the implementation relies on the servlet API and a servlet API call
169   *                          has failed
170   */
171  public void handleHttpInteraction(HttpInteraction httpInteraction)
172    throws IOException, ServletException {
173    if (
174      !isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT))
175        || methodsToIgnore.contains(httpInteraction.getMethod())
176        || httpInteraction.getHeader(headerName) != null
177    ) {
178      httpInteraction.proceed();
179    } else {
180      httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
181        "Missing Required Header for CSRF Vulnerability Protection");
182    }
183  }
184
185  @Override
186  public void doFilter(ServletRequest request, ServletResponse response, final FilterChain chain)
187    throws IOException, ServletException {
188    final HttpServletRequest httpRequest = (HttpServletRequest) request;
189    final HttpServletResponse httpResponse = (HttpServletResponse) response;
190    handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest, httpResponse, chain));
191  }
192
193  @Override
194  public void destroy() {
195  }
196
197  /**
198   * Constructs a mapping of configuration properties to be used for filter initialization. The
199   * mapping includes all properties that start with the specified configuration prefix. Property
200   * names in the mapping are trimmed to remove the configuration prefix.
201   * @param conf       configuration to read
202   * @param confPrefix configuration prefix
203   * @return mapping of configuration properties to be used for filter initialization
204   */
205  public static Map<String, String> getFilterParams(Configuration conf, String confPrefix) {
206    Map<String, String> filterConfigMap = new HashMap<>();
207    for (Map.Entry<String, String> entry : conf) {
208      String name = entry.getKey();
209      if (name.startsWith(confPrefix)) {
210        String value = conf.get(name);
211        name = name.substring(confPrefix.length());
212        filterConfigMap.put(name, value);
213      }
214    }
215    return filterConfigMap;
216  }
217
218  /**
219   * {@link HttpInteraction} implementation for use in the servlet filter.
220   */
221  private static final class ServletFilterHttpInteraction implements HttpInteraction {
222    private final FilterChain chain;
223    private final HttpServletRequest httpRequest;
224    private final HttpServletResponse httpResponse;
225
226    /**
227     * Creates a new ServletFilterHttpInteraction.
228     * @param httpRequest  request to process
229     * @param httpResponse response to process
230     * @param chain        filter chain to forward to if HTTP interaction is allowed
231     */
232    public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
233      HttpServletResponse httpResponse, FilterChain chain) {
234      this.httpRequest = httpRequest;
235      this.httpResponse = httpResponse;
236      this.chain = chain;
237    }
238
239    @Override
240    public String getHeader(String header) {
241      return httpRequest.getHeader(header);
242    }
243
244    @Override
245    public String getMethod() {
246      return httpRequest.getMethod();
247    }
248
249    @Override
250    public void proceed() throws IOException, ServletException {
251      chain.doFilter(httpRequest, httpResponse);
252    }
253
254    @Override
255    public void sendError(int code, String message) throws IOException {
256      httpResponse.sendError(code, message);
257    }
258  }
259}