/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.security.Principal;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.Ordered;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.SimpSession;
import org.springframework.messaging.simp.user.SimpSubscription;
import org.springframework.messaging.simp.user.SimpSubscriptionMatcher;
import org.springframework.messaging.simp.user.SimpUser;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
/**
* A default implementation of {@link SimpUserRegistry} that relies on
* {@link AbstractSubProtocolEvent} application context events to keep track of
* connected users and their subscriptions.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener {
/* Primary lookup that holds all users and their sessions */
private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<>();
/* Secondary lookup across all sessions by id */
private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<>();
private final Object sessionLock = new Object();
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
// SmartApplicationListener methods
@Override
public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
return AbstractSubProtocolEvent.class.isAssignableFrom(eventType);
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
Message<?> message = subProtocolEvent.getMessage();
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
String sessionId = accessor.getSessionId();
if (event instanceof SessionSubscribeEvent) {
LocalSimpSession session = this.sessions.get(sessionId);
if (session != null) {
String id = accessor.getSubscriptionId();
String destination = accessor.getDestination();
session.addSubscription(id, destination);
}
}
else if (event instanceof SessionConnectedEvent) {
Principal user = subProtocolEvent.getUser();
if (user == null) {
return;
}
String name = user.getName();
if (user instanceof DestinationUserNameProvider) {
name = ((DestinationUserNameProvider) user).getDestinationUserName();
}
synchronized (this.sessionLock) {
LocalSimpUser simpUser = this.users.get(name);
if (simpUser == null) {
simpUser = new LocalSimpUser(name);
this.users.put(name, simpUser);
}
LocalSimpSession session = new LocalSimpSession(sessionId, simpUser);
simpUser.addSession(session);
this.sessions.put(sessionId, session);
}
}
else if (event instanceof SessionDisconnectEvent) {
synchronized (this.sessionLock) {
LocalSimpSession session = this.sessions.remove(sessionId);
if (session != null) {
LocalSimpUser user = session.getUser();
user.removeSession(sessionId);
if (!user.hasSessions()) {
this.users.remove(user.getName());
}
}
}
}
else if (event instanceof SessionUnsubscribeEvent) {
LocalSimpSession session = this.sessions.get(sessionId);
if (session != null) {
String subscriptionId = accessor.getSubscriptionId();
session.removeSubscription(subscriptionId);
}
}
}
@Override
public boolean supportsSourceType(Class<?> sourceType) {
return true;
}
// SimpUserRegistry methods
@Override
public SimpUser getUser(String userName) {
return this.users.get(userName);
}
@Override
public Set<SimpUser> getUsers() {
return new HashSet<>(this.users.values());
}
@Override
public int getUserCount() {
return this.users.size();
}
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
Set<SimpSubscription> result = new HashSet<>();
for (LocalSimpSession session : this.sessions.values()) {
for (SimpSubscription subscription : session.subscriptions.values()) {
if (matcher.match(subscription)) {
result.add(subscription);
}
}
}
return result;
}
@Override
public String toString() {
return "users=" + this.users;
}
private static class LocalSimpUser implements SimpUser {
private final String name;
private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<>(1);
public LocalSimpUser(String userName) {
Assert.notNull(userName, "User name must not be null");
this.name = userName;
}
@Override
public String getName() {
return this.name;
}
@Override
public boolean hasSessions() {
return !this.userSessions.isEmpty();
}
@Override
public SimpSession getSession(String sessionId) {
return (sessionId != null ? this.userSessions.get(sessionId) : null);
}
@Override
public Set<SimpSession> getSessions() {
return new HashSet<>(this.userSessions.values());
}
void addSession(SimpSession session) {
this.userSessions.put(session.getId(), session);
}
void removeSession(String sessionId) {
this.userSessions.remove(sessionId);
}
@Override
public boolean equals(Object other) {
return (this == other ||
(other instanceof SimpUser && this.name.equals(((SimpUser) other).getName())));
}
@Override
public int hashCode() {
return this.name.hashCode();
}
@Override
public String toString() {
return "name=" + this.name + ", sessions=" + this.userSessions;
}
}
private static class LocalSimpSession implements SimpSession {
private final String id;
private final LocalSimpUser user;
private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<>(4);
public LocalSimpSession(String id, LocalSimpUser user) {
Assert.notNull(id, "Id must not be null");
Assert.notNull(user, "User must not be null");
this.id = id;
this.user = user;
}
@Override
public String getId() {
return this.id;
}
@Override
public LocalSimpUser getUser() {
return this.user;
}
@Override
public Set<SimpSubscription> getSubscriptions() {
return new HashSet<>(this.subscriptions.values());
}
void addSubscription(String id, String destination) {
this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this));
}
void removeSubscription(String id) {
this.subscriptions.remove(id);
}
@Override
public boolean equals(Object other) {
return (this == other ||
(other instanceof SimpSubscription && this.id.equals(((SimpSubscription) other).getId())));
}
@Override
public int hashCode() {
return this.id.hashCode();
}
@Override
public String toString() {
return "id=" + this.id + ", subscriptions=" + this.subscriptions;
}
}
private static class LocalSimpSubscription implements SimpSubscription {
private final String id;
private final LocalSimpSession session;
private final String destination;
public LocalSimpSubscription(String id, String destination, LocalSimpSession session) {
Assert.notNull(id, "Id must not be null");
Assert.hasText(destination, "Destination must not be empty");
Assert.notNull(session, "Session must not be null");
this.id = id;
this.destination = destination;
this.session = session;
}
@Override
public String getId() {
return this.id;
}
@Override
public LocalSimpSession getSession() {
return this.session;
}
@Override
public String getDestination() {
return this.destination;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof SimpSubscription)) {
return false;
}
SimpSubscription otherSubscription = (SimpSubscription) other;
return (this.id.equals(otherSubscription.getId()) &&
getSession().getId().equals(otherSubscription.getSession().getId()));
}
@Override
public int hashCode() {
return this.id.hashCode() * 31 + getSession().getId().hashCode();
}
@Override
public String toString() {
return "destination=" + this.destination;
}
}
}