/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.filter; import java.io.IOException; import java.util.Collections; import java.util.Enumeration; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; /** * Filter that wraps the request and response in order to override its * {@link HttpServletRequest#getServerName() getServerName()}, * {@link HttpServletRequest#getServerPort() getServerPort()}, * {@link HttpServletRequest#getScheme() getScheme()}, * {@link HttpServletRequest#isSecure() isSecure()}, * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}, * methods with values derived from "Forwarded" or "X-Forwarded-*" * headers. In effect the wrapped request and response reflects the * client-originated protocol and address. * * @author Rossen Stoyanchev * @author EddĂș MelĂ©ndez * @author Rob Winch * @since 4.3 */ public class ForwardedHeaderFilter extends OncePerRequestFilter { private static final Set<String> FORWARDED_HEADER_NAMES = Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH)); static { FORWARDED_HEADER_NAMES.add("Forwarded"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); } private final UrlPathHelper pathHelper; public ForwardedHeaderFilter() { this.pathHelper = new UrlPathHelper(); this.pathHelper.setUrlDecode(false); this.pathHelper.setRemoveSemicolonContent(false); } @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { Enumeration<String> names = request.getHeaderNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (FORWARDED_HEADER_NAMES.contains(name)) { return false; } } return true; } @Override protected boolean shouldNotFilterAsyncDispatch() { return false; } @Override protected boolean shouldNotFilterErrorDispatch() { return false; } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { ForwardedHeaderRequestWrapper wrappedRequest = new ForwardedHeaderRequestWrapper(request, this.pathHelper); ForwardedHeaderResponseWrapper wrappedResponse = new ForwardedHeaderResponseWrapper(response, wrappedRequest); filterChain.doFilter(wrappedRequest, wrappedResponse); } private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper { private final String scheme; private final boolean secure; private final String host; private final int port; private final String contextPath; private final String requestUri; private final String requestUrl; private final Map<String, List<String>> headers; public ForwardedHeaderRequestWrapper(HttpServletRequest request, UrlPathHelper pathHelper) { super(request); HttpRequest httpRequest = new ServletServerHttpRequest(request); UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); int port = uriComponents.getPort(); this.scheme = uriComponents.getScheme(); this.secure = "https".equals(scheme); this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); String prefix = getForwardedPrefix(request); this.contextPath = (prefix != null ? prefix : request.getContextPath()); this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request); this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri; this.headers = initHeaders(request); } private static String getForwardedPrefix(HttpServletRequest request) { String prefix = null; Enumeration<String> names = request.getHeaderNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) { prefix = request.getHeader(name); } } if (prefix != null) { while (prefix.endsWith("/")) { prefix = prefix.substring(0, prefix.length() - 1); } } return prefix; } /** * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}. */ private static Map<String, List<String>> initHeaders(HttpServletRequest request) { Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); Enumeration<String> names = request.getHeaderNames(); while (names.hasMoreElements()) { String name = names.nextElement(); if (!FORWARDED_HEADER_NAMES.contains(name)) { headers.put(name, Collections.list(request.getHeaders(name))); } } return headers; } @Override public String getScheme() { return this.scheme; } @Override public String getServerName() { return this.host; } @Override public int getServerPort() { return this.port; } @Override public boolean isSecure() { return this.secure; } @Override public String getContextPath() { return this.contextPath; } @Override public String getRequestURI() { return this.requestUri; } @Override public StringBuffer getRequestURL() { return new StringBuffer(this.requestUrl); } // Override header accessors to not expose forwarded headers @Override public String getHeader(String name) { List<String> value = this.headers.get(name); return (CollectionUtils.isEmpty(value) ? null : value.get(0)); } @Override public Enumeration<String> getHeaders(String name) { List<String> value = this.headers.get(name); return (Collections.enumeration(value != null ? value : Collections.emptySet())); } @Override public Enumeration<String> getHeaderNames() { return Collections.enumeration(this.headers.keySet()); } } private static class ForwardedHeaderResponseWrapper extends HttpServletResponseWrapper { private static final String FOLDER_SEPARATOR = "/"; private final HttpServletRequest request; public ForwardedHeaderResponseWrapper(HttpServletResponse response, HttpServletRequest request) { super(response); this.request = request; } @Override public void sendRedirect(String location) throws IOException { UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(location); // Absolute location if (builder.build().getScheme() != null) { super.sendRedirect(location); return; } // Network-path reference if (location.startsWith("//")) { String scheme = this.request.getScheme(); super.sendRedirect(builder.scheme(scheme).toUriString()); return; } // Relative to Servlet container root or to current request String path = (location.startsWith(FOLDER_SEPARATOR) ? location : StringUtils.applyRelativePath(this.request.getRequestURI(), location)); String result = UriComponentsBuilder .fromHttpRequest(new ServletServerHttpRequest(this.request)) .replacePath(path) .build().normalize().toUriString(); super.sendRedirect(result); } } }