package com.psddev.cms.rtc;
import com.google.common.collect.ImmutableMap;
import com.psddev.cms.db.ToolUser;
import com.psddev.cms.tool.AuthenticationFilter;
import com.psddev.cms.tool.CmsTool;
import com.psddev.dari.db.Database;
import com.psddev.dari.db.Query;
import com.psddev.dari.util.AbstractFilter;
import com.psddev.dari.util.ObjectUtils;
import com.psddev.dari.util.TypeDefinition;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
/**
* Filter that handles the real-time communication between the server
* and the clients.
*/
public class RtcFilter extends AbstractFilter implements AbstractFilter.Auto {
public static final String PATH = "/_rtc";
public static final String ASYNC_CONTEXT_TIMEOUT_SETTING = "brightspot/rtc/asyncContextTimeout";
private static final String ATTRIBUTE_PREFIX = RtcFilter.class.getName() + ".";
private static final String USER_ID_ATTRIBUTE = ATTRIBUTE_PREFIX + "userId";
private final ConcurrentMap<UUID, RtcAsyncContext> contexts = new ConcurrentHashMap<>();
private volatile RtcAsyncContextPingRunnable pingRunnable;
private volatile ScheduledExecutorService pingExecutor;
private volatile RtcSessionUpdateNotifier sessionUpdateNotifier;
private volatile RtcEventUpdateNotifier eventUpdateNotifier;
public static UUID getUserId(HttpServletRequest request) {
return (UUID) request.getAttribute(USER_ID_ATTRIBUTE);
}
public static void setUserId(HttpServletRequest request, UUID userId) {
request.setAttribute(USER_ID_ATTRIBUTE, userId);
}
@Override
public void updateDependencies(Class<? extends AbstractFilter> filterClass, List<Class<? extends Filter>> dependencies) {
dependencies.add(getClass());
}
@Override
protected void doInit() throws ServletException {
// Ping all clients every 5 seconds to detect disconnects.
pingRunnable = new RtcAsyncContextPingRunnable(contexts);
pingExecutor = Executors.newSingleThreadScheduledExecutor();
pingExecutor.scheduleWithFixedDelay(pingRunnable, 0, 5, TimeUnit.SECONDS);
Database database = Database.Static.getDefault();
sessionUpdateNotifier = new RtcSessionUpdateNotifier(contexts);
database.addUpdateNotifier(sessionUpdateNotifier);
eventUpdateNotifier = new RtcEventUpdateNotifier(contexts);
database.addUpdateNotifier(eventUpdateNotifier);
}
@Override
protected void doDestroy() {
contexts.values().forEach(RtcAsyncContext::disconnect);
contexts.clear();
pingRunnable.stop();
pingRunnable = null;
pingExecutor.shutdownNow();
pingExecutor = null;
Database database = Database.Static.getDefault();
if (sessionUpdateNotifier != null) {
database.removeUpdateNotifier(sessionUpdateNotifier);
sessionUpdateNotifier = null;
}
if (eventUpdateNotifier != null) {
database.removeUpdateNotifier(eventUpdateNotifier);
eventUpdateNotifier = null;
}
}
@Override
protected void doRequest(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException {
if (!request.getServletPath().startsWith(PATH)) {
chain.doFilter(request, response);
return;
}
// RTC disabled?
CmsTool cms = Query.from(CmsTool.class).first();
if (cms != null && cms.isDisableRtc()) {
chain.doFilter(request, response);
return;
}
// Make sure that the user is available.
UUID userId = getUserId(request);
if (userId == null) {
ToolUser user = AuthenticationFilter.Static.getUser(request);
if (user != null) {
userId = user.getId();
}
}
if (userId == null) {
return;
}
// On GET, start the RTC connection.
String method = request.getMethod();
if ("get".equalsIgnoreCase(method)) {
new RtcAsyncContext(contexts, request, userId);
return;
}
// On POST...
if (!"post".equalsIgnoreCase(method)) {
throw new UnsupportedOperationException(String.format(
"[%s] method isn't supported!",
method));
}
// Make sure that the session is available.
RtcSession session = Query
.from(RtcSession.class)
.where("_id = ?", ObjectUtils.to(UUID.class, request.getParameter("sessionId")))
.first();
if (session == null) {
throw new IllegalArgumentException("Can't process RTC request without a session!");
}
String message = request.getParameter("message");
@SuppressWarnings("unchecked")
Map<String, Object> messageJson = (Map<String, Object>) ObjectUtils.fromJson(message);
String messageType = (String) messageJson.get("type");
// Ping from the client to prevent RtcSessionTask from deleting the
// session.
if ("ping".equals(messageType)) {
session.setLastPing(Database.Static.getDefault().now());
session.save();
return;
}
@SuppressWarnings("unchecked")
Map<String, Object> messageData = (Map<String, Object>) messageJson.get("data");
if ("restore".equals(messageType)) {
Iterable<?> restores = createInstance(RtcState.class, messageJson).create(messageData);
if (restores != null) {
UUID currentUserId = userId;
List<Map<String, Object>> items = new ArrayList<>();
restores.forEach(event ->
RtcBroadcast.forEachBroadcast(event, (broadcast, data) -> {
if (broadcast.shouldBroadcast(data, currentUserId)) {
items.add(ImmutableMap.of(
"broadcast", broadcast.getClass().getName(),
"data", data));
}
}));
if (!items.isEmpty()) {
response.setContentType("application/json");
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
response.getWriter().write(ObjectUtils.toJson(items));
response.flushBuffer();
}
}
return;
}
if ("execute".equals(messageType)) {
createInstance(RtcAction.class, messageJson).execute(messageData, userId, session.getId());
return;
}
if ("disconnect".equals(messageType)) {
Iterable<?> disconnects = createInstance(RtcState.class, messageJson).close(messageData, userId);
if (disconnects != null) {
Database database = Database.Static.getDefault();
database.beginWrites();
try {
disconnects.forEach(event -> {
if (event instanceof RtcEvent) {
((RtcEvent) event).onDisconnect();
}
});
database.commitWrites();
} finally {
database.endWrites();
}
}
return;
}
throw new UnsupportedOperationException(String.format(
"[%s] type isn't supported!",
messageType));
}
private <T> T createInstance(Class<T> returnClass, Map<String, Object> messageJson) {
String className = (String) messageJson.get("className");
Class<?> c = ObjectUtils.getClassByName(className);
if (c == null) {
throw new IllegalArgumentException(String.format(
"[%s] isn't a valid class name!",
className));
} else if (!returnClass.isAssignableFrom(c)) {
throw new IllegalArgumentException(String.format(
"[%s] isn't assignable from [%s]!",
returnClass.getName(),
c.getName()));
}
return (T) TypeDefinition.getInstance(c).newInstance();
}
}