/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.savedrequest;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import java.util.*;
/**
* Represents central information from a {@code HttpServletRequest}.
* <p>
* This class is used by
* {@link org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter}
* and {@link org.springframework.security.web.savedrequest.SavedRequestAwareWrapper} to
* reproduce the request after successful authentication. An instance of this class is
* stored at the time of an authentication exception by
* {@link org.springframework.security.web.access.ExceptionTranslationFilter}.
* <p>
* <em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed only from the
* context of a single thread, so no synchronization around internal collection classes is
* performed.
* <p>
* This class is based on code in Apache Tomcat.
*
* @author Craig McClanahan
* @author Andrey Grebnev
* @author Ben Alex
* @author Luke Taylor
*/
public class DefaultSavedRequest implements SavedRequest {
// ~ Static fields/initializers
// =====================================================================================
protected static final Log logger = LogFactory.getLog(DefaultSavedRequest.class);
private static final String HEADER_IF_NONE_MATCH = "If-None-Match";
private static final String HEADER_IF_MODIFIED_SINCE = "If-Modified-Since";
// ~ Instance fields
// ================================================================================================
private final ArrayList<SavedCookie> cookies = new ArrayList<SavedCookie>();
private final ArrayList<Locale> locales = new ArrayList<Locale>();
private final Map<String, List<String>> headers = new TreeMap<String, List<String>>(
String.CASE_INSENSITIVE_ORDER);
private final Map<String, String[]> parameters = new TreeMap<String, String[]>();
private final String contextPath;
private final String method;
private final String pathInfo;
private final String queryString;
private final String requestURI;
private final String requestURL;
private final String scheme;
private final String serverName;
private final String servletPath;
private final int serverPort;
// ~ Constructors
// ===================================================================================================
@SuppressWarnings("unchecked")
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) {
Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required");
// Cookies
addCookies(request.getCookies());
// Headers
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
// Skip If-Modified-Since and If-None-Match header. SEC-1412, SEC-1624.
if (HEADER_IF_MODIFIED_SINCE.equalsIgnoreCase(name)
|| HEADER_IF_NONE_MATCH.equalsIgnoreCase(name)) {
continue;
}
Enumeration<String> values = request.getHeaders(name);
while (values.hasMoreElements()) {
this.addHeader(name, values.nextElement());
}
}
// Locales
addLocales(request.getLocales());
// Parameters
addParameters(request.getParameterMap());
// Primitives
this.method = request.getMethod();
this.pathInfo = request.getPathInfo();
this.queryString = request.getQueryString();
this.requestURI = request.getRequestURI();
this.serverPort = portResolver.getServerPort(request);
this.requestURL = request.getRequestURL().toString();
this.scheme = request.getScheme();
this.serverName = request.getServerName();
this.contextPath = request.getContextPath();
this.servletPath = request.getServletPath();
}
/**
* Private constructor invoked through Builder
*/
private DefaultSavedRequest(Builder builder) {
this.contextPath = builder.contextPath;
this.method = builder.method;
this.pathInfo = builder.pathInfo;
this.queryString = builder.queryString;
this.requestURI = builder.requestURI;
this.requestURL = builder.requestURL;
this.scheme = builder.scheme;
this.serverName = builder.serverName;
this.servletPath = builder.servletPath;
this.serverPort = builder.serverPort;
}
// ~ Methods
// ========================================================================================================
/**
* @since 4.2
*/
private void addCookies(Cookie[] cookies) {
if (cookies != null) {
for (Cookie cookie : cookies) {
this.addCookie(cookie);
}
}
}
private void addCookie(Cookie cookie) {
cookies.add(new SavedCookie(cookie));
}
private void addHeader(String name, String value) {
List<String> values = headers.get(name);
if (values == null) {
values = new ArrayList<String>();
headers.put(name, values);
}
values.add(value);
}
/**
* @since 4.2
*/
private void addLocales(Enumeration<Locale> locales) {
while (locales.hasMoreElements()) {
Locale locale = locales.nextElement();
this.addLocale(locale);
}
}
private void addLocale(Locale locale) {
locales.add(locale);
}
/**
* @since 4.2
*/
private void addParameters(Map<String, String[]> parameters) {
if (!ObjectUtils.isEmpty(parameters)) {
for (String paramName : parameters.keySet()) {
Object paramValues = parameters.get(paramName);
if (paramValues instanceof String[]) {
this.addParameter(paramName, (String[]) paramValues);
} else {
if (logger.isWarnEnabled()) {
logger.warn("ServletRequest.getParameterMap() returned non-String array");
}
}
}
}
}
private void addParameter(String name, String[] values) {
parameters.put(name, values);
}
/**
* Determines if the current request matches the <code>DefaultSavedRequest</code>.
* <p>
* All URL arguments are considered but not cookies, locales, headers or parameters.
*
* @param request the actual request to be matched against this one
* @param portResolver used to obtain the server port of the request
* @return true if the request is deemed to match this one.
*/
public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) {
if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) {
return false;
}
if (!propertyEquals("queryString", this.queryString, request.getQueryString())) {
return false;
}
if (!propertyEquals("requestURI", this.requestURI, request.getRequestURI())) {
return false;
}
if (!"GET".equals(request.getMethod()) && "GET".equals(method)) {
// A save GET should not match an incoming non-GET method
return false;
}
if (!propertyEquals("serverPort", Integer.valueOf(this.serverPort),
Integer.valueOf(portResolver.getServerPort(request)))) {
return false;
}
if (!propertyEquals("requestURL", this.requestURL, request.getRequestURL()
.toString())) {
return false;
}
if (!propertyEquals("scheme", this.scheme, request.getScheme())) {
return false;
}
if (!propertyEquals("serverName", this.serverName, request.getServerName())) {
return false;
}
if (!propertyEquals("contextPath", this.contextPath, request.getContextPath())) {
return false;
}
return propertyEquals("servletPath", this.servletPath, request.getServletPath());
}
public String getContextPath() {
return contextPath;
}
public List<Cookie> getCookies() {
List<Cookie> cookieList = new ArrayList<Cookie>(cookies.size());
for (SavedCookie savedCookie : cookies) {
cookieList.add(savedCookie.getCookie());
}
return cookieList;
}
/**
* Indicates the URL that the user agent used for this request.
*
* @return the full URL of this request
*/
public String getRedirectUrl() {
return UrlUtils.buildFullRequestUrl(scheme, serverName, serverPort, requestURI,
queryString);
}
public Collection<String> getHeaderNames() {
return headers.keySet();
}
public List<String> getHeaderValues(String name) {
List<String> values = headers.get(name);
if (values == null) {
return Collections.emptyList();
}
return values;
}
public List<Locale> getLocales() {
return locales;
}
public String getMethod() {
return method;
}
public Map<String, String[]> getParameterMap() {
return parameters;
}
public Collection<String> getParameterNames() {
return parameters.keySet();
}
public String[] getParameterValues(String name) {
return parameters.get(name);
}
public String getPathInfo() {
return pathInfo;
}
public String getQueryString() {
return (this.queryString);
}
public String getRequestURI() {
return (this.requestURI);
}
public String getRequestURL() {
return requestURL;
}
public String getScheme() {
return scheme;
}
public String getServerName() {
return serverName;
}
public int getServerPort() {
return serverPort;
}
public String getServletPath() {
return servletPath;
}
private boolean propertyEquals(String log, Object arg1, Object arg2) {
if ((arg1 == null) && (arg2 == null)) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": both null (property equals)");
}
return true;
}
if (arg1 == null || arg2 == null) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property not equals)");
}
return false;
}
if (arg1.equals(arg2)) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property equals)");
}
return true;
} else {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property not equals)");
}
return false;
}
}
public String toString() {
return "DefaultSavedRequest[" + getRedirectUrl() + "]";
}
/**
* @since 4.2
*/
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonPOJOBuilder(withPrefix = "set")
public static class Builder {
private List<SavedCookie> cookies = null;
private List<Locale> locales = null;
private Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER);
private Map<String, String[]> parameters = new TreeMap<String, String[]>();
private String contextPath;
private String method;
private String pathInfo;
private String queryString;
private String requestURI;
private String requestURL;
private String scheme;
private String serverName;
private String servletPath;
private int serverPort = 80;
public Builder setCookies(List<SavedCookie> cookies) {
this.cookies = cookies;
return this;
}
public Builder setLocales(List<Locale> locales) {
this.locales = locales;
return this;
}
public Builder setHeaders(Map<String, List<String>> header) {
this.headers.putAll(header);
return this;
}
public Builder setParameters(Map<String, String[]> parameters) {
this.parameters = parameters;
return this;
}
public Builder setContextPath(String contextPath) {
this.contextPath = contextPath;
return this;
}
public Builder setMethod(String method) {
this.method = method;
return this;
}
public Builder setPathInfo(String pathInfo) {
this.pathInfo = pathInfo;
return this;
}
public Builder setQueryString(String queryString) {
this.queryString = queryString;
return this;
}
public Builder setRequestURI(String requestURI) {
this.requestURI = requestURI;
return this;
}
public Builder setRequestURL(String requestURL) {
this.requestURL = requestURL;
return this;
}
public Builder setScheme(String scheme) {
this.scheme = scheme;
return this;
}
public Builder setServerName(String serverName) {
this.serverName = serverName;
return this;
}
public Builder setServletPath(String servletPath) {
this.servletPath = servletPath;
return this;
}
public Builder setServerPort(int serverPort) {
this.serverPort = serverPort;
return this;
}
public DefaultSavedRequest build() {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
if(!ObjectUtils.isEmpty(this.cookies)) {
for (SavedCookie cookie : this.cookies) {
savedRequest.addCookie(cookie.getCookie());
}
}
if (!ObjectUtils.isEmpty(this.locales))
savedRequest.locales.addAll(this.locales);
savedRequest.addParameters(this.parameters);
this.headers.remove(HEADER_IF_MODIFIED_SINCE);
this.headers.remove(HEADER_IF_NONE_MATCH);
if (!ObjectUtils.isEmpty(this.headers))
savedRequest.headers.putAll(this.headers);
return savedRequest;
}
}
}