/*
* Copyright 2017 OmniFaces
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package org.omnifaces.cdi.push;
import static java.lang.String.format;
import static java.util.Collections.emptySet;
import static java.util.logging.Level.FINE;
import static java.util.logging.Level.WARNING;
import static javax.websocket.CloseReason.CloseCodes.NORMAL_CLOSURE;
import static org.omnifaces.cdi.push.SocketEndpoint.PARAM_CHANNEL;
import static org.omnifaces.util.Beans.getReference;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.logging.Logger;
import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.util.AnnotationLiteral;
import javax.inject.Inject;
import javax.websocket.CloseReason;
import javax.websocket.Session;
import org.omnifaces.cdi.push.SocketEvent.Closed;
import org.omnifaces.cdi.push.SocketEvent.Opened;
import org.omnifaces.util.Beans;
import org.omnifaces.util.Hacks;
import org.omnifaces.util.Json;
/**
* <p>
* This web socket session manager holds all web socket sessions by their channel identifier.
*
* @author Bauke Scholtz
* @see SocketEndpoint
* @since 2.3
*/
@ApplicationScoped
public class SocketSessionManager {
// Constants ------------------------------------------------------------------------------------------------------
private static final Logger logger = Logger.getLogger(SocketSessionManager.class.getName());
private static final CloseReason REASON_EXPIRED = new CloseReason(NORMAL_CLOSURE, "Expired");
private static final AnnotationLiteral<Opened> SESSION_OPENED = new AnnotationLiteral<Opened>() {
private static final long serialVersionUID = 1L;
};
private static final AnnotationLiteral<Closed> SESSION_CLOSED = new AnnotationLiteral<Closed>() {
private static final long serialVersionUID = 1L;
};
private static final String WARNING_TOMCAT_WEB_SOCKET_BOMBED =
"Tomcat cannot handle concurrent push messages. A push message has been sent only after %s retries."
+ " Consider rate limiting sending push messages. For example, once every 500ms.";
private static volatile SocketSessionManager instance;
// Properties -----------------------------------------------------------------------------------------------------
private final ConcurrentHashMap<String, Collection<Session>> socketSessions = new ConcurrentHashMap<>();
@Inject
private SocketUserManager socketUsers;
// Actions --------------------------------------------------------------------------------------------------------
/**
* Register given channel identifier.
* @param channelId The channel identifier to register.
*/
protected void register(String channelId) {
if (!socketSessions.containsKey(channelId)) {
socketSessions.putIfAbsent(channelId, new ConcurrentLinkedQueue<Session>());
}
}
/**
* Register given channel identifiers.
* @param channelIds The channel identifiers to register.
*/
protected void register(Iterable<String> channelIds) {
for (String channelId : channelIds) {
register(channelId);
}
}
/**
* On open, add given web socket session to the mapping associated with its channel identifier and returns
* <code>true</code> if it's accepted (i.e. the channel identifier is known) and the same session hasn't been added
* before, otherwise <code>false</code>.
* @param session The opened web socket session.
* @return <code>true</code> if given web socket session is accepted and is new, otherwise <code>false</code>.
*/
protected boolean add(Session session) {
String channelId = getChannelId(session);
Collection<Session> sessions = socketSessions.get(channelId);
if (sessions != null && sessions.add(session)) {
Serializable user = socketUsers.getUser(getChannel(session), channelId);
if (user != null) {
session.getUserProperties().put("user", user);
}
fireEvent(session, null, SESSION_OPENED);
return true;
}
return false;
}
/**
* Encode the given message object as JSON and send it to all open web socket sessions associated with given web
* socket channel identifier.
* @param channelId The web socket channel identifier.
* @param message The push message object.
* @return The results of the send operation. If it returns an empty set, then there was no open session associated
* with given channel identifier. The returned futures will return <code>null</code> on {@link Future#get()} if the
* message was successfully delivered and otherwise throw {@link ExecutionException}.
*/
protected Set<Future<Void>> send(String channelId, Object message) {
Collection<Session> sessions = (channelId != null) ? socketSessions.get(channelId) : null;
if (sessions != null && !sessions.isEmpty()) {
Set<Future<Void>> results = new HashSet<>(sessions.size());
String json = Json.encode(message);
for (Session session : sessions) {
send(session, json, results, 0);
}
return results;
}
return emptySet();
}
private void send(Session session, String text, Set<Future<Void>> results, int retries) {
if (session.isOpen()) {
try {
results.add(session.getAsyncRemote().sendText(text));
if (retries > 0 && logger.isLoggable(WARNING)) {
logger.log(WARNING, format(WARNING_TOMCAT_WEB_SOCKET_BOMBED, retries));
}
}
catch (IllegalStateException e) {
if (Hacks.isTomcatWebSocketBombed(session, e)) {
synchronized (session) {
send(session, text, results, retries + 1);
}
}
else {
throw e;
}
}
}
}
/**
* On close, remove given web socket session from the mapping.
* @param session The closed web socket session.
* @param reason The close reason.
*/
protected void remove(Session session, CloseReason reason) {
Collection<Session> sessions = socketSessions.get(getChannelId(session));
if (sessions != null && sessions.remove(session)) {
fireEvent(session, reason, SESSION_CLOSED);
}
}
/**
* Deregister given channel identifiers and explicitly close all open web socket sessions associated with it.
* @param channelIds The channel identifiers to deregister.
*/
protected void deregister(Iterable<String> channelIds) {
for (String channelId : channelIds) {
Collection<Session> sessions = socketSessions.get(channelId);
if (sessions != null) {
for (Session session : sessions) {
close(session);
}
}
}
}
/**
* Close given web socket session.
* @param session The web socket session to close.
*/
private void close(Session session) {
if (session.isOpen()) {
try {
session.close(REASON_EXPIRED);
}
catch (IOException ignore) {
logger.log(FINE, "Ignoring thrown exception; there is nothing more we could do here.", ignore);
}
}
}
// Internal -------------------------------------------------------------------------------------------------------
/**
* Internal usage only. Awkward workaround for it being unavailable via @Inject in endpoint in Tomcat+Weld/OWB.
*/
static SocketSessionManager getInstance() {
if (instance == null) {
instance = getReference(SocketSessionManager.class);
}
return instance;
}
// Helpers --------------------------------------------------------------------------------------------------------
private static String getChannel(Session session) {
return session.getPathParameters().get(PARAM_CHANNEL);
}
private static String getChannelId(Session session) {
return getChannel(session) + "?" + session.getQueryString();
}
private static void fireEvent(Session session, CloseReason reason, AnnotationLiteral<?> qualifier) {
Serializable user = (Serializable) session.getUserProperties().get("user");
Beans.fireEvent(new SocketEvent(getChannel(session), user, (reason != null) ? reason.getCloseCode() : null), qualifier);
}
}