package io.netty.protocol.wamp; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.*; import io.netty.handler.codec.http.websocketx.*; import io.netty.util.CharsetUtil; import java.util.logging.Logger; import static io.netty.handler.codec.http.HttpHeaders.Names.HOST; import static io.netty.handler.codec.http.HttpHeaders.isKeepAlive; import static io.netty.handler.codec.http.HttpHeaders.setContentLength; import static io.netty.handler.codec.http.HttpMethod.GET; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; /** * Handles handshakes and messages */ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> { private static final Logger logger = Logger.getLogger(WebSocketServerHandler.class.getName()); private static final String WEBSOCKET_PATH = "/websocket"; private WebSocketServerHandshaker handshaker; @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { // Suppress this event for WampServerHandler // and fire it only after successful WebSocket handshake } @Override public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof FullHttpRequest) { handleHttpRequest(ctx, (FullHttpRequest) msg); } else if (msg instanceof WebSocketFrame) { handleWebSocketFrame(ctx, (WebSocketFrame) msg); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { ctx.flush(); } private void handleHttpRequest(final ChannelHandlerContext ctx, FullHttpRequest req) throws Exception { // Handle a bad request. if (!req.getDecoderResult().isSuccess()) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)); return; } // Allow only GET methods. if (req.getMethod() != HttpMethod.GET) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN)); return; } // Handshake WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), "wamp", false); handshaker = wsFactory.newHandshaker(req); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.channel()); } else { ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); handshakeFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { // Fire channelActive only after successful handshake ctx.fireChannelActive(); } } }); } } private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { // Check for closing frame if (frame instanceof CloseWebSocketFrame) handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain()); else if (frame instanceof PingWebSocketFrame) ctx.channel().write(new PongWebSocketFrame(frame.isFinalFragment(), frame.rsv(), frame.content().retain())); else if (frame instanceof PongWebSocketFrame) frame.release(); else if (frame instanceof BinaryWebSocketFrame) throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName())); else if (frame instanceof ContinuationWebSocketFrame) throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName())); else if (frame instanceof TextWebSocketFrame) ctx.fireChannelRead(frame.retain()); } private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) { // Generate an error page if response getStatus code is not OK (200). if (res.getStatus().code() != 200) { res.content().writeBytes(Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8)); HttpHeaders.setContentLength(res, res.content().readableBytes()); } // Send the response and close the connection if necessary. ChannelFuture f = ctx.channel().write(res); if (!HttpHeaders.isKeepAlive(req) || res.getStatus().code() != 200) { f.addListener(ChannelFutureListener.CLOSE); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { cause.printStackTrace(); ctx.close(); } private static String getWebSocketLocation(FullHttpRequest req) { return "ws://" + req.headers().get(HttpHeaders.Names.HOST) + WEBSOCKET_PATH; } }