package org.hsweb.web.socket.message; import org.hsweb.web.bean.po.user.User; import org.hsweb.web.core.exception.AuthorizeException; import org.hsweb.web.core.session.HttpSessionManager; import org.hsweb.web.socket.utils.SessionUtils; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; /** * Created by zhouhao on 16-5-29. */ public class SimpleWebSocketMessageManager implements WebSocketMessageManager { private static final ConcurrentMap<String, Map<String, WebSocketSession>> session_map = new ConcurrentHashMap<>(); private static final ConcurrentMap<String, Subscribe> subscribe_map = new ConcurrentHashMap<>(); private static final ConcurrentMap<String, Map<String, Queue<WebSocketMessage>>> message_map = new ConcurrentHashMap<>(); private HttpSessionManager httpSessionManager; public void setHttpSessionManager(HttpSessionManager httpSessionManager) { this.httpSessionManager = httpSessionManager; } public Map<String, WebSocketSession> getSessionMap(String userId) { Map<String, WebSocketSession> map = session_map.get(userId); if (map == null) { map = Collections.synchronizedMap(new HashMap<>()); session_map.put(userId, map); } return map; } @Override public boolean publish(WebSocketMessage message) throws IOException { String to = message.getTo(); Subscribe subscribe = subscribe_map.get(to); Map<String, WebSocketSession> socketSession = getSessionMap(message.getTo()); if (!socketSession.isEmpty() && subscribe != null) { if (message.getSessionId() == null) socketSession.values().forEach(session -> { try { if (subscribe.getTopic(session.getId()).contains(message.getType())) session.sendMessage(new TextMessage(message.toString())); } catch (IOException e) { e.printStackTrace(); saveMessage(message); } }); else { WebSocketSession session = socketSession.get(message.getSessionId()); if (session != null && session.isOpen()) { session.sendMessage(new TextMessage(message.toString())); } } return true; } return false; } protected Queue<WebSocketMessage> getMessageQueue(String userId, String type) { Map<String, Queue<WebSocketMessage>> message_type_map = message_map.get(userId); if (message_type_map == null) { message_type_map = new ConcurrentHashMap<>(); message_map.putIfAbsent(userId, message_type_map); } Queue<WebSocketMessage> queue = message_type_map.get(type); if (queue == null) { queue = new ConcurrentLinkedQueue<>(); message_type_map.putIfAbsent(type, queue); } return queue; } protected void saveMessage(WebSocketMessage message) { getMessageQueue(message.getTo(), message.getType()).offer(message); } @Override public boolean deSubscribe(String type, String userId, WebSocketSession socketSession) { return getSubscribe(userId).getTopic(socketSession.getId()).remove(type); } protected Subscribe getSubscribe(String userId) { Subscribe subscribe = subscribe_map.get(userId); synchronized (subscribe_map) { if (subscribe == null) { subscribe = new Subscribe(); subscribe.setUserId(userId); subscribe_map.put(userId, subscribe); } } return subscribe; } @Override public boolean subscribe(String type, String userId, WebSocketSession socketSession) { getSubscribe(userId).getTopic(socketSession.getId()).add(type); //推送未读消息 Queue<WebSocketMessage> queue = getMessageQueue(userId, type); while (!queue.isEmpty()) { try { publish(queue.poll()); } catch (IOException e) { } } return true; } class Subscribe { private String userId; private Map<String, Set<String>> topic = Collections.synchronizedMap(new HashMap<>()); public String getUserId() { return userId; } public void setUserId(String userId) { this.userId = userId; } public void cancelTopic(String sessionId) { topic.remove(sessionId); } public Set<String> getTopic(String sessionId) { Set<String> tp = topic.get(sessionId); if (tp == null) { tp = Collections.synchronizedSet(new HashSet<>()); topic.putIfAbsent(sessionId, tp); } return tp; } } @Override public void onSessionConnect(WebSocketSession session) throws Exception { User user = getUser(session); if (user == null) { throw new AuthorizeException("未登录"); } getSessionMap(user.getId()).put(session.getId(), session); } @Override public void onSessionClose(WebSocketSession session) throws Exception { User user = getUser(session); if (user == null) { return; } Subscribe subscribe = subscribe_map.get(user.getId()); if (subscribe != null) subscribe.cancelTopic(session.getId()); getSessionMap(user.getId()).remove(session.getId()); } protected User getUser(WebSocketSession session) { return SessionUtils.getUser(session, httpSessionManager); } }