package com.revolsys.ui.web.security; import java.io.IOException; import java.lang.reflect.Method; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; public class HttpSessionSecurityContextRepository implements SecurityContextRepository { final class SaveToSessionResponseWrapper extends HttpServletResponseWrapper { private final int contextHashBeforeChainExecution; private boolean contextSaved = false; private final boolean disableUrlRewriting; private final boolean httpSessionExistedAtStartOfRequest; private final HttpServletRequest request; SaveToSessionResponseWrapper(final HttpServletResponse response, final HttpServletRequest request, final boolean httpSessionExistedAtStartOfRequest, final int contextHashBeforeChainExecution) { super(response); this.disableUrlRewriting = isDisableUrlRewriting(); this.request = request; this.httpSessionExistedAtStartOfRequest = httpSessionExistedAtStartOfRequest; this.contextHashBeforeChainExecution = contextHashBeforeChainExecution; } private HttpSession createNewSessionIfAllowed(final SecurityContext context) { if (this.httpSessionExistedAtStartOfRequest) { return null; } if (!HttpSessionSecurityContextRepository.this.allowSessionCreation) { return null; } if (HttpSessionSecurityContextRepository.this.contextObject.equals(context)) { return null; } try { return this.request.getSession(true); } catch (final IllegalStateException e) { HttpSessionSecurityContextRepository.this.logger .warn("Failed to Construct a new session, as response has been committed. Unable to store" + " SecurityContext."); } return null; } private void doSaveContext() { saveContext(SecurityContextHolder.getContext()); this.contextSaved = true; } @Override public final String encodeRedirectUrl(final String url) { if (this.disableUrlRewriting) { return url; } return super.encodeRedirectUrl(url); } @Override public final String encodeRedirectURL(final String url) { if (this.disableUrlRewriting) { return url; } return super.encodeRedirectURL(url); } @Override public final String encodeUrl(final String url) { if (this.disableUrlRewriting) { return url; } return super.encodeUrl(url); } @Override public final String encodeURL(final String url) { if (this.disableUrlRewriting) { return url; } return super.encodeURL(url); } public final boolean isContextSaved() { return this.contextSaved; } protected void saveContext(final SecurityContext context) { final Authentication authentication = context.getAuthentication(); HttpSession httpSession = this.request.getSession(false); if (authentication == null || HttpSessionSecurityContextRepository.this.authenticationTrustResolver .isAnonymous(authentication)) { if (httpSession != null) { httpSession .removeAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey); } return; } if (httpSession == null) { httpSession = createNewSessionIfAllowed(context); } if (httpSession != null && (context.hashCode() != this.contextHashBeforeChainExecution || httpSession.getAttribute( HttpSessionSecurityContextRepository.this.springSecurityContextKey) == null)) { httpSession.setAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey, context); } } @Override public final void sendError(final int sc) throws IOException { doSaveContext(); super.sendError(sc); } @Override public final void sendError(final int sc, final String msg) throws IOException { doSaveContext(); super.sendError(sc, msg); } @Override public final void sendRedirect(final String location) throws IOException { doSaveContext(); super.sendRedirect(location); } } private boolean allowSessionCreation = true; private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); private final boolean cloneFromHttpSession = false; private final Object contextObject = SecurityContextHolder.createEmptyContext(); private boolean disableUrlRewriting = false; protected final Log logger = LogFactory.getLog(this.getClass()); private final Class<? extends SecurityContext> securityContextClass = null; public String springSecurityContextKey = "SPRING_SECURITY_CONTEXT"; private Object cloneContext(final Object context) { Object clonedContext = null; Assert.isInstanceOf(Cloneable.class, context, "Context must implement Cloneable and provide a Object.clone() method"); try { final Method m = context.getClass().getMethod("clone", new Class[] {}); if (!m.isAccessible()) { m.setAccessible(true); } clonedContext = m.invoke(context, new Object[] {}); } catch (final Exception ex) { ReflectionUtils.handleReflectionException(ex); } return clonedContext; } @Override public boolean containsContext(final HttpServletRequest request) { final HttpSession session = request.getSession(false); if (session == null) { return false; } return session.getAttribute(this.springSecurityContextKey) != null; } SecurityContext generateNewContext() { SecurityContext context = null; if (this.securityContextClass == null) { context = SecurityContextHolder.createEmptyContext(); return context; } try { context = this.securityContextClass.newInstance(); } catch (final Exception e) { ReflectionUtils.handleReflectionException(e); } return context; } public String getSpringSecurityContextKey() { return this.springSecurityContextKey; } public boolean isDisableUrlRewriting() { return this.disableUrlRewriting; } @Override public SecurityContext loadContext(final HttpRequestResponseHolder requestResponseHolder) { final HttpServletRequest request = requestResponseHolder.getRequest(); final HttpServletResponse response = requestResponseHolder.getResponse(); final HttpSession httpSession = request.getSession(false); SecurityContext context = readSecurityContextFromSession(httpSession); if (context == null) { context = generateNewContext(); } requestResponseHolder.setResponse( new SaveToSessionResponseWrapper(response, request, httpSession != null, context.hashCode())); return context; } private SecurityContext readSecurityContextFromSession(final HttpSession httpSession) { final boolean debug = this.logger.isDebugEnabled(); if (httpSession == null) { return null; } Object contextFromSession = httpSession.getAttribute(this.springSecurityContextKey); if (contextFromSession == null) { return null; } if (!(contextFromSession instanceof SecurityContext)) { if (this.logger.isWarnEnabled()) { this.logger .warn("SPRING_SECURITY_CONTEXT did not contain a SecurityContext but contained: '" + contextFromSession + "'; are you improperly modifying the HttpSession directly " + "(you should always use SecurityContextHolder) or using the HttpSession attribute " + "reserved for this class?"); } return null; } if (this.cloneFromHttpSession) { contextFromSession = cloneContext(contextFromSession); } if (debug) { this.logger.debug("Obtained a valid SecurityContext from SPRING_SECURITY_CONTEXT: '" + contextFromSession + "'"); } return (SecurityContext)contextFromSession; } @Override public void saveContext(final SecurityContext context, final HttpServletRequest request, final HttpServletResponse response) { final SaveToSessionResponseWrapper responseWrapper = (SaveToSessionResponseWrapper)response; if (!responseWrapper.isContextSaved()) { responseWrapper.saveContext(context); } } public void setAllowSessionCreation(final boolean allowSessionCreation) { this.allowSessionCreation = allowSessionCreation; } public void setDisableUrlRewriting(final boolean disableUrlRewriting) { this.disableUrlRewriting = disableUrlRewriting; } public void setSpringSecurityContextKey(final String springSecurityContextKey) { this.springSecurityContextKey = springSecurityContextKey; } }