/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.falcon.security; import java.io.IOException; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.classification.InterfaceStability.Evolving; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Source code forked from Hadoop 2.8.0+ org.apache.hadoop.security.http.RestCsrfPreventionFilter. */ @Public @Evolving public class RestCsrfPreventionFilter implements Filter { private static final Logger LOG = LoggerFactory.getLogger(RestCsrfPreventionFilter.class); public static final String HEADER_USER_AGENT = "User-Agent"; public static final String BROWSER_USER_AGENT_PARAM = "browser-useragents-regex"; public static final String CUSTOM_HEADER_PARAM = "custom-header"; public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = "methods-to-ignore"; static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*"; public static final String HEADER_DEFAULT = "X-XSRF-HEADER"; static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; public static final String CSRF_ERROR_MESSAGE = "Missing Required Header for CSRF Vulnerability Protection"; protected String headerName = "X-XSRF-HEADER"; protected Set<String> methodsToIgnore = null; protected Set<Pattern> browserUserAgents; public RestCsrfPreventionFilter() { } public void init(FilterConfig filterConfig) throws ServletException { String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM); if (customHeader != null) { this.headerName = customHeader; } String customMethodsToIgnore = filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM); if (customMethodsToIgnore != null) { this.parseMethodsToIgnore(customMethodsToIgnore); } else { this.parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT); } String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM); if (agents == null) { agents = BROWSER_USER_AGENTS_DEFAULT; } this.parseBrowserUserAgents(agents); } void parseBrowserUserAgents(String userAgents) { String[] agentsArray = userAgents.split(","); this.browserUserAgents = new HashSet(); String[] arr = agentsArray; int len = agentsArray.length; for (int i = 0; i < len; ++i) { String patternString = arr[i]; this.browserUserAgents.add(Pattern.compile(patternString)); } } void parseMethodsToIgnore(String mti) { String[] methods = mti.split(","); this.methodsToIgnore = new HashSet(); for (int i = 0; i < methods.length; ++i) { this.methodsToIgnore.add(methods[i]); } } protected boolean isBrowser(String userAgent) { if (userAgent == null) { return false; } else { Iterator iterator = this.browserUserAgents.iterator(); Matcher matcher; do { if (!iterator.hasNext()) { return false; } Pattern pattern = (Pattern)iterator.next(); matcher = pattern.matcher(userAgent); } while(!matcher.matches()); return true; } } public void handleHttpInteraction(RestCsrfPreventionFilter.HttpInteraction httpInteraction) throws IOException, ServletException { if (this.isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) && !this.methodsToIgnore.contains(httpInteraction.getMethod()) && httpInteraction.getHeader(this.headerName) == null) { httpInteraction.sendError(HttpServletResponse.SC_FORBIDDEN, CSRF_ERROR_MESSAGE); } else { httpInteraction.proceed(); } } public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpRequest = (HttpServletRequest)request; HttpServletResponse httpResponse = (HttpServletResponse)response; this.handleHttpInteraction(new RestCsrfPreventionFilter.ServletFilterHttpInteraction( httpRequest, httpResponse, chain)); } public void destroy() { } private static final class ServletFilterHttpInteraction implements RestCsrfPreventionFilter.HttpInteraction { private final FilterChain chain; private final HttpServletRequest httpRequest; private final HttpServletResponse httpResponse; public ServletFilterHttpInteraction(HttpServletRequest httpRequest, HttpServletResponse httpResponse, FilterChain chain) { this.httpRequest = httpRequest; this.httpResponse = httpResponse; this.chain = chain; } public String getHeader(String header) { return this.httpRequest.getHeader(header); } public String getMethod() { return this.httpRequest.getMethod(); } public void proceed() throws IOException, ServletException { this.chain.doFilter(this.httpRequest, this.httpResponse); } public void sendError(int code, String message) throws IOException { this.httpResponse.sendError(code, message); } } /** * Interface for HttpInteraction. */ public interface HttpInteraction { String getHeader(String var1); String getMethod(); void proceed() throws IOException, ServletException; void sendError(int var1, String var2) throws IOException; } }