001/** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 */ 018package org.apache.hadoop.hbase.rest.filter; 019 020import java.io.IOException; 021import java.util.Collections; 022import java.util.HashMap; 023import java.util.HashSet; 024import java.util.Map; 025import java.util.Set; 026import java.util.regex.Matcher; 027import java.util.regex.Pattern; 028 029import javax.servlet.Filter; 030import javax.servlet.FilterChain; 031import javax.servlet.FilterConfig; 032import javax.servlet.ServletException; 033import javax.servlet.ServletRequest; 034import javax.servlet.ServletResponse; 035import javax.servlet.http.HttpServletRequest; 036import javax.servlet.http.HttpServletResponse; 037 038import org.apache.hadoop.conf.Configuration; 039 040import org.apache.yetus.audience.InterfaceAudience; 041 042import org.slf4j.Logger; 043import org.slf4j.LoggerFactory; 044 045/** 046 * This filter provides protection against cross site request forgery (CSRF) 047 * attacks for REST APIs. Enabling this filter on an endpoint results in the 048 * requirement of all client to send a particular (configurable) HTTP header 049 * with every request. In the absense of this header the filter will reject the 050 * attempt as a bad request. 051 */ 052@InterfaceAudience.Public 053public class RestCsrfPreventionFilter implements Filter { 054 private static final Logger LOG = 055 LoggerFactory.getLogger(RestCsrfPreventionFilter.class); 056 057 public static final String HEADER_USER_AGENT = "User-Agent"; 058 public static final String BROWSER_USER_AGENT_PARAM = 059 "browser-useragents-regex"; 060 public static final String CUSTOM_HEADER_PARAM = "custom-header"; 061 public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = 062 "methods-to-ignore"; 063 static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*"; 064 public static final String HEADER_DEFAULT = "X-XSRF-HEADER"; 065 static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; 066 private String headerName = HEADER_DEFAULT; 067 private Set<String> methodsToIgnore = null; 068 private Set<Pattern> browserUserAgents; 069 070 @Override 071 public void init(FilterConfig filterConfig) { 072 String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM); 073 if (customHeader != null) { 074 headerName = customHeader; 075 } 076 String customMethodsToIgnore = 077 filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM); 078 if (customMethodsToIgnore != null) { 079 parseMethodsToIgnore(customMethodsToIgnore); 080 } else { 081 parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT); 082 } 083 084 String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM); 085 if (agents == null) { 086 agents = BROWSER_USER_AGENTS_DEFAULT; 087 } 088 parseBrowserUserAgents(agents); 089 LOG.info(String.format("Adding cross-site request forgery (CSRF) protection, " 090 + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s", 091 headerName, methodsToIgnore, browserUserAgents)); 092 } 093 094 void parseBrowserUserAgents(String userAgents) { 095 String[] agentsArray = userAgents.split(","); 096 browserUserAgents = new HashSet<>(); 097 for (String patternString : agentsArray) { 098 browserUserAgents.add(Pattern.compile(patternString)); 099 } 100 } 101 102 void parseMethodsToIgnore(String mti) { 103 String[] methods = mti.split(","); 104 methodsToIgnore = new HashSet<>(); 105 Collections.addAll(methodsToIgnore, methods); 106 } 107 108 /** 109 * This method interrogates the User-Agent String and returns whether it 110 * refers to a browser. If its not a browser, then the requirement for the 111 * CSRF header will not be enforced; if it is a browser, the requirement will 112 * be enforced. 113 * <p> 114 * A User-Agent String is considered to be a browser if it matches 115 * any of the regex patterns from browser-useragent-regex; the default 116 * behavior is to consider everything a browser that matches the following: 117 * "^Mozilla.*,^Opera.*". Subclasses can optionally override 118 * this method to use different behavior. 119 * 120 * @param userAgent The User-Agent String, or null if there isn't one 121 * @return true if the User-Agent String refers to a browser, false if not 122 */ 123 protected boolean isBrowser(String userAgent) { 124 if (userAgent == null) { 125 return false; 126 } 127 for (Pattern pattern : browserUserAgents) { 128 Matcher matcher = pattern.matcher(userAgent); 129 if (matcher.matches()) { 130 return true; 131 } 132 } 133 return false; 134 } 135 136 /** 137 * Defines the minimal API requirements for the filter to execute its 138 * filtering logic. This interface exists to facilitate integration in 139 * components that do not run within a servlet container and therefore cannot 140 * rely on a servlet container to dispatch to the {@link #doFilter} method. 141 * Applications that do run inside a servlet container will not need to write 142 * code that uses this interface. Instead, they can use typical servlet 143 * container configuration mechanisms to insert the filter. 144 */ 145 public interface HttpInteraction { 146 /** 147 * Returns the value of a header. 148 * 149 * @param header name of header 150 * @return value of header 151 */ 152 String getHeader(String header); 153 154 /** 155 * Returns the method. 156 * 157 * @return method 158 */ 159 String getMethod(); 160 161 /** 162 * Called by the filter after it decides that the request may proceed. 163 * 164 * @throws IOException if there is an I/O error 165 * @throws ServletException if the implementation relies on the servlet API 166 * and a servlet API call has failed 167 */ 168 void proceed() throws IOException, ServletException; 169 170 /** 171 * Called by the filter after it decides that the request is a potential 172 * CSRF attack and therefore must be rejected. 173 * 174 * @param code status code to send 175 * @param message response message 176 * @throws IOException if there is an I/O error 177 */ 178 void sendError(int code, String message) throws IOException; 179 } 180 181 /** 182 * Handles an {@link HttpInteraction} by applying the filtering logic. 183 * 184 * @param httpInteraction caller's HTTP interaction 185 * @throws IOException if there is an I/O error 186 * @throws ServletException if the implementation relies on the servlet API 187 * and a servlet API call has failed 188 */ 189 public void handleHttpInteraction(HttpInteraction httpInteraction) 190 throws IOException, ServletException { 191 if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) || 192 methodsToIgnore.contains(httpInteraction.getMethod()) || 193 httpInteraction.getHeader(headerName) != null) { 194 httpInteraction.proceed(); 195 } else { 196 httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST, 197 "Missing Required Header for CSRF Vulnerability Protection"); 198 } 199 } 200 201 @Override 202 public void doFilter(ServletRequest request, ServletResponse response, 203 final FilterChain chain) throws IOException, ServletException { 204 final HttpServletRequest httpRequest = (HttpServletRequest)request; 205 final HttpServletResponse httpResponse = (HttpServletResponse)response; 206 handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest, 207 httpResponse, chain)); 208 } 209 210 @Override 211 public void destroy() { 212 } 213 214 /** 215 * Constructs a mapping of configuration properties to be used for filter 216 * initialization. The mapping includes all properties that start with the 217 * specified configuration prefix. Property names in the mapping are trimmed 218 * to remove the configuration prefix. 219 * 220 * @param conf configuration to read 221 * @param confPrefix configuration prefix 222 * @return mapping of configuration properties to be used for filter 223 * initialization 224 */ 225 public static Map<String, String> getFilterParams(Configuration conf, String confPrefix) { 226 Map<String, String> filterConfigMap = new HashMap<>(); 227 for (Map.Entry<String, String> entry : conf) { 228 String name = entry.getKey(); 229 if (name.startsWith(confPrefix)) { 230 String value = conf.get(name); 231 name = name.substring(confPrefix.length()); 232 filterConfigMap.put(name, value); 233 } 234 } 235 return filterConfigMap; 236 } 237 238 /** 239 * {@link HttpInteraction} implementation for use in the servlet filter. 240 */ 241 private static final class ServletFilterHttpInteraction implements HttpInteraction { 242 private final FilterChain chain; 243 private final HttpServletRequest httpRequest; 244 private final HttpServletResponse httpResponse; 245 246 /** 247 * Creates a new ServletFilterHttpInteraction. 248 * 249 * @param httpRequest request to process 250 * @param httpResponse response to process 251 * @param chain filter chain to forward to if HTTP interaction is allowed 252 */ 253 public ServletFilterHttpInteraction(HttpServletRequest httpRequest, 254 HttpServletResponse httpResponse, FilterChain chain) { 255 this.httpRequest = httpRequest; 256 this.httpResponse = httpResponse; 257 this.chain = chain; 258 } 259 260 @Override 261 public String getHeader(String header) { 262 return httpRequest.getHeader(header); 263 } 264 265 @Override 266 public String getMethod() { 267 return httpRequest.getMethod(); 268 } 269 270 @Override 271 public void proceed() throws IOException, ServletException { 272 chain.doFilter(httpRequest, httpResponse); 273 } 274 275 @Override 276 public void sendError(int code, String message) throws IOException { 277 httpResponse.sendError(code, message); 278 } 279 } 280}