package com.intrbiz.bergamot.agent.server;
import static io.netty.handler.codec.http.HttpHeaders.*;
import static io.netty.handler.codec.http.HttpHeaders.Names.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpVersion.*;
import java.net.SocketAddress;
import java.security.Principal;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import javax.net.ssl.SSLEngine;
import org.apache.log4j.Logger;
import com.intrbiz.bergamot.crypto.util.CertInfo;
import com.intrbiz.bergamot.crypto.util.SerialNum;
import com.intrbiz.bergamot.io.BergamotAgentTranscoder;
import com.intrbiz.bergamot.model.message.agent.AgentMessage;
import com.intrbiz.bergamot.model.message.agent.error.GeneralError;
import com.intrbiz.bergamot.model.message.agent.hello.AgentHello;
import com.intrbiz.bergamot.model.message.agent.ping.AgentPing;
import com.intrbiz.bergamot.model.message.agent.ping.AgentPong;
import com.intrbiz.bergamot.model.message.agent.registration.AgentRegistrationFailed;
import com.intrbiz.bergamot.model.message.agent.registration.AgentRegistrationFailed.ErrorCode;
import com.intrbiz.bergamot.model.message.agent.registration.AgentRegistrationMessage;
import com.intrbiz.bergamot.model.message.agent.registration.AgentRegistrationRequest;
import com.intrbiz.bergamot.model.message.agent.registration.AgentRegistrationRequired;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.CharsetUtil;
public class BergamotAgentServerHandler extends SimpleChannelInboundHandler<Object>
{
private static final Logger logger = Logger.getLogger(BergamotAgentServerHandler.class);
private static final String WEBSOCKET_PATH = "/agent";
private final BergamotAgentServer server;
private WebSocketServerHandshaker handshaker;
private final BergamotAgentTranscoder transcoder = BergamotAgentTranscoder.getDefaultInstance();
private AgentHello hello;
private SocketAddress remoteAddress;
private Channel channel;
private ConcurrentMap<String, Consumer<AgentMessage>> pendingRequests = new ConcurrentHashMap<String, Consumer<AgentMessage>>();
private final SSLEngine engine;
private Certificate agentCertificate;
private CertInfo agentCertificateInfo;
private Certificate siteCertificate;
private CertInfo siteCertificateInfo;
private SerialNum agentSerial;
private SerialNum siteSerial;
private UUID agentId;
private UUID siteId;
private AgentVerificationResult certificateVerification;
public BergamotAgentServerHandler(BergamotAgentServer server, SSLEngine engine)
{
super();
this.server = server;
this.engine = engine;
}
/**
* The Agent Id as extracted from the certificate
*/
public UUID getAgentId()
{
return this.agentId;
}
/**
* The Site Id as extracted from the site certificate
*/
public UUID getSiteId()
{
return this.siteId;
}
public CertInfo getAgentCertificateInfo()
{
return this.agentCertificateInfo;
}
public String getAgentName()
{
return this.agentCertificateInfo.getSubject().getCommonName();
}
public CertInfo getSiteCertificateInfo()
{
return this.siteCertificateInfo;
}
public AgentHello getHello()
{
return this.hello;
}
public SocketAddress getRemoteAddress()
{
return this.remoteAddress;
}
public Channel getChannel()
{
return this.channel;
}
public void sendMessageToAgent(AgentMessage message, Consumer<AgentMessage> onResponse)
{
// ensure the message has an id
if (message.getId() == null) message.setId(UUID.randomUUID().toString());
// stash the message
this.pendingRequests.put(message.getId(), onResponse);
// send the message
this.channel.writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(message)));
}
public void sendOnePingAndOnePingOnly(Consumer<Long> onPong)
{
this.sendMessageToAgent(new AgentPing(UUID.randomUUID().toString(), System.currentTimeMillis()), (message) -> onPong.accept(System.currentTimeMillis() - ((AgentPong) message).getTimestamp()) );
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception
{
this.channel = ctx.channel();
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception
{
// unregister this agent
if (this.hello != null)
{
this.server.unregisterAgent(this);
}
// invoke any pending messages
for (Consumer<AgentMessage> callback : this.pendingRequests.values())
{
callback.accept(new GeneralError("Channel closed"));
}
}
@Override
public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception
{
if (msg instanceof FullHttpRequest)
{
FullHttpRequest http = (FullHttpRequest) msg;
handleHttpRequest(ctx, http);
}
else if (msg instanceof WebSocketFrame)
{
WebSocketFrame frame = (WebSocketFrame) msg;
handleWebSocketFrame(ctx, frame);
}
else
{
throw new IllegalStateException("Unexpected message, got: " + msg);
}
}
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception
{
// Handle a bad request.
if (!req.getDecoderResult().isSuccess())
{
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
return;
}
// validate the certificate
// allow the WebSocket channel to open even if the certificate presented
// is a template certificate this allows the registration protocol to happen
this.certificateVerification = this.validateAgentCertificate(this.engine.getSession().getPeerPrincipal(), this.engine.getSession().getPeerCertificates());
if (this.certificateVerification == AgentVerificationResult.GOOD || this.certificateVerification == AgentVerificationResult.TEMPLATE)
{
// got a good client certificate, start the WS handshake
if (logger.isTraceEnabled()) logger.trace("Handshaking websocket request url: " + req.getUri());
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, false);
this.handshaker = wsFactory.newHandshaker(req);
if (this.handshaker == null)
{
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
}
else
{
this.handshaker.handshake(ctx.channel(), req);
}
}
else
{
// bad client certificate, terminate the connection
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
}
}
private void handleWebSocketFrame(final ChannelHandlerContext ctx, WebSocketFrame frame)
{
// Check for closing frame
if (frame instanceof CloseWebSocketFrame)
{
this.handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
return;
}
// ping pong
if (frame instanceof PingWebSocketFrame)
{
ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
return;
}
// only support text frames
if (!(frame instanceof TextWebSocketFrame)) throw new IllegalStateException(frame.getClass().getName() + " frame types not supported");
// get the frame
try
{
AgentMessage request = this.transcoder.decodeFromString(((TextWebSocketFrame) frame).text(), AgentMessage.class);
// process the message and respond
this.processMessage(ctx, request);
}
catch (Exception e)
{
logger.error("Failed to decode request", e);
ctx.close();
}
}
private void processMessage(final ChannelHandlerContext ctx, final AgentMessage request) throws Exception
{
if (this.certificateVerification == AgentVerificationResult.GOOD)
{
// allow the full agent protocol
this.processAgentMessage(ctx, request);
}
else if (this.certificateVerification == AgentVerificationResult.TEMPLATE)
{
this.processRegistrationMessage(ctx, request);
}
else
{
// should never get here
throw new IllegalStateException("WebSocket established given a bad certificate, not processing messaged");
}
}
private void processRegistrationMessage(final ChannelHandlerContext ctx, final AgentMessage request) throws Exception
{
if (request instanceof AgentRegistrationMessage)
{
if (request instanceof AgentRegistrationRequest)
{
// start the registration process
this.server.requestAgentRegistration(this.agentSerial.getId(), (AgentRegistrationRequest) request, (response) -> {
try
{
if (response != null)
{
writeMessage(ctx, response);
}
else
{
writeMessage(ctx, new AgentRegistrationFailed(request, ErrorCode.NOT_AVAILABLE, null));
}
}
catch (Exception e)
{
ctx.fireExceptionCaught(e);
}
});
}
}
else
{
// tell the agent it needs to register
writeMessage(ctx, new AgentRegistrationRequired(request));
}
}
private void processAgentMessage(final ChannelHandlerContext ctx, final AgentMessage request) throws Exception
{
if (request instanceof AgentHello)
{
this.hello = (AgentHello) request;
this.remoteAddress = ctx.channel().remoteAddress();
if (logger.isInfoEnabled()) logger.info("Got hello from " + this.remoteAddress + " " + this.agentId + " " + this.agentCertificateInfo.getSubject().getCommonName());
// register ourselves
this.server.registerAgent(this);
}
else if (request instanceof AgentPing)
{
if (logger.isTraceEnabled()) logger.trace("Got ping from agent");
this.server.fireAgentPing(this);
writeMessage(ctx, new AgentPong((AgentPing) request));
}
else if (request instanceof AgentPong)
{
Consumer<AgentMessage> callback = this.pendingRequests.remove(request.getId());
if (callback != null)
{
callback.accept(request);
}
else
{
if (logger.isInfoEnabled()) logger.trace("Got pong from agent");
}
}
else
{
if (request.getId() != null)
{
// do we have a callback to invoke
Consumer<AgentMessage> callback = this.pendingRequests.remove(request.getId());
if (callback != null)
{
callback.accept(request);
}
else
{
logger.warn("Unhandled message: " + request);
}
}
else
{
logger.warn("Unhandled message, no request id: " + request);
}
}
}
private void writeMessage(final ChannelHandlerContext ctx, final AgentMessage message) throws Exception
{
ctx.channel().writeAndFlush(new TextWebSocketFrame(this.transcoder.encodeAsString(message)));
}
private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res)
{
// Generate an error page if response getStatus code is not OK (200).
if (res.getStatus().code() != 200)
{
ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8);
res.content().writeBytes(buf);
buf.release();
setContentLength(res, res.content().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.getStatus().code() != 200)
{
f.addListener(ChannelFutureListener.CLOSE);
}
}
public enum AgentVerificationResult { GOOD, BAD, TEMPLATE }
/**
* Validate the agent client auth certificate.
*
* The Bergamot Agent encodes and signs important information
* into the certificate:
*
* 1) The Agent UUID - encoded in the agent certificate serial number
* 2) The Agent common name - the common name of the agent certificate
* 3) The Site UUID - encoded in the site CA certificate serial number
*
* @param clientPrincipal
* @param clientCertificates
*/
private AgentVerificationResult validateAgentCertificate(Principal clientPrincipal, Certificate[] clientCertificates)
{
// assert that we have a certificate
if (clientPrincipal == null || clientCertificates == null || clientCertificates.length < 2)
{
if (logger.isDebugEnabled()) logger.debug("Invalid agent certificate chain, not valid!");
return AgentVerificationResult.BAD;
}
try
{
// the client auth certificte chain will be:
// 0 - agent certificate
// 1 - site authority certificate
// 2 - root authority certificate
// agent cert
this.agentCertificate = clientCertificates[0];
this.agentCertificateInfo = CertInfo.fromCertificate(this.agentCertificate);
// site CA cert
this.siteCertificate = clientCertificates[1];
this.siteCertificateInfo = CertInfo.fromCertificate(this.siteCertificate);
// check the serial numbers
this.agentSerial = SerialNum.fromBigInt(((X509Certificate) this.agentCertificate).getSerialNumber());
this.agentId = this.agentSerial.getId();
this.siteSerial = SerialNum.fromBigInt(((X509Certificate) this.siteCertificate).getSerialNumber());
this.siteId = this.siteSerial.getId();
// validate that the Agent Id is masked by the Site Id
if ((this.agentId.getMostSignificantBits() & 0xFFFFFFFF_FFFF0000L) != (this.siteId.getMostSignificantBits() & 0xFFFFFFFF_FFFF0000L))
{
logger.warn("The agent id " + this.agentId + " is not masked by the site id " + this.siteId + ", refusing: " + this.agentCertificateInfo.getSubject().getCommonName());
return AgentVerificationResult.BAD;
}
// is the presented certificate a template rather than an actual agent
if (this.agentSerial.isVersion2() && this.agentSerial.isTemplate())
{
logger.info("Got agent connection with template id: " + this.agentId);
return AgentVerificationResult.TEMPLATE;
}
// assert that the agent serial number is an actual agent
if (this.agentSerial.isVersion2() && (! this.agentSerial.isAgent()))
{
logger.warn("The agent id " + this.agentId + " has the wrong mode");
return AgentVerificationResult.BAD;
}
// log
if (logger.isInfoEnabled()) logger.info("Connection from client: " + this.agentCertificateInfo.getSubject().getCommonName() + " of site " + this.siteCertificateInfo.getSubject().getCommonName());
return AgentVerificationResult.GOOD;
}
catch (Exception e)
{
logger.error("Error validating client certificate", e);
}
return AgentVerificationResult.BAD;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception
{
logger.error("Error processing request", cause);
ctx.close();
}
private static String getWebSocketLocation(FullHttpRequest req)
{
return "wss://" + req.headers().get(HOST) + WEBSOCKET_PATH;
}
}