// Copyright 2016 Twitter. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.twitter.heron.common.network; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.SelectableChannel; import java.nio.channels.SocketChannel; import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import com.google.protobuf.Message; import com.twitter.heron.common.basics.ISelectHandler; import com.twitter.heron.common.basics.NIOLooper; /** * Implements this class could handle some following socket related behaviors: * 1. handleRead(SelectableChannel), which read data from a socket and convert into incomingPacket. * It could handle the conditions of closedConnection, normal Reading and partial Reading. When a * incomingPacket is read, it will be pass to handlePacket(), which will convert incomingPackets to * messages and call onIncomingMessage(message), which should be implemented by its child class. * <p> * 2. handleWrite(SelectableChannel), which will try to get outgoing message by calling getOutgoingMessage(), * pack the outgoing message into OutgoingPacket and write to the sockets. * <p> * 3. handleConnect(SelectableChannel), which handles some basic setup when this client connect to * remote endpoint. * <p> * 4. handleAccept(SelectableChannel). * <p> * 5. handleError(SelectableChannel). * Remember, the socket client will register Read when the socket is connectible. However, it will * register Write when having something to write since the socket in most cases is writable. * To implement this, we will add the check whether write is needed into persistent tasks. */ public abstract class HeronClient implements ISelectHandler { private static final Logger LOG = Logger.getLogger(HeronClient.class.getName()); // When we send a request, we need to: // record the the context for this particular RID, and prepare the response for that RID // Then when the response come back, we could handle it protected Map<REQID, Object> contextMap; protected Map<REQID, Message.Builder> responseMessageMap; // Map from protobuf message's name to protobuf message protected Map<String, Message.Builder> messageMap; private SocketChannel socketChannel; // Define the endpoint this socket client will communicate with private InetSocketAddress endpoint; private NIOLooper nioLooper; private SocketChannelHelper socketChannelHelper; private HeronSocketOptions socketOptions; // A flag to determine whether the socket is connected or not // We could not simply use socketChanel.isConnected() to tell whether the socketChannel // is connected or not, since: // SocketChannel.socket().isConnected() and SocketChannel.isConnected() // return false before the socket is connected. // Once the socket is connected they will return true, // they will not revert to false for any reason. // It violates what is documented in // http://docs.oracle.com/javase/7/docs/api/java/nio/channels/SocketChannel.html#isConnected() // Consider it is a JAVA bug private boolean isConnected; /** * Constructor * * @param s the NIOLooper bind with this socket client * @param host the host of remote endpoint to communicate with * @param port the port of remote endpoint to communicate with */ public HeronClient(NIOLooper s, String host, int port, HeronSocketOptions options) { nioLooper = s; endpoint = new InetSocketAddress(host, port); socketOptions = options; isConnected = false; contextMap = new HashMap<REQID, Object>(); responseMessageMap = new HashMap<REQID, Message.Builder>(); messageMap = new HashMap<String, Message.Builder>(); } // Register the protobuf Message's name with protobuf Message public void registerOnMessage(Message.Builder builder) { messageMap.put(builder.getDescriptorForType().getFullName(), builder); } public void start() { try { socketChannel = SocketChannel.open(); socketChannel.configureBlocking(false); // Set the maximum possible send and receive buffers socketChannel.socket().setSendBufferSize( (int) socketOptions.getSocketSendBufferSize().asBytes()); socketChannel.socket().setReceiveBufferSize( (int) socketOptions.getSocketReceivedBufferSize().asBytes()); socketChannel.socket().setTcpNoDelay(true); // If the socketChannel has already connect to endpoint, call handleConnect() // Otherwise, registerConnect(), which will call handleConnect() when it is connectible LOG.info("Connecting to endpoint: " + endpoint); if (socketChannel.connect(endpoint)) { handleConnect(socketChannel); } else { nioLooper.registerConnect(socketChannel, this); } } catch (IOException e) { // Call onConnect() with CONNECT_ERROR LOG.log(Level.SEVERE, "Error connecting to remote endpoint: " + endpoint, e); Runnable r = new Runnable() { public void run() { onConnect(StatusCode.CONNECT_ERROR); } }; nioLooper.registerTimerEvent(Duration.ZERO, r); } } public void stop() { if (!isConnected()) { return; } // Flush the data to socket with best effort forceFlushWithBestEffort(); LOG.info("To stop the HeronClient."); contextMap.clear(); responseMessageMap.clear(); messageMap.clear(); socketChannelHelper.clear(); nioLooper.removeAllInterest(socketChannel); try { socketChannel.close(); onClose(); } catch (IOException e) { LOG.log(Level.SEVERE, "Failed to stop Client", e); } } @Override public void handleRead(SelectableChannel channel) { List<IncomingPacket> packets = socketChannelHelper.read(); for (IncomingPacket ipt : packets) { handlePacket(ipt); } } @Override public void handleWrite(SelectableChannel channel) { socketChannelHelper.write(); } // Send a request to the server with a certain timeout in seconds // This function doesnt return anything. After this function returns, // does not mean that the request actually sent out, merely that the request // was successfully queued to be sent out. // Actual send occurs when the socket becomes readable and all prev // requests are sent. If the packet cannot be sent // out or the request is not retired by the client within the timeout // period, the HandleResponse is called with the appropriate status. // The request is now owned by the Client class. // The ctx is a user owned piece of context. // The response is a MessageBuilder to handle the response from server // A negative value of the timeout means no timeout. public void sendRequest(Message request, Object context, Message.Builder responseBuilder, Duration timeout) { // Pack it as a no-timeout request and send it! final REQID rid = REQID.generate(); contextMap.put(rid, context); responseMessageMap.put(rid, responseBuilder); // Add timeout for this request if necessary if (timeout.getSeconds() > 0) { registerTimerEvent(timeout, new Runnable() { @Override public void run() { handleTimeout(rid); } }); } OutgoingPacket opk = new OutgoingPacket(rid, request); socketChannelHelper.sendPacket(opk); } // Convenience method of the above method with no timeout or context public void sendRequest(Message request, Message.Builder responseBuilder) { sendRequest(request, null, responseBuilder, Duration.ZERO); } // This method is used if you want to communicate with the other end // on a non-request-response based communication. public void sendMessage(Message message) { OutgoingPacket opk = new OutgoingPacket(REQID.zeroREQID, message); socketChannelHelper.sendPacket(opk); } public boolean isConnected() { return isConnected; } public NIOLooper getNIOLooper() { return nioLooper; } // Add a timer to be invoked after timer duration. private void registerTimerEvent(Duration timer, Runnable task) { nioLooper.registerTimerEvent(timer, task); } @Override public void handleAccept(SelectableChannel channel) { throw new RuntimeException("Client does not implement accept"); } @Override public void handleConnect(SelectableChannel channel) { try { if (socketChannel.finishConnect()) { // If we finishConnect(), we have to unregisterConnect, otherwise there will be a bug // http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4960791 nioLooper.unregisterConnect(channel); } } catch (IOException e) { LOG.log(Level.SEVERE, "Failed to FinishConnect to endpoint: " + endpoint, e); Runnable r = new Runnable() { public void run() { onConnect(StatusCode.CONNECT_ERROR); } }; nioLooper.registerTimerEvent(Duration.ZERO, r); return; } // Construct the ChannelHelper and by default it would: // 1. always read // 2. write if # of packets to send > 0 socketChannelHelper = new SocketChannelHelper(nioLooper, this, socketChannel, socketOptions); // Only when we fully connected, we set isConnected true isConnected = true; onConnect(StatusCode.OK); } /** * Handle an incomingPacket and if necessary, * convert it to Message and call onIncomingMessage() to handle it */ protected void handlePacket(IncomingPacket incomingPacket) { String typeName = incomingPacket.unpackString(); REQID rid = incomingPacket.unpackREQID(); if (contextMap.containsKey(rid)) { // This incomingPacket contains the response of Request Object ctx = contextMap.get(rid); Message.Builder bldr = responseMessageMap.get(rid); contextMap.remove(rid); responseMessageMap.remove(rid); incomingPacket.unpackMessage(bldr); // Call onResponse to handle it if (bldr.isInitialized()) { Message response = bldr.build(); onResponse(StatusCode.OK, ctx, response); return; } else { onResponse(StatusCode.INVALID_PACKET, ctx, null); return; } } else if (rid.equals(REQID.zeroREQID)) { // If rid is REQID.zeroREQID, this is a Message, e.g. no need send back response. // Convert it into message and call onIncomingMessage() to handle it Message.Builder bldr = messageMap.get(typeName); if (bldr != null) { bldr.clear(); incomingPacket.unpackMessage(bldr); if (bldr.isInitialized()) { onIncomingMessage(bldr.build()); } else { // We just need to log here // TODO:- log } } else { // We got a message but we didn't register // TODO:- log here } } else { // This might be a timeout response // TODO:- log here } } // Handle the timeout for a particular REQID protected void handleTimeout(REQID rid) { if (contextMap.containsKey(rid)) { Object ctx = contextMap.get(rid); contextMap.remove(rid); responseMessageMap.remove(rid); onResponse(StatusCode.TIMEOUT_ERROR, ctx, null); } else { // Since we dont do cancel timer, this is because we already have // the response. So just disregard this timeout // TODO:- implement cancel timer to avoid this overhead } } // Clean the stuff when meeting some errors public void handleError(SelectableChannel channel) { LOG.info("Handling Error. Cleaning states in HeronClient."); contextMap.clear(); responseMessageMap.clear(); messageMap.clear(); socketChannelHelper.clear(); nioLooper.removeAllInterest(channel); try { channel.close(); LOG.info("Successfully closed the channel: " + channel); } catch (IOException e) { LOG.log(Level.SEVERE, "Failed to close connection in handleError", e); } // Since we closed the channel, we set isConnected false isConnected = false; onError(); } public void startReading() { socketChannelHelper.enableReading(); } public void stopReading() { socketChannelHelper.disableReading(); } public void startWriting() { socketChannelHelper.enableWriting(); } public void stopWriting() { socketChannelHelper.disableWriting(); } public int getOutstandingPackets() { return socketChannelHelper.getOutstandingPackets(); } // Force to flush all data to be sent by HeronClient public void forceFlushWithBestEffort() { socketChannelHelper.forceFlushWithBestEffort(); } ///////////////////////////////////////////////////////// // This is the interface that needs to be implemented by // all Heron Clients. ///////////////////////////////////////////////////////// // What action do you want to take when the client meets errors public abstract void onError(); // What action do you want to take when connecting to a new server public abstract void onConnect(StatusCode status); // What action do you want to take when you get a new // response from a particular server public abstract void onResponse(StatusCode status, Object ctx, Message response); // What action do you want to take when you get a new // message from a particular server public abstract void onIncomingMessage(Message message); // What action do you want to take when we want to stop this client public abstract void onClose(); ///////////////////////////////////////////////////////// // Following protected methods are just used for testing ///////////////////////////////////////////////////////// protected Map<String, Message.Builder> getMessageMap() { return new HashMap<String, Message.Builder>(messageMap); } protected Map<REQID, Message.Builder> getResponseMessageMap() { return new HashMap<REQID, Message.Builder>(responseMessageMap); } protected Map<REQID, Object> getContextMap() { return new HashMap<REQID, Object>(contextMap); } protected SocketChannelHelper getSocketChannelHelper() { return socketChannelHelper; } protected SocketChannel getSocketChannel() { return socketChannel; } }