package com.intrbiz.bergamot.updater;
import static io.netty.handler.codec.http.HttpHeaders.*;
import static io.netty.handler.codec.http.HttpHeaders.Names.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpVersion.*;
import org.apache.log4j.Logger;
import com.intrbiz.Util;
import com.intrbiz.balsa.BalsaApplication;
import com.intrbiz.balsa.BalsaContext;
import com.intrbiz.balsa.engine.session.BalsaSession;
import com.intrbiz.balsa.error.BalsaSecurityException;
import com.intrbiz.bergamot.io.BergamotTranscoder;
import com.intrbiz.bergamot.model.Contact;
import com.intrbiz.bergamot.model.message.api.APIObject;
import com.intrbiz.bergamot.model.message.api.APIRequest;
import com.intrbiz.bergamot.model.message.api.error.APIError;
import com.intrbiz.bergamot.updater.context.ClientContext;
import com.intrbiz.bergamot.updater.handler.RequestHandler;
import com.intrbiz.bergamot.updater.util.CookieJar;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.CharsetUtil;
/**
* Handles handshakes and messages
*/
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
{
private static final Logger logger = Logger.getLogger(WebSocketServerHandler.class);
private static final String WEBSOCKET_PATH = "/websocket";
private final UpdateServer server;
private ClientContext context;
private WebSocketServerHandshaker handshaker;
private BergamotTranscoder transcoder = new BergamotTranscoder();
public WebSocketServerHandler(UpdateServer server)
{
super();
this.server = server;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception
{
final Channel channel = ctx.channel();
// setup the context
this.context = new ClientContext()
{
@Override
public void send(APIObject value)
{
channel.writeAndFlush(new TextWebSocketFrame(transcoder.encodeAsString(value)));
}
};
super.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception
{
// shutdown the context
if (this.context != null) this.context.close();
super.channelInactive(ctx);
}
@Override
public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception
{
if (msg instanceof FullHttpRequest)
{
handleHttpRequest(ctx, (FullHttpRequest) msg);
}
else if (msg instanceof WebSocketFrame)
{
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
private void authenticateContext(FullHttpRequest request) throws BalsaSecurityException
{
// extract the Balsa session cookie
CookieJar cookies = CookieJar.parseCookies(HttpHeaders.getHeader(request, HttpHeaders.Names.COOKIE));
String sessionId = cookies.cookie(BalsaSession.COOKIE_NAME);
if (Util.isEmpty(sessionId))
throw new BalsaSecurityException("No Blasa session cookie found");
// lookup the session
final BalsaApplication application = BalsaApplication.getInstance();
final BalsaSession session = application.getSessionEngine().getSession(sessionId);
if (session == null)
throw new BalsaSecurityException("Invalid session id");
// lookup the site and principal
try
{
BalsaContext.withContext(application, session, () -> {
Contact contact = session.authenticationState().currentPrincipal();
if (contact == null) return null;
WebSocketServerHandler.this.context.setSite(contact.getSite());
WebSocketServerHandler.this.context.setPrincipal(contact);
return null;
});
}
catch (Exception e)
{
if (e instanceof BalsaSecurityException)
throw (BalsaSecurityException) e;
throw new BalsaSecurityException("Failed to get site and principal from session", e);
}
// finally validate
if (this.context.getSite() == null || this.context.getPrincipal() == null)
throw new BalsaSecurityException("Failed to get site and principal");
}
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception
{
// Handle a bad request.
if (!req.getDecoderResult().isSuccess())
{
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
return;
}
// Authenticate this connection
try
{
this.authenticateContext(req);
logger.debug("Authenticated websock connection for principal: " + this.context.getPrincipal());
}
catch (BalsaSecurityException e)
{
logger.debug("Denying access for websocket", e);
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
// Handshake
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, false);
this.handshaker = wsFactory.newHandshaker(req);
if (this.handshaker == null)
{
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
}
else
{
this.handshaker.handshake(ctx.channel(), req);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private void handleWebSocketFrame(final ChannelHandlerContext ctx, WebSocketFrame frame)
{
if (frame instanceof CloseWebSocketFrame)
{
// Check for closing frame
this.handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
return;
}
else if (frame instanceof PingWebSocketFrame)
{
// ping pong
ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
return;
}
else if (frame instanceof TextWebSocketFrame)
{
// only support text frames
try
{
APIObject request = this.transcoder.decodeFromString(((TextWebSocketFrame) frame).text(), APIObject.class);
// process the request
if (request instanceof APIRequest)
{
APIRequest apiRequest = (APIRequest) request;
// process the request
RequestHandler<?> handler = this.server.getHandler(request.getClass());
if (handler != null)
{
try
{
((RequestHandler) handler).onRequest(this.context, apiRequest);
}
catch (Exception e)
{
ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new APIError(apiRequest, e.getMessage()))));
}
}
else
{
ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new APIError(apiRequest, "Not found"))));
}
}
else
{
ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new APIError("Bad request"))));
ctx.channel().close();
}
}
catch (Exception e)
{
logger.error("Failed to decode request", e);
// send an error response
ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new APIError("Failed to decode request"))));
ctx.channel().close();
}
}
}
private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res)
{
// Generate an error page if response getStatus code is not OK (200).
if (res.getStatus().code() != 200)
{
ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8);
res.content().writeBytes(buf);
buf.release();
setContentLength(res, res.content().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.getStatus().code() != 200)
{
f.addListener(ChannelFutureListener.CLOSE);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception
{
logger.error("Error processing request", cause);
ctx.close();
}
private static String getWebSocketLocation(FullHttpRequest req)
{
return "wss://" + req.headers().get(HOST) + WEBSOCKET_PATH;
}
}