package org.sdnplatform.sync.internal.rpc; import java.io.IOException; import java.net.ConnectException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.List; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import javax.xml.bind.DatatypeConverter; import net.floodlightcontroller.core.annotations.LogMessageCategory; import net.floodlightcontroller.core.annotations.LogMessageDoc; import net.floodlightcontroller.core.annotations.LogMessageDocs; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.handler.timeout.IdleStateAwareChannelHandler; import org.jboss.netty.handler.timeout.IdleStateEvent; import org.jboss.netty.handler.timeout.ReadTimeoutException; import org.sdnplatform.sync.error.AuthException; import org.sdnplatform.sync.error.HandshakeTimeoutException; import org.sdnplatform.sync.error.SyncException; import org.sdnplatform.sync.internal.config.AuthScheme; import org.sdnplatform.sync.internal.util.CryptoUtil; import org.sdnplatform.sync.thrift.AsyncMessageHeader; import org.sdnplatform.sync.thrift.AuthChallengeResponse; import org.sdnplatform.sync.thrift.ClusterJoinRequestMessage; import org.sdnplatform.sync.thrift.ClusterJoinResponseMessage; import org.sdnplatform.sync.thrift.SyncError; import org.sdnplatform.sync.thrift.SyncMessage; import org.sdnplatform.sync.thrift.CursorRequestMessage; import org.sdnplatform.sync.thrift.CursorResponseMessage; import org.sdnplatform.sync.thrift.DeleteRequestMessage; import org.sdnplatform.sync.thrift.DeleteResponseMessage; import org.sdnplatform.sync.thrift.EchoReplyMessage; import org.sdnplatform.sync.thrift.EchoRequestMessage; import org.sdnplatform.sync.thrift.ErrorMessage; import org.sdnplatform.sync.thrift.FullSyncRequestMessage; import org.sdnplatform.sync.thrift.GetRequestMessage; import org.sdnplatform.sync.thrift.GetResponseMessage; import org.sdnplatform.sync.thrift.HelloMessage; import org.sdnplatform.sync.thrift.MessageType; import org.sdnplatform.sync.thrift.PutRequestMessage; import org.sdnplatform.sync.thrift.PutResponseMessage; import org.sdnplatform.sync.thrift.RegisterRequestMessage; import org.sdnplatform.sync.thrift.RegisterResponseMessage; import org.sdnplatform.sync.thrift.SyncOfferMessage; import org.sdnplatform.sync.thrift.SyncRequestMessage; import org.sdnplatform.sync.thrift.SyncValueMessage; import org.sdnplatform.sync.thrift.SyncValueResponseMessage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Abstract base class for implementing the RPC protocol. The protocol is * defined by a thrift specification; all protocol messages are delivered in * a {@link SyncMessage} which will provide specific type information. * @author readams */ @LogMessageCategory("State Synchronization") public abstract class AbstractRPCChannelHandler extends IdleStateAwareChannelHandler { protected static final Logger logger = LoggerFactory.getLogger(AbstractRPCChannelHandler.class); protected String currentChallenge; protected enum ChannelState { OPEN, CONNECTED, AUTHENTICATED; } protected ChannelState channelState = ChannelState.OPEN; public AbstractRPCChannelHandler() { super(); } // **************************** // IdleStateAwareChannelHandler // **************************** @Override public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { channelState = ChannelState.CONNECTED; HelloMessage m = new HelloMessage(); if (getLocalNodeId() != null) m.setNodeId(getLocalNodeId()); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(getTransactionId()); m.setHeader(header); switch (getAuthScheme()) { case NO_AUTH: channelState = ChannelState.AUTHENTICATED; m.setAuthScheme(org.sdnplatform.sync.thrift. AuthScheme.NO_AUTH); break; case CHALLENGE_RESPONSE: AuthChallengeResponse cr = new AuthChallengeResponse(); cr.setChallenge(generateChallenge()); m.setAuthScheme(org.sdnplatform.sync.thrift. AuthScheme.CHALLENGE_RESPONSE); m.setAuthChallengeResponse(cr); break; } SyncMessage bsm = new SyncMessage(MessageType.HELLO); bsm.setHello(m); ctx.getChannel().write(bsm); } @Override public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) throws Exception { // send an echo request EchoRequestMessage m = new EchoRequestMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.ECHO_REQUEST); bsm.setEchoRequest(m); ctx.getChannel().write(bsm); } @Override @LogMessageDocs({ @LogMessageDoc(level="ERROR", message="[{id}->{id}] Disconnecting client due to read timeout", explanation="The connected client has failed to send any " + "messages or respond to echo requests", recommendation=LogMessageDoc.CHECK_CONTROLLER), @LogMessageDoc(level="ERROR", message="[{id}->{id}] Disconnecting RPC node due to " + "handshake timeout", explanation="The remote node did not complete the handshake", recommendation=LogMessageDoc.CHECK_CONTROLLER), @LogMessageDoc(level="ERROR", message="[{id}->{id}] IOException: {message}", explanation="There was an error communicating with the " + "remote client", recommendation=LogMessageDoc.GENERIC_ACTION), @LogMessageDoc(level="ERROR", message="[{id}->{id}] ConnectException: {message} {error}", explanation="There was an error connecting to the " + "remote node", recommendation=LogMessageDoc.GENERIC_ACTION), @LogMessageDoc(level="ERROR", message="[{}->{}] An error occurred on RPC channel", explanation="An error occurred processing the message", recommendation=LogMessageDoc.GENERIC_ACTION), }) public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { if (e.getCause() instanceof ReadTimeoutException) { // read timeout logger.error("[{}->{}] Disconnecting RPC node due to read timeout", getLocalNodeIdString(), getRemoteNodeIdString()); ctx.getChannel().close(); } else if (e.getCause() instanceof HandshakeTimeoutException) { // read timeout logger.error("[{}->{}] Disconnecting RPC node due to " + "handshake timeout", getLocalNodeIdString(), getRemoteNodeIdString()); ctx.getChannel().close(); } else if (e.getCause() instanceof ConnectException || e.getCause() instanceof IOException) { logger.debug("[{}->{}] {}: {}", new Object[] {getLocalNodeIdString(), getRemoteNodeIdString(), e.getCause().getClass().getName(), e.getCause().getMessage()}); } else { logger.error("[{}->{}] An error occurred on RPC channel", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), e.getCause()}); ctx.getChannel().close(); } } @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { Object message = e.getMessage(); if (message instanceof SyncMessage) { handleSyncMessage((SyncMessage)message, ctx.getChannel()); } else if (message instanceof List) { for (Object i : (List<?>)message) { if (i instanceof SyncMessage) { try { handleSyncMessage((SyncMessage)i, ctx.getChannel()); } catch (Exception ex) { Channels.fireExceptionCaught(ctx, ex); } } } } else { handleUnknownMessage(ctx, message); } } // **************** // Message Handlers // **************** /** * A handler for messages on the channel that are not of type * {@link SyncMessage} * @param ctx the context * @param message the message object */ @LogMessageDoc(level="WARN", message="[{id}->{id}] Unhandled message: {message type}", explanation="An unrecognized event occurred", recommendation=LogMessageDoc.REPORT_CONTROLLER_BUG) protected void handleUnknownMessage(ChannelHandlerContext ctx, Object message) { logger.warn("[{}->{}] Unhandled message: {}", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), message.getClass().getCanonicalName()}); } /** * Handle a generic {@link SyncMessage} and dispatch to an appropriate * handler * @param bsm the message * @param channel the channel on which the message arrived */ protected void handleSyncMessage(SyncMessage bsm, Channel channel) { switch (channelState) { case OPEN: case CONNECTED: switch (bsm.getType()) { case HELLO: handshake(bsm.getHello(), channel); break; case ECHO_REQUEST: handleEchoRequest(bsm.getEchoRequest(), channel); break; case ERROR: handleError(bsm.getError(), channel); break; default: // ignore } break; case AUTHENTICATED: handleSMAuthenticated(bsm, channel); break; } } /** * Handle a generic {@link SyncMessage} and dispatch to an appropriate * handler * @param bsm the message * @param channel the channel on which the message arrived */ protected void handleSMAuthenticated(SyncMessage bsm, Channel channel) { switch (bsm.getType()) { case HELLO: handleHello(bsm.getHello(), channel); break; case ECHO_REQUEST: handleEchoRequest(bsm.getEchoRequest(), channel); break; case GET_REQUEST: handleGetRequest(bsm.getGetRequest(), channel); break; case GET_RESPONSE: handleGetResponse(bsm.getGetResponse(), channel); break; case PUT_REQUEST: handlePutRequest(bsm.getPutRequest(), channel); break; case PUT_RESPONSE: handlePutResponse(bsm.getPutResponse(), channel); break; case DELETE_REQUEST: handleDeleteRequest(bsm.getDeleteRequest(), channel); break; case DELETE_RESPONSE: handleDeleteResponse(bsm.getDeleteResponse(), channel); break; case SYNC_VALUE_RESPONSE: handleSyncValueResponse(bsm.getSyncValueResponse(), channel); break; case SYNC_VALUE: handleSyncValue(bsm.getSyncValue(), channel); break; case SYNC_OFFER: handleSyncOffer(bsm.getSyncOffer(), channel); break; case FULL_SYNC_REQUEST: handleFullSyncRequest(bsm.getFullSyncRequest(), channel); break; case SYNC_REQUEST: handleSyncRequest(bsm.getSyncRequest(), channel); break; case CURSOR_REQUEST: handleCursorRequest(bsm.getCursorRequest(), channel); break; case CURSOR_RESPONSE: handleCursorResponse(bsm.getCursorResponse(), channel); break; case REGISTER_REQUEST: handleRegisterRequest(bsm.getRegisterRequest(), channel); break; case REGISTER_RESPONSE: handleRegisterResponse(bsm.getRegisterResponse(), channel); break; case CLUSTER_JOIN_REQUEST: handleClusterJoinRequest(bsm.getClusterJoinRequest(), channel); break; case CLUSTER_JOIN_RESPONSE: handleClusterJoinResponse(bsm.getClusterJoinResponse(), channel); break; case ERROR: handleError(bsm.getError(), channel); break; case ECHO_REPLY: // do nothing; just the read will have reset our read timeout // handler break; default: logger.warn("[{}->{}] Unhandled message: {}", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), bsm.getType()}); break; } } @LogMessageDoc(level="WARN", message="Failed to authenticate connection from {remote}: {message}", explanation="Challenge/Response authentication failed", recommendation="Check the included error message, and " + "verify the shared secret is correctly-configured") protected void handshake(HelloMessage request, Channel channel) { try { switch (getAuthScheme()) { case CHALLENGE_RESPONSE: handshakeChallengeResponse(request, channel); break; case NO_AUTH: // shouldn't get here break; } } catch (AuthException e) { logger.warn("[{}->{}] Failed to authenticate connection: {}", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), e.getMessage()}); channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.HELLO)); channel.close(); } } protected void handshakeChallengeResponse(HelloMessage request, Channel channel) throws AuthException { AuthChallengeResponse cr = request.getAuthChallengeResponse(); if (cr == null) { throw new AuthException("No authentication data in " + "handshake message"); } if (cr.isSetResponse()) { authenticateResponse(currentChallenge, cr.getResponse()); currentChallenge = null; channelState = ChannelState.AUTHENTICATED; handleHello(request, channel); } else if (cr.isSetChallenge()) { HelloMessage m = new HelloMessage(); if (getLocalNodeId() != null) m.setNodeId(getLocalNodeId()); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.HELLO); bsm.setHello(m); AuthChallengeResponse reply = new AuthChallengeResponse(); reply.setResponse(generateResponse(cr.getChallenge())); m.setAuthChallengeResponse(reply); channel.write(bsm); } else { throw new AuthException("No authentication data in " + "handshake message"); } } protected void error(ErrorMessage error, Channel channel) { if (MessageType.HELLO.equals(error.getType())) { } } protected void handleHello(HelloMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.HELLO, channel); } protected void handleEchoRequest(EchoRequestMessage request, Channel channel) { EchoReplyMessage m = new EchoReplyMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.ECHO_REPLY); bsm.setEchoReply(m); channel.write(bsm); } protected void handleGetRequest(GetRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.GET_REQUEST, channel); } protected void handleGetResponse(GetResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.GET_RESPONSE, channel); } protected void handlePutRequest(PutRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.PUT_REQUEST, channel); } protected void handlePutResponse(PutResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.PUT_RESPONSE, channel); } protected void handleDeleteRequest(DeleteRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.DELETE_REQUEST, channel); } protected void handleDeleteResponse(DeleteResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.PUT_RESPONSE, channel); } protected void handleSyncValue(SyncValueMessage message, Channel channel) { unexpectedMessage(message.getHeader().getTransactionId(), MessageType.SYNC_VALUE, channel); } protected void handleSyncValueResponse(SyncValueResponseMessage message, Channel channel) { unexpectedMessage(message.getHeader().getTransactionId(), MessageType.SYNC_VALUE_RESPONSE, channel); } protected void handleSyncOffer(SyncOfferMessage message, Channel channel) { unexpectedMessage(message.getHeader().getTransactionId(), MessageType.SYNC_OFFER, channel); } protected void handleSyncRequest(SyncRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.SYNC_REQUEST, channel); } protected void handleFullSyncRequest(FullSyncRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.FULL_SYNC_REQUEST, channel); } protected void handleCursorRequest(CursorRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.CURSOR_REQUEST, channel); } protected void handleCursorResponse(CursorResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.CURSOR_RESPONSE, channel); } protected void handleRegisterRequest(RegisterRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.REGISTER_REQUEST, channel); } protected void handleRegisterResponse(RegisterResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.REGISTER_RESPONSE, channel); } protected void handleClusterJoinRequest(ClusterJoinRequestMessage request, Channel channel) { unexpectedMessage(request.getHeader().getTransactionId(), MessageType.CLUSTER_JOIN_REQUEST, channel); } protected void handleClusterJoinResponse(ClusterJoinResponseMessage response, Channel channel) { unexpectedMessage(response.getHeader().getTransactionId(), MessageType.CLUSTER_JOIN_RESPONSE, channel); } @LogMessageDoc(level="ERROR", message="[{id}->{id}] Error for message {id} ({type}): " + "{message} {error code}", explanation="Remote client sent an error", recommendation=LogMessageDoc.GENERIC_ACTION) protected void handleError(ErrorMessage error, Channel channel) { logger.error("[{}->{}] Error for message {} ({}): {} ({})", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), error.getHeader().getTransactionId(), error.getType(), error.getError().getMessage(), error.getError().getErrorCode()}); } // ***************** // Utility functions // ***************** /** * Generate an error message from the provided transaction ID and * exception * @param transactionId the transaction Id * @param error the exception * @param type the type of the message that generated the error * @return the {@link SyncError} message */ @LogMessageDoc(level="ERROR", message="Unexpected error processing message {} ({})", explanation="An error occurred while processing an " + "RPC message", recommendation=LogMessageDoc.GENERIC_ACTION) protected SyncMessage getError(int transactionId, Exception error, MessageType type) { int ec = SyncException.ErrorType.GENERIC.getValue(); if (error instanceof SyncException) { ec = ((SyncException)error).getErrorType().getValue(); } else { logger.error("Unexpected error processing message " + transactionId + "(" + type + ")", error); } SyncError m = new SyncError(); m.setErrorCode(ec); m.setMessage(error.getMessage()); ErrorMessage em = new ErrorMessage(); em.setError(m); em.setType(type); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(transactionId); em.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.ERROR); bsm.setError(em); return bsm; } /** * Send an error to the channel indicating that we got an unexpected * message for this type of RPC client * @param transactionId the transaction ID for the message that generated * the error * @param type The type of the message that generated the error * @param channel the channel to write the error */ @LogMessageDoc(level="WARN", message="[{id}->{id}] Received unexpected message: {type}", explanation="A inappriopriate message was sent by the remote" + "client", recommendation=LogMessageDoc.REPORT_CONTROLLER_BUG) protected void unexpectedMessage(int transactionId, MessageType type, Channel channel) { String message = "Received unexpected message: " + type; logger.warn("[{}->{}] {}", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), message}); channel.write(getError(transactionId, new SyncException(message), type)); } /** * Get a transaction ID suitable for sending an async message * @return the unique transaction ID */ protected abstract int getTransactionId(); /** * Get the node ID for the remote node if its connected * @return the node ID */ protected abstract Short getRemoteNodeId(); /** * Get the node ID for the remote node if its connected as a string * for use output * @return the node ID */ protected String getRemoteNodeIdString() { return ""+getRemoteNodeId(); } /** * Get the node ID for the local node if appropriate * @return the node ID. Null if this is a client */ protected abstract Short getLocalNodeId(); /** * Get the node ID for the local node as a string for use output * @return the node ID */ protected String getLocalNodeIdString() { return ""+getLocalNodeId(); } /** * Get the type of authentication to use for this connection */ protected abstract AuthScheme getAuthScheme(); /** * Get a shared secret to be used for authentication handshake. * Throwing an exception will cause authentication to fail * @return the shared secret */ protected abstract byte[] getSharedSecret() throws AuthException; // ************* // Local methods // ************* private String generateChallenge() { byte[] challengeBytes = CryptoUtil.secureRandom(16); currentChallenge = DatatypeConverter.printBase64Binary(challengeBytes); return currentChallenge; } private void authenticateResponse(String challenge, String response) throws AuthException { String expected = generateResponse(challenge); if (expected == null) return; byte[] expectedBytes = DatatypeConverter.parseBase64Binary(expected); byte[] reponseBytes = DatatypeConverter.parseBase64Binary(response); if (!Arrays.equals(expectedBytes, reponseBytes)) { throw new AuthException("Challenge response does not match " + "expected response"); } } private String generateResponse(String challenge) throws AuthException { byte[] secretBytes = getSharedSecret(); if (secretBytes == null) return null; SecretKeySpec signingKey = new SecretKeySpec(secretBytes, "HmacSHA1"); Mac mac; try { mac = Mac.getInstance("HmacSHA1"); } catch (NoSuchAlgorithmException e) { throw new AuthException("Could not initialize HmacSHA1 algorithm", e); } try { mac.init(signingKey); byte[] output = mac.doFinal(DatatypeConverter.parseBase64Binary(challenge)); return DatatypeConverter.printBase64Binary(output); } catch (InvalidKeyException e) { throw new AuthException("Invalid shared secret; could not " + "authenticate response", e); } } }