/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.messaging; import java.security.Principal; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.springframework.context.ApplicationEvent; import org.springframework.context.event.SmartApplicationListener; import org.springframework.core.Ordered; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.SimpSession; import org.springframework.messaging.simp.user.SimpSubscription; import org.springframework.messaging.simp.user.SimpSubscriptionMatcher; import org.springframework.messaging.simp.user.SimpUser; import org.springframework.messaging.simp.user.SimpUserRegistry; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** * A default implementation of {@link SimpUserRegistry} that relies on * {@link AbstractSubProtocolEvent} application context events to keep track of * connected users and their subscriptions. * * @author Rossen Stoyanchev * @since 4.2 */ public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener { /* Primary lookup that holds all users and their sessions */ private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<>(); /* Secondary lookup across all sessions by id */ private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<>(); private final Object sessionLock = new Object(); @Override public int getOrder() { return Ordered.LOWEST_PRECEDENCE; } // SmartApplicationListener methods @Override public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) { return AbstractSubProtocolEvent.class.isAssignableFrom(eventType); } @Override public void onApplicationEvent(ApplicationEvent event) { AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event; Message<?> message = subProtocolEvent.getMessage(); SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); String sessionId = accessor.getSessionId(); if (event instanceof SessionSubscribeEvent) { LocalSimpSession session = this.sessions.get(sessionId); if (session != null) { String id = accessor.getSubscriptionId(); String destination = accessor.getDestination(); session.addSubscription(id, destination); } } else if (event instanceof SessionConnectedEvent) { Principal user = subProtocolEvent.getUser(); if (user == null) { return; } String name = user.getName(); if (user instanceof DestinationUserNameProvider) { name = ((DestinationUserNameProvider) user).getDestinationUserName(); } synchronized (this.sessionLock) { LocalSimpUser simpUser = this.users.get(name); if (simpUser == null) { simpUser = new LocalSimpUser(name); this.users.put(name, simpUser); } LocalSimpSession session = new LocalSimpSession(sessionId, simpUser); simpUser.addSession(session); this.sessions.put(sessionId, session); } } else if (event instanceof SessionDisconnectEvent) { synchronized (this.sessionLock) { LocalSimpSession session = this.sessions.remove(sessionId); if (session != null) { LocalSimpUser user = session.getUser(); user.removeSession(sessionId); if (!user.hasSessions()) { this.users.remove(user.getName()); } } } } else if (event instanceof SessionUnsubscribeEvent) { LocalSimpSession session = this.sessions.get(sessionId); if (session != null) { String subscriptionId = accessor.getSubscriptionId(); session.removeSubscription(subscriptionId); } } } @Override public boolean supportsSourceType(Class<?> sourceType) { return true; } // SimpUserRegistry methods @Override public SimpUser getUser(String userName) { return this.users.get(userName); } @Override public Set<SimpUser> getUsers() { return new HashSet<>(this.users.values()); } @Override public int getUserCount() { return this.users.size(); } public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) { Set<SimpSubscription> result = new HashSet<>(); for (LocalSimpSession session : this.sessions.values()) { for (SimpSubscription subscription : session.subscriptions.values()) { if (matcher.match(subscription)) { result.add(subscription); } } } return result; } @Override public String toString() { return "users=" + this.users; } private static class LocalSimpUser implements SimpUser { private final String name; private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<>(1); public LocalSimpUser(String userName) { Assert.notNull(userName, "User name must not be null"); this.name = userName; } @Override public String getName() { return this.name; } @Override public boolean hasSessions() { return !this.userSessions.isEmpty(); } @Override public SimpSession getSession(String sessionId) { return (sessionId != null ? this.userSessions.get(sessionId) : null); } @Override public Set<SimpSession> getSessions() { return new HashSet<>(this.userSessions.values()); } void addSession(SimpSession session) { this.userSessions.put(session.getId(), session); } void removeSession(String sessionId) { this.userSessions.remove(sessionId); } @Override public boolean equals(Object other) { return (this == other || (other instanceof SimpUser && this.name.equals(((SimpUser) other).getName()))); } @Override public int hashCode() { return this.name.hashCode(); } @Override public String toString() { return "name=" + this.name + ", sessions=" + this.userSessions; } } private static class LocalSimpSession implements SimpSession { private final String id; private final LocalSimpUser user; private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<>(4); public LocalSimpSession(String id, LocalSimpUser user) { Assert.notNull(id, "Id must not be null"); Assert.notNull(user, "User must not be null"); this.id = id; this.user = user; } @Override public String getId() { return this.id; } @Override public LocalSimpUser getUser() { return this.user; } @Override public Set<SimpSubscription> getSubscriptions() { return new HashSet<>(this.subscriptions.values()); } void addSubscription(String id, String destination) { this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this)); } void removeSubscription(String id) { this.subscriptions.remove(id); } @Override public boolean equals(Object other) { return (this == other || (other instanceof SimpSubscription && this.id.equals(((SimpSubscription) other).getId()))); } @Override public int hashCode() { return this.id.hashCode(); } @Override public String toString() { return "id=" + this.id + ", subscriptions=" + this.subscriptions; } } private static class LocalSimpSubscription implements SimpSubscription { private final String id; private final LocalSimpSession session; private final String destination; public LocalSimpSubscription(String id, String destination, LocalSimpSession session) { Assert.notNull(id, "Id must not be null"); Assert.hasText(destination, "Destination must not be empty"); Assert.notNull(session, "Session must not be null"); this.id = id; this.destination = destination; this.session = session; } @Override public String getId() { return this.id; } @Override public LocalSimpSession getSession() { return this.session; } @Override public String getDestination() { return this.destination; } @Override public boolean equals(Object other) { if (this == other) { return true; } if (!(other instanceof SimpSubscription)) { return false; } SimpSubscription otherSubscription = (SimpSubscription) other; return (this.id.equals(otherSubscription.getId()) && getSession().getId().equals(otherSubscription.getSession().getId())); } @Override public int hashCode() { return this.id.hashCode() * 31 + getSession().getId().hashCode(); } @Override public String toString() { return "destination=" + this.destination; } } }