package com.psddev.cms.rtc;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableMap;
import com.psddev.dari.db.Database;
import com.psddev.dari.db.Query;
import com.psddev.dari.util.ObjectUtils;
import com.psddev.dari.util.Settings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.AsyncContext;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
class RtcAsyncContext {
private static final Logger LOGGER = LoggerFactory.getLogger(RtcAsyncContext.class);
private final ConcurrentMap<UUID, RtcAsyncContext> contexts;
private final AsyncContext context;
private final UUID userId;
private final UUID sessionId;
private final AtomicBoolean disconnected = new AtomicBoolean();
public RtcAsyncContext(
ConcurrentMap<UUID, RtcAsyncContext> contexts,
HttpServletRequest request,
UUID userId)
throws IOException {
// Remove all request attributes to minimize memory usage.
for (Enumeration<String> names = request.getAttributeNames(); names.hasMoreElements();) {
request.removeAttribute(names.nextElement());
}
// Create the session first so that if there's a database error,
// the underlying context isn't started.
RtcSession session = new RtcSession();
session.setUserId(userId);
session.setLastPing(Database.Static.getDefault().now());
session.save();
this.contexts = contexts;
this.context = request.startAsync();
// Forcibly close the underlying context after some time to prevent
// potential connection leaks.
context.setTimeout(Settings.getOrDefault(long.class, RtcFilter.ASYNC_CONTEXT_TIMEOUT_SETTING, 15L * 60 * 1000));
// Make sure everything's cleaned up when the underlying context
// goes away for any reason.
context.addListener(new RtcAsyncContextListener(this));
this.userId = userId;
this.sessionId = session.getId();
contexts.put(sessionId, this);
// Start the event stream and send the session ID to the client.
ServletResponse response = context.getResponse();
response.setContentType("text/event-stream");
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
writeEvent(ImmutableMap.of(
"_first", true,
"sessionId", sessionId.toString()));
}
public UUID getUserId() {
return userId;
}
public UUID getSessionId() {
return sessionId;
}
public void writeEvent(Map<String, Object> data) {
try {
ServletResponse response = context.getResponse();
PrintWriter writer = response.getWriter();
writer.write("data:");
writer.write(ObjectUtils.toJson(data));
writer.write("\n\n");
// This is important to force an exception when the client
// disconnects.
response.flushBuffer();
} catch (IOException | RuntimeException error) {
disconnect();
LOGGER.debug(
String.format("Can't write [%s] to [%s]!", data, this),
error);
}
}
public void disconnect() {
if (disconnected.compareAndSet(false, true)) {
UUID sessionId = getSessionId();
contexts.remove(sessionId);
try {
context.complete();
} catch (RuntimeException error) {
LOGGER.debug(
String.format("Can't complete [%s]!", this),
error);
}
RtcSession session = Query
.from(RtcSession.class)
.where("_id = ?", sessionId)
.first();
if (session != null) {
session.disconnect();
}
}
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("userId", getUserId())
.add("sessionId", getSessionId())
.toString();
}
}