package com.intrbiz.bergamot.agent; import java.net.URI; import java.util.Timer; import java.util.TimerTask; import java.util.UUID; import org.apache.log4j.Logger; import com.intrbiz.bergamot.io.BergamotAgentTranscoder; import com.intrbiz.bergamot.model.message.agent.AgentMessage; import com.intrbiz.bergamot.model.message.agent.error.GeneralError; import com.intrbiz.bergamot.model.message.agent.hello.AgentHello; import com.intrbiz.bergamot.model.message.agent.ping.AgentPing; import com.intrbiz.bergamot.util.AgentUtil; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketVersion; public abstract class AgentClientHandler extends SimpleChannelInboundHandler<Object> { private static final long AGENT_PING_INTERVAL_MS = 30L * 1000L; private Logger logger = Logger.getLogger(AgentClientHandler.class); private final WebSocketClientHandshaker handshaker; private final Timer timer; private final BergamotAgentTranscoder transcoder = BergamotAgentTranscoder.getDefaultInstance(); private AgentHello hello; public AgentClientHandler(Timer timer, URI server) { super(); this.timer = timer; HttpHeaders headers = new DefaultHttpHeaders(); headers.set(HttpHeaders.Names.USER_AGENT, BergamotAgent.AGENT_PRODUCT + "/" + BergamotAgent.AGENT_VERSION); this.handshaker = WebSocketClientHandshakerFactory.newHandshaker(server, WebSocketVersion.V13, null, false, headers); } protected AgentHello getHello() { if (this.hello == null) { this.hello = new AgentHello(UUID.randomUUID().toString()); hello.setAgentName(BergamotAgent.AGENT_PRODUCT); hello.setAgentVariant(BergamotAgent.AGENT_VENDOR); hello.setAgentVersion(BergamotAgent.AGENT_VERSION); hello.setNonce(AgentUtil.newNonce()); hello.setTimestamp(System.currentTimeMillis()); hello.setProtocolVersion(1); } return this.hello; } protected abstract AgentMessage processAgentMessage(final ChannelHandlerContext ctx, final AgentMessage request); @Override public void channelActive(ChannelHandlerContext ctx) { logger.trace("Connected, starting handshake"); handshaker.handshake(ctx.channel()); } public void channelHandshaked(ChannelHandlerContext ctx) { logger.trace("Handshake done"); final Channel channel = ctx.channel(); // hello logger.debug("Sending hello to server"); channel.writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(this.getHello()))); // schedule ping this.timer.scheduleAtFixedRate(new TimerTask() { @Override public void run() { if (channel.isActive()) { logger.trace("Sending ping to server"); channel.writeAndFlush(new TextWebSocketFrame(transcoder.encodeAsString(new AgentPing(UUID.randomUUID().toString(), System.currentTimeMillis())))); } else { this.cancel(); } } }, AGENT_PING_INTERVAL_MS, AGENT_PING_INTERVAL_MS); } @Override public void channelRead0(final ChannelHandlerContext ctx, Object msg) { if (msg instanceof FullHttpResponse) { FullHttpResponse http = (FullHttpResponse) msg; // complete the handshake if (!handshaker.isHandshakeComplete()) { handshaker.finishHandshake(ctx.channel(), http); this.channelHandshaked(ctx); return; } } else if (msg instanceof WebSocketFrame) { // process the frame WebSocketFrame frame = (WebSocketFrame) msg; if (frame instanceof TextWebSocketFrame) { String message = ((TextWebSocketFrame) frame).text(); if (logger.isDebugEnabled()) logger.debug("Got message from agent server: " + message); try { final AgentMessage request = this.transcoder.decodeFromString(message, AgentMessage.class); // process the request if (request instanceof AgentMessage) { // process the message ctx.executor().execute(new Runnable() { public void run() { try { AgentMessage response = processAgentMessage(ctx, request); // respond if (response != null) { ctx.channel().writeAndFlush(new TextWebSocketFrame(transcoder.encodeAsString(response))); } } catch (Exception e) { ctx.channel().writeAndFlush(new TextWebSocketFrame(transcoder.encodeAsString(new GeneralError("Failed to process message")))); } } }); } else { ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new GeneralError(request, "Bad request")))); } } catch (Exception e) { logger.error("Failed to decode request", e); ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(new GeneralError("Failed to decode request")))); } } else if (frame instanceof PongWebSocketFrame) { logger.trace("Got pong, whoop"); } else if (frame instanceof CloseWebSocketFrame) { logger.trace("Closing connection"); ctx.close(); } } else { throw new IllegalStateException("Unexpected message, got: " + msg); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) { logger.error("Unhandled error communicating with Bergamot server", e); ctx.close(); } }