/*
* Copyright (C) 2012 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.errai.bus.server.servlet;
import org.jboss.errai.bus.client.protocols.BusCommand;
import org.jboss.errai.bus.server.api.SessionProvider;
import org.jboss.errai.bus.server.service.ErraiConfigAttribs;
import org.jboss.errai.bus.server.service.ErraiService;
import org.jboss.errai.bus.server.service.ErraiServiceConfigurator;
import org.jboss.errai.common.client.protocols.MessageParts;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.FilterConfig;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
/**
* The <tt>AbstractErraiServlet</tt> provides a starting point for creating Http-protocol gateway between the server
* bus and the client buses.
*/
public abstract class AbstractErraiServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private static final Charset UTF_8 = Charset.forName("UTF-8");
protected final byte[] SSE_TERMINATION_BYTES = "\n\n".getBytes();
protected final Logger log = LoggerFactory.getLogger(getClass());
/* New and configured errai service */
protected ErraiService<HttpSession> service;
/* A default Http session provider */
protected SessionProvider<HttpSession> sessionProvider;
private RequestSecurityCheck csrfSecurityCheck;
public enum ConnectionPhase {
NORMAL, CONNECTING, DISCONNECTING, UNKNOWN
}
private boolean longPollingEnabled;
private int longPollTimeout;
private int sseTimeout;
private void configureSettings() {
final ErraiServiceConfigurator config = service.getConfiguration();
final boolean hostedModeTesting = ErraiConfigAttribs.HOSTED_MODE_TESTING.getBoolean(config);
longPollingEnabled = !hostedModeTesting && ErraiConfigAttribs.DO_LONG_POLL.getBoolean(config);
longPollTimeout = ErraiConfigAttribs.LONG_POLL_TIMEOUT.getInt(config);
sseTimeout = ErraiConfigAttribs.SSE_TIMEOUT.getInt(config);
csrfSecurityCheck = (ErraiConfigAttribs.ENABLE_CSRF_BUS_TOKEN.getBoolean(config) ? CSRFTokenCheck.INSTANCE : RequestSecurityCheck.noCheck());
}
public static ConnectionPhase getConnectionPhase(final HttpServletRequest request) {
if (request.getParameter("phase") == null) return ConnectionPhase.NORMAL;
else {
final String phase = request.getParameter("phase");
if ("connection".equals(phase)) {
return ConnectionPhase.CONNECTING;
}
if ("disconnect".equals(phase)) {
return ConnectionPhase.DISCONNECTING;
}
return ConnectionPhase.UNKNOWN;
}
}
@Override
public void init(final ServletConfig config) throws ServletException {
service = ServletBootstrapUtil.getService(config);
sessionProvider = service.getSessionProvider();
configureSettings();
}
public void initAsFilter(final FilterConfig config) throws ServletException {
service = ServletBootstrapUtil.getService(config);
sessionProvider = service.getSessionProvider();
configureSettings();
}
@Override
public void destroy() {
service.stopService();
}
/**
* Writes the message to the output stream
*
* @param stream
* - the stream to write to
* @param encodedMessage
* - the message to write to the stream
*
* @throws java.io.IOException
* - is thrown if any input/output errors occur while writing to the stream
*/
public static void writeToOutputStream(final OutputStream stream, final String encodedMessage) throws IOException {
stream.write('[');
if (encodedMessage == null) {
stream.write('n');
stream.write('u');
stream.write('l');
stream.write('l');
}
else {
for (byte b : encodedMessage.getBytes(UTF_8)) {
stream.write(b);
}
}
stream.write(']');
}
protected void writeExceptionToOutputStream(
final HttpServletResponse httpServletResponse,
final Throwable t) throws IOException {
httpServletResponse.setHeader("Cache-Control", "no-cache");
httpServletResponse.addHeader("Payload-Size", "1");
httpServletResponse.setContentType("application/json");
final OutputStream stream = httpServletResponse.getOutputStream();
stream.write('[');
StringBuilder b = new StringBuilder("{\"ErrorMessage\":\"").append(t.getMessage()).append("\"," +
"\"AdditionalDetails\":\"");
for (StackTraceElement e : t.getStackTrace()) {
b.append(e.toString()).append("<br/>");
}
b.append("\"}");
writeToOutputStream(stream, b.toString());
stream.write(']');
}
protected void sendDisconnectWithReason(OutputStream stream, final String reason) throws IOException {
writeToOutputStream(stream,
reason != null ? "{\"" + MessageParts.ToSubject.name() + "\":\"ClientBus\", \""
+ MessageParts.CommandType.name() + "\":\"" + BusCommand.Disconnect + "\"," +
"\"Reason\":\"" + reason + "\"}"
: "{\"CommandType\":\"" + BusCommand.Disconnect + "\"}");
}
protected void sendDisconnectDueToSessionExpiry(final HttpServletResponse response) throws IOException {
response.setStatus(401);
writeToOutputStream(response.getOutputStream(),
"{\"" + MessageParts.ToSubject.name() + "\":\"ClientBus\", \"" + MessageParts.CommandType.name()
+ "\":\"" + BusCommand.SessionExpired.name() + "\"}");
}
protected static String getClientId(HttpServletRequest request) {
return request.getParameter("clientId");
}
protected int getLongPollTimeout() {
return longPollTimeout;
}
protected final int getSSETimeout() {
return sseTimeout;
}
public boolean isLongPollingEnabled() {
return longPollingEnabled;
}
protected boolean shouldWait(final HttpServletRequest request) {
return longPollingEnabled && "1".equals(request.getParameter("wait"));
}
protected boolean isSSERequest(final HttpServletRequest request) {
return request.getParameter("sse") != null;
}
protected void prepareCometPoll(final HttpServletResponse response) {
response.setContentType("application/json");
}
protected void prepareSSE(final HttpServletResponse response) throws IOException {
response.setContentType("text/event-stream");
response.getOutputStream().write("retry: 500\n\n".getBytes());
}
protected void prepareSSEContinue(final HttpServletResponse response) throws IOException {
response.getOutputStream().write("data: ".getBytes());
}
protected boolean failFromMissingCSRFToken(final HttpServletRequest httpServletRequest) {
return csrfSecurityCheck.isInsecure(httpServletRequest, log);
}
protected void prepareTokenChallenge(final HttpServletRequest request, final HttpServletResponse httpServletResponse) {
csrfSecurityCheck.prepareResponse(request, httpServletResponse, log);
}
}