001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.rest.filter;
019
020import java.io.IOException;
021import java.util.Collections;
022import java.util.HashMap;
023import java.util.HashSet;
024import java.util.Map;
025import java.util.Set;
026import java.util.regex.Matcher;
027import java.util.regex.Pattern;
028
029import javax.servlet.Filter;
030import javax.servlet.FilterChain;
031import javax.servlet.FilterConfig;
032import javax.servlet.ServletException;
033import javax.servlet.ServletRequest;
034import javax.servlet.ServletResponse;
035import javax.servlet.http.HttpServletRequest;
036import javax.servlet.http.HttpServletResponse;
037
038import org.apache.hadoop.conf.Configuration;
039
040import org.apache.yetus.audience.InterfaceAudience;
041
042import org.slf4j.Logger;
043import org.slf4j.LoggerFactory;
044
045/**
046 * This filter provides protection against cross site request forgery (CSRF)
047 * attacks for REST APIs. Enabling this filter on an endpoint results in the
048 * requirement of all client to send a particular (configurable) HTTP header
049 * with every request. In the absense of this header the filter will reject the
050 * attempt as a bad request.
051 */
052@InterfaceAudience.Public
053public class RestCsrfPreventionFilter implements Filter {
054  private static final Logger LOG =
055      LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
056
057  public static final String HEADER_USER_AGENT = "User-Agent";
058  public static final String BROWSER_USER_AGENT_PARAM =
059      "browser-useragents-regex";
060  public static final String CUSTOM_HEADER_PARAM = "custom-header";
061  public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
062      "methods-to-ignore";
063  static final String  BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
064  public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
065  static final String  METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
066  private String  headerName = HEADER_DEFAULT;
067  private Set<String> methodsToIgnore = null;
068  private Set<Pattern> browserUserAgents;
069
070  @Override
071  public void init(FilterConfig filterConfig) {
072    String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
073    if (customHeader != null) {
074      headerName = customHeader;
075    }
076    String customMethodsToIgnore =
077        filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
078    if (customMethodsToIgnore != null) {
079      parseMethodsToIgnore(customMethodsToIgnore);
080    } else {
081      parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
082    }
083
084    String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
085    if (agents == null) {
086      agents = BROWSER_USER_AGENTS_DEFAULT;
087    }
088    parseBrowserUserAgents(agents);
089    LOG.info(String.format("Adding cross-site request forgery (CSRF) protection, "
090        + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s",
091        headerName, methodsToIgnore, browserUserAgents));
092  }
093
094  void parseBrowserUserAgents(String userAgents) {
095    String[] agentsArray =  userAgents.split(",");
096    browserUserAgents = new HashSet<>();
097    for (String patternString : agentsArray) {
098      browserUserAgents.add(Pattern.compile(patternString));
099    }
100  }
101
102  void parseMethodsToIgnore(String mti) {
103    String[] methods = mti.split(",");
104    methodsToIgnore = new HashSet<>();
105    Collections.addAll(methodsToIgnore, methods);
106  }
107
108  /**
109   * This method interrogates the User-Agent String and returns whether it
110   * refers to a browser.  If its not a browser, then the requirement for the
111   * CSRF header will not be enforced; if it is a browser, the requirement will
112   * be enforced.
113   * <p>
114   * A User-Agent String is considered to be a browser if it matches
115   * any of the regex patterns from browser-useragent-regex; the default
116   * behavior is to consider everything a browser that matches the following:
117   * "^Mozilla.*,^Opera.*".  Subclasses can optionally override
118   * this method to use different behavior.
119   *
120   * @param userAgent The User-Agent String, or null if there isn't one
121   * @return true if the User-Agent String refers to a browser, false if not
122   */
123  protected boolean isBrowser(String userAgent) {
124    if (userAgent == null) {
125      return false;
126    }
127    for (Pattern pattern : browserUserAgents) {
128      Matcher matcher = pattern.matcher(userAgent);
129      if (matcher.matches()) {
130        return true;
131      }
132    }
133    return false;
134  }
135
136  /**
137   * Defines the minimal API requirements for the filter to execute its
138   * filtering logic.  This interface exists to facilitate integration in
139   * components that do not run within a servlet container and therefore cannot
140   * rely on a servlet container to dispatch to the {@link #doFilter} method.
141   * Applications that do run inside a servlet container will not need to write
142   * code that uses this interface.  Instead, they can use typical servlet
143   * container configuration mechanisms to insert the filter.
144   */
145  public interface HttpInteraction {
146    /**
147     * Returns the value of a header.
148     *
149     * @param header name of header
150     * @return value of header
151     */
152    String getHeader(String header);
153
154    /**
155     * Returns the method.
156     *
157     * @return method
158     */
159    String getMethod();
160
161    /**
162     * Called by the filter after it decides that the request may proceed.
163     *
164     * @throws IOException if there is an I/O error
165     * @throws ServletException if the implementation relies on the servlet API
166     *     and a servlet API call has failed
167     */
168    void proceed() throws IOException, ServletException;
169
170    /**
171     * Called by the filter after it decides that the request is a potential
172     * CSRF attack and therefore must be rejected.
173     *
174     * @param code status code to send
175     * @param message response message
176     * @throws IOException if there is an I/O error
177     */
178    void sendError(int code, String message) throws IOException;
179  }
180
181  /**
182   * Handles an {@link HttpInteraction} by applying the filtering logic.
183   *
184   * @param httpInteraction caller's HTTP interaction
185   * @throws IOException if there is an I/O error
186   * @throws ServletException if the implementation relies on the servlet API
187   *     and a servlet API call has failed
188   */
189  public void handleHttpInteraction(HttpInteraction httpInteraction)
190      throws IOException, ServletException {
191    if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
192        methodsToIgnore.contains(httpInteraction.getMethod()) ||
193        httpInteraction.getHeader(headerName) != null) {
194      httpInteraction.proceed();
195    } else {
196      httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
197          "Missing Required Header for CSRF Vulnerability Protection");
198    }
199  }
200
201  @Override
202  public void doFilter(ServletRequest request, ServletResponse response,
203      final FilterChain chain) throws IOException, ServletException {
204    final HttpServletRequest httpRequest = (HttpServletRequest)request;
205    final HttpServletResponse httpResponse = (HttpServletResponse)response;
206    handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
207        httpResponse, chain));
208  }
209
210  @Override
211  public void destroy() {
212  }
213
214  /**
215   * Constructs a mapping of configuration properties to be used for filter
216   * initialization.  The mapping includes all properties that start with the
217   * specified configuration prefix.  Property names in the mapping are trimmed
218   * to remove the configuration prefix.
219   *
220   * @param conf configuration to read
221   * @param confPrefix configuration prefix
222   * @return mapping of configuration properties to be used for filter
223   *     initialization
224   */
225  public static Map<String, String> getFilterParams(Configuration conf, String confPrefix) {
226    Map<String, String> filterConfigMap = new HashMap<>();
227    for (Map.Entry<String, String> entry : conf) {
228      String name = entry.getKey();
229      if (name.startsWith(confPrefix)) {
230        String value = conf.get(name);
231        name = name.substring(confPrefix.length());
232        filterConfigMap.put(name, value);
233      }
234    }
235    return filterConfigMap;
236  }
237
238  /**
239   * {@link HttpInteraction} implementation for use in the servlet filter.
240   */
241  private static final class ServletFilterHttpInteraction implements HttpInteraction {
242    private final FilterChain chain;
243    private final HttpServletRequest httpRequest;
244    private final HttpServletResponse httpResponse;
245
246    /**
247     * Creates a new ServletFilterHttpInteraction.
248     *
249     * @param httpRequest request to process
250     * @param httpResponse response to process
251     * @param chain filter chain to forward to if HTTP interaction is allowed
252     */
253    public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
254        HttpServletResponse httpResponse, FilterChain chain) {
255      this.httpRequest = httpRequest;
256      this.httpResponse = httpResponse;
257      this.chain = chain;
258    }
259
260    @Override
261    public String getHeader(String header) {
262      return httpRequest.getHeader(header);
263    }
264
265    @Override
266    public String getMethod() {
267      return httpRequest.getMethod();
268    }
269
270    @Override
271    public void proceed() throws IOException, ServletException {
272      chain.doFilter(httpRequest, httpResponse);
273    }
274
275    @Override
276    public void sendError(int code, String message) throws IOException {
277      httpResponse.sendError(code, message);
278    }
279  }
280}