/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.core.session;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.ApplicationListener;
import org.springframework.util.Assert;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
/**
* Default implementation of
* {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} which
* listens for {@link org.springframework.security.core.session.SessionDestroyedEvent
* SessionDestroyedEvent}s published in the Spring application context.
* <p>
* For this class to function correctly in a web application, it is important that you
* register an <a href="{@docRoot}/org/springframework/security/web/session/HttpSessionEventPublisher.html">HttpSessionEventPublisher</a>
* in the <tt>web.xml</tt> file so that this class is notified of sessions that expire.
*
* @author Ben Alex
* @author Luke Taylor
*/
public class SessionRegistryImpl implements SessionRegistry,
ApplicationListener<SessionDestroyedEvent> {
// ~ Instance fields
// ================================================================================================
protected final Log logger = LogFactory.getLog(SessionRegistryImpl.class);
/** <principal:Object,SessionIdSet> */
private final ConcurrentMap<Object, Set<String>> principals = new ConcurrentHashMap<Object, Set<String>>();
/** <sessionId:Object,SessionInformation> */
private final Map<String, SessionInformation> sessionIds = new ConcurrentHashMap<String, SessionInformation>();
// ~ Methods
// ========================================================================================================
public List<Object> getAllPrincipals() {
return new ArrayList<Object>(principals.keySet());
}
public List<SessionInformation> getAllSessions(Object principal,
boolean includeExpiredSessions) {
final Set<String> sessionsUsedByPrincipal = principals.get(principal);
if (sessionsUsedByPrincipal == null) {
return Collections.emptyList();
}
List<SessionInformation> list = new ArrayList<SessionInformation>(
sessionsUsedByPrincipal.size());
for (String sessionId : sessionsUsedByPrincipal) {
SessionInformation sessionInformation = getSessionInformation(sessionId);
if (sessionInformation == null) {
continue;
}
if (includeExpiredSessions || !sessionInformation.isExpired()) {
list.add(sessionInformation);
}
}
return list;
}
public SessionInformation getSessionInformation(String sessionId) {
Assert.hasText(sessionId, "SessionId required as per interface contract");
return sessionIds.get(sessionId);
}
public void onApplicationEvent(SessionDestroyedEvent event) {
String sessionId = event.getId();
removeSessionInformation(sessionId);
}
public void refreshLastRequest(String sessionId) {
Assert.hasText(sessionId, "SessionId required as per interface contract");
SessionInformation info = getSessionInformation(sessionId);
if (info != null) {
info.refreshLastRequest();
}
}
public void registerNewSession(String sessionId, Object principal) {
Assert.hasText(sessionId, "SessionId required as per interface contract");
Assert.notNull(principal, "Principal required as per interface contract");
if (logger.isDebugEnabled()) {
logger.debug("Registering session " + sessionId + ", for principal "
+ principal);
}
if (getSessionInformation(sessionId) != null) {
removeSessionInformation(sessionId);
}
sessionIds.put(sessionId,
new SessionInformation(principal, sessionId, new Date()));
Set<String> sessionsUsedByPrincipal = principals.get(principal);
if (sessionsUsedByPrincipal == null) {
sessionsUsedByPrincipal = new CopyOnWriteArraySet<String>();
Set<String> prevSessionsUsedByPrincipal = principals.putIfAbsent(principal,
sessionsUsedByPrincipal);
if (prevSessionsUsedByPrincipal != null) {
sessionsUsedByPrincipal = prevSessionsUsedByPrincipal;
}
}
sessionsUsedByPrincipal.add(sessionId);
if (logger.isTraceEnabled()) {
logger.trace("Sessions used by '" + principal + "' : "
+ sessionsUsedByPrincipal);
}
}
public void removeSessionInformation(String sessionId) {
Assert.hasText(sessionId, "SessionId required as per interface contract");
SessionInformation info = getSessionInformation(sessionId);
if (info == null) {
return;
}
if (logger.isTraceEnabled()) {
logger.debug("Removing session " + sessionId
+ " from set of registered sessions");
}
sessionIds.remove(sessionId);
Set<String> sessionsUsedByPrincipal = principals.get(info.getPrincipal());
if (sessionsUsedByPrincipal == null) {
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Removing session " + sessionId
+ " from principal's set of registered sessions");
}
sessionsUsedByPrincipal.remove(sessionId);
if (sessionsUsedByPrincipal.isEmpty()) {
// No need to keep object in principals Map anymore
if (logger.isDebugEnabled()) {
logger.debug("Removing principal " + info.getPrincipal()
+ " from registry");
}
principals.remove(info.getPrincipal());
}
if (logger.isTraceEnabled()) {
logger.trace("Sessions used by '" + info.getPrincipal() + "' : "
+ sessionsUsedByPrincipal);
}
}
}