1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.rest.filter;
19
20 import java.io.IOException;
21 import java.util.HashMap;
22 import java.util.HashSet;
23 import java.util.Map;
24 import java.util.Set;
25 import java.util.regex.Matcher;
26 import java.util.regex.Pattern;
27
28 import javax.servlet.Filter;
29 import javax.servlet.FilterChain;
30 import javax.servlet.FilterConfig;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36
37 import org.apache.commons.logging.Log;
38 import org.apache.commons.logging.LogFactory;
39 import org.apache.hadoop.classification.InterfaceAudience;
40 import org.apache.hadoop.classification.InterfaceStability;
41 import org.apache.hadoop.conf.Configuration;
42
43
44
45
46
47
48
49
50 @InterfaceAudience.Public
51 @InterfaceStability.Evolving
52 public class RestCsrfPreventionFilter implements Filter {
53
54 private static final Log LOG =
55 LogFactory.getLog(RestCsrfPreventionFilter.class);
56
57 public static final String HEADER_USER_AGENT = "User-Agent";
58 public static final String BROWSER_USER_AGENT_PARAM =
59 "browser-useragents-regex";
60 public static final String CUSTOM_HEADER_PARAM = "custom-header";
61 public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
62 "methods-to-ignore";
63 static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
64 public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
65 static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
66 private String headerName = HEADER_DEFAULT;
67 private Set<String> methodsToIgnore = null;
68 private Set<Pattern> browserUserAgents;
69
70 @Override
71 public void init(FilterConfig filterConfig) throws ServletException {
72 String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
73 if (customHeader != null) {
74 headerName = customHeader;
75 }
76 String customMethodsToIgnore =
77 filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
78 if (customMethodsToIgnore != null) {
79 parseMethodsToIgnore(customMethodsToIgnore);
80 } else {
81 parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
82 }
83
84 String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
85 if (agents == null) {
86 agents = BROWSER_USER_AGENTS_DEFAULT;
87 }
88 parseBrowserUserAgents(agents);
89 LOG.info(String.format("Adding cross-site request forgery (CSRF) protection, "
90 + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s",
91 headerName, methodsToIgnore, browserUserAgents));
92 }
93
94 void parseBrowserUserAgents(String userAgents) {
95 String[] agentsArray = userAgents.split(",");
96 browserUserAgents = new HashSet<Pattern>();
97 for (String patternString : agentsArray) {
98 browserUserAgents.add(Pattern.compile(patternString));
99 }
100 }
101
102 void parseMethodsToIgnore(String mti) {
103 String[] methods = mti.split(",");
104 methodsToIgnore = new HashSet<String>();
105 for (int i = 0; i < methods.length; i++) {
106 methodsToIgnore.add(methods[i]);
107 }
108 }
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 protected boolean isBrowser(String userAgent) {
126 if (userAgent == null) {
127 return false;
128 }
129 for (Pattern pattern : browserUserAgents) {
130 Matcher matcher = pattern.matcher(userAgent);
131 if (matcher.matches()) {
132 return true;
133 }
134 }
135 return false;
136 }
137
138
139
140
141
142
143
144
145
146
147 public interface HttpInteraction {
148
149
150
151
152
153
154
155 String getHeader(String header);
156
157
158
159
160
161
162 String getMethod();
163
164
165
166
167
168
169
170
171 void proceed() throws IOException, ServletException;
172
173
174
175
176
177
178
179
180
181 void sendError(int code, String message) throws IOException;
182 }
183
184
185
186
187
188
189
190
191
192 public void handleHttpInteraction(HttpInteraction httpInteraction)
193 throws IOException, ServletException {
194 if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
195 methodsToIgnore.contains(httpInteraction.getMethod()) ||
196 httpInteraction.getHeader(headerName) != null) {
197 httpInteraction.proceed();
198 } else {
199 httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
200 "Missing Required Header for CSRF Vulnerability Protection");
201 }
202 }
203
204 @Override
205 public void doFilter(ServletRequest request, ServletResponse response,
206 final FilterChain chain) throws IOException, ServletException {
207 final HttpServletRequest httpRequest = (HttpServletRequest)request;
208 final HttpServletResponse httpResponse = (HttpServletResponse)response;
209 handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
210 httpResponse, chain));
211 }
212
213 @Override
214 public void destroy() {
215 }
216
217
218
219
220
221
222
223
224
225
226
227
228 public static Map<String, String> getFilterParams(Configuration conf,
229 String confPrefix) {
230 Map<String, String> filterConfigMap = new HashMap<>();
231 for (Map.Entry<String, String> entry : conf) {
232 String name = entry.getKey();
233 if (name.startsWith(confPrefix)) {
234 String value = conf.get(name);
235 name = name.substring(confPrefix.length());
236 filterConfigMap.put(name, value);
237 }
238 }
239 return filterConfigMap;
240 }
241
242
243
244
245 private static final class ServletFilterHttpInteraction
246 implements HttpInteraction {
247
248 private final FilterChain chain;
249 private final HttpServletRequest httpRequest;
250 private final HttpServletResponse httpResponse;
251
252
253
254
255
256
257
258
259 public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
260 HttpServletResponse httpResponse, FilterChain chain) {
261 this.httpRequest = httpRequest;
262 this.httpResponse = httpResponse;
263 this.chain = chain;
264 }
265
266 @Override
267 public String getHeader(String header) {
268 return httpRequest.getHeader(header);
269 }
270
271 @Override
272 public String getMethod() {
273 return httpRequest.getMethod();
274 }
275
276 @Override
277 public void proceed() throws IOException, ServletException {
278 chain.doFilter(httpRequest, httpResponse);
279 }
280
281 @Override
282 public void sendError(int code, String message) throws IOException {
283 httpResponse.sendError(code, message);
284 }
285 }
286 }