package org.radargun; import java.io.IOException; import java.io.Serializable; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; import org.radargun.config.Cluster; import org.radargun.config.Configuration; import org.radargun.config.Scenario; import org.radargun.logging.Log; import org.radargun.logging.LogFactory; import org.radargun.reporting.Timeline; import org.radargun.utils.SlaveConnectionInfo; import org.radargun.utils.TimeService; /** * Connection to slaves in different JVMs * * @author Radim Vansa <rvansa@redhat.com> */ public class RemoteSlaveConnection { private static final long CONNECT_TIMEOUT = TimeUnit.MINUTES.toMillis(5); private static Log log = LogFactory.getLog(RemoteSlaveConnection.class); private static final int UUID_BYTES = 16; private static final int EXPECTED_SIZE_BYTES = 4; private static final int DEFAULT_WRITE_BUFF_CAPACITY = 1024; private static final int DEFAULT_READ_BUFF_CAPACITY = 1024; public static final int DEFAULT_PORT = 2103; private ServerSocketChannel serverSocketChannel; private SlaveRecord[] slaves; private SlaveAddresses slaveAddresses; private ByteBuffer mcastBuffer; private Map<SocketChannel, ByteBuffer> writeBufferMap = new HashMap<SocketChannel, ByteBuffer>(); private Map<SocketChannel, ByteBuffer> readBufferMap = new HashMap<SocketChannel, ByteBuffer>(); private List<Object> responses = new ArrayList<Object>(); private Selector communicationSelector; private Selector discoverySelector; private Map<SocketChannel, Integer> channel2Index = new HashMap<>(); private int reconnections = 0; private String host; private int port; private static class SlaveRecord { private UUID uuid; // key unique for given series of generations of this slave private SocketChannel channel; public SlaveRecord(int index, UUID uuid, SocketChannel channel) { this.uuid = uuid; this.channel = channel; } } /** * Holds information about interfaces and IP addresses of individual slaves. * This information is collected from individual slaves and re-distributed to the other * slaves in the cluster. */ public static class SlaveAddresses implements Serializable { public Map<Integer, SlaveConnectionInfo> slaveConnections; public SlaveAddresses() { this.slaveConnections = new HashMap<>(); } public void addSlaveAddresses(int index, SlaveConnectionInfo connectionInfo) { slaveConnections.put(index, connectionInfo); } public SlaveConnectionInfo getSlaveAddresses(int index) { return slaveConnections.get(index); } } public RemoteSlaveConnection(int numSlaves, String host, int port) throws IOException { this.host = host; this.port = port > 0 && port < 65536 ? port : DEFAULT_PORT; slaves = new SlaveRecord[numSlaves]; slaveAddresses = new SlaveAddresses(); communicationSelector = Selector.open(); startServerSocket(); } public void establish() throws IOException { discoverySelector = Selector.open(); serverSocketChannel.register(discoverySelector, SelectionKey.OP_ACCEPT); mcastBuffer = ByteBuffer.allocate(DEFAULT_WRITE_BUFF_CAPACITY); int slaveCount = 0; long deadline = TimeService.currentTimeMillis() + CONNECT_TIMEOUT; while (slaveCount < slaves.length) { long timeout = deadline - TimeService.currentTimeMillis(); if (timeout <= 0) { throw new IOException((slaves.length - slaveCount) + " slaves haven't connected within timeout!"); } log.info("Awaiting registration from " + (slaves.length - slaveCount) + " slaves."); slaveCount += connectSlaves(timeout); } log.info("Connection established from " + slaveCount + " slaves."); } private int connectSlaves(long timeout) throws IOException { discoverySelector.select(timeout); Set<SelectionKey> keySet = discoverySelector.selectedKeys(); Iterator<SelectionKey> it = keySet.iterator(); int slaveCount = 0; while (it.hasNext()) { SelectionKey selectionKey = it.next(); it.remove(); if (!selectionKey.isValid()) { continue; } ServerSocketChannel srvSocketChannel = (ServerSocketChannel) selectionKey.channel(); SocketChannel socketChannel = srvSocketChannel.accept(); int slaveIndex = readInt(socketChannel); ByteBuffer uuidBytes = readBytes(socketChannel, UUID_BYTES); UUID uuid = new UUID(uuidBytes.getLong(), uuidBytes.getLong()); if (slaveIndex < 0) { for (int i = 0; i < slaves.length; ++i) { if (slaves[i] == null) { slaveIndex = i; break; } } if (slaveIndex < 0) { throw new IllegalArgumentException("All slaves are already connected."); } } else if (slaveIndex >= slaves.length) { throw new IllegalArgumentException("Slave requests invalid slaveIndex " + slaveIndex + " (expected " + slaves.length + " slaves)"); } if (slaves[slaveIndex] == null) { if (uuid.getLeastSignificantBits() != 0 || uuid.getMostSignificantBits() != 0) { throw new IllegalArgumentException("We are expecting 0th generation slave " + slaveIndex + " but it already has UUID set!"); } slaves[slaveIndex] = new RemoteSlaveConnection.SlaveRecord(slaveIndex, uuid, socketChannel); } else if (slaves[slaveIndex] != null) { RemoteSlaveConnection.SlaveRecord record = slaves[slaveIndex]; if (!uuid.equals(record.uuid)) { throw new IllegalArgumentException(String.format("For slave %d expecting UUID %s but new generation (%s) has UUID %s", slaveIndex, record.uuid, socketChannel, uuid)); } record.channel = socketChannel; } writeInt(socketChannel, slaveIndex); writeInt(socketChannel, slaves.length); slaveCount++; channel2Index.put(socketChannel, slaveIndex); readBufferMap.put(socketChannel, ByteBuffer.allocate(DEFAULT_READ_BUFF_CAPACITY)); socketChannel.configureBlocking(false); log.trace("Added new slave connection " + slaveIndex + " from: " + socketChannel.socket().getInetAddress()); } return slaveCount; } public void sendScenario(Scenario scenario, int clusterSize) throws IOException { mcastObject(scenario, clusterSize); flushBuffers(0); } public void sendConfiguration(Configuration configuration) throws IOException { mcastObject(configuration, slaves.length); flushBuffers(0); } public void sendCluster(Cluster cluster) throws IOException { mcastObject(cluster, cluster.getSize()); flushBuffers(0); } private void clearBuffer() { if (!writeBufferMap.isEmpty()) { throw new IllegalStateException("Something not sent to slaves yet: " + writeBufferMap); } mcastBuffer.clear(); } private void mcastBuffer(int numSlaves) throws IOException { for (int i = 0; i < numSlaves; ++i) { SocketChannel channel = slaves[i].channel; if (channel == null) throw new IOException("Slave " + i + " disconnected"); writeBufferMap.put(channel, ByteBuffer.wrap(mcastBuffer.array(), 0, mcastBuffer.position())); channel.register(communicationSelector, SelectionKey.OP_WRITE); } } private void mcastObject(Serializable object, int numSlaves) throws IOException { clearBuffer(); mcastBuffer = SerializationHelper.serializeObjectWithLength(object, mcastBuffer); mcastBuffer(numSlaves); } public List<DistStageAck> runStage(int stageId, Map<String, Object> masterData, int numSlaves) throws IOException { responses.clear(); clearBuffer(); mcastBuffer.putInt(stageId); mcastBuffer = SerializationHelper.serializeObjectWithLength((Serializable) masterData, mcastBuffer); mcastBuffer(numSlaves); flushBuffers(numSlaves); ArrayList<DistStageAck> list = new ArrayList<>(responses.size()); for (Object o : responses) { list.add((DistStageAck) o); } return list; } public List<Timeline> receiveTimelines(int numSlaves) throws IOException { responses.clear(); mcastObject(new Timeline.Request(), numSlaves); flushBuffers(numSlaves); return Arrays.asList(responses.toArray(new Timeline[numSlaves])); } public void receiveSlaveAddresses() throws IOException { responses.clear(); mcastObject(new SlaveConnectionInfo.Request(), slaves.length); flushBuffers(slaves.length); List<SlaveConnectionInfo> connections = Arrays.asList(responses.toArray(new SlaveConnectionInfo[slaves.length])); for (SlaveConnectionInfo connectionInfo : connections) { slaveAddresses.addSlaveAddresses(connectionInfo.getSlaveIndex(), connectionInfo); } } public void sendSlaveAddresses() throws IOException { mcastObject(slaveAddresses, slaves.length); flushBuffers(0); } private void flushBuffers(int numResponses) throws IOException { while (!writeBufferMap.isEmpty() || responses.size() < numResponses) { communicationSelector.select(); Set<SelectionKey> keys = communicationSelector.selectedKeys(); if (keys.size() > 0) { Iterator<SelectionKey> keysIt = keys.iterator(); while (keysIt.hasNext()) { SelectionKey key = keysIt.next(); keysIt.remove(); if (!key.isValid()) { continue; } if (key.isWritable()) { sendData(key); } else if (key.isReadable()) { readResponse(key); } else { log.warn("Unknown selection on key " + key); } } } } long deadline = TimeService.currentTimeMillis() + CONNECT_TIMEOUT; while (reconnections > 0) { log.infof("Waiting for %d reconnecting slaves", reconnections); long timeout = deadline - TimeService.currentTimeMillis(); if (timeout <= 0) { throw new IOException(reconnections + " slaves haven't connected within timeout!"); } reconnections -= connectSlaves(timeout); } } private void sendData(SelectionKey key) throws IOException { SocketChannel socketChannel = (SocketChannel) key.channel(); ByteBuffer buf = writeBufferMap.get(socketChannel); socketChannel.write(buf); if (buf.remaining() == 0) { key.interestOps(SelectionKey.OP_READ); writeBufferMap.remove(socketChannel); log.trace("Finished writing entire buffer, " + writeBufferMap.size() + " write buffers remaining."); } } private void readResponse(SelectionKey key) throws IOException { SocketChannel socketChannel = (SocketChannel) key.channel(); ByteBuffer byteBuffer = readBufferMap.get(socketChannel); int value = socketChannel.read(byteBuffer); if (byteBuffer.position() >= EXPECTED_SIZE_BYTES) { int expectedSize = byteBuffer.getInt(0); if ((expectedSize + EXPECTED_SIZE_BYTES + UUID_BYTES) > byteBuffer.capacity()) { ByteBuffer replacer = ByteBuffer.allocate(expectedSize + EXPECTED_SIZE_BYTES + UUID_BYTES); replacer.put(byteBuffer.array(), 0, byteBuffer.position()); readBufferMap.put(socketChannel, replacer); if (log.isTraceEnabled()) log.trace("Expected size(" + expectedSize + ")" + " is > ByteBuffer's capacity(" + byteBuffer.capacity() + ")" + ".Replacing " + byteBuffer + " with " + replacer); byteBuffer = replacer; } if (log.isTraceEnabled()) log.trace("Expected size: " + expectedSize + ". byteBuffer.position() == " + byteBuffer.position()); if (byteBuffer.position() >= expectedSize + EXPECTED_SIZE_BYTES + UUID_BYTES) { log.trace("Received response from " + socketChannel.getRemoteAddress()); Object response = SerializationHelper.deserialize(byteBuffer.array(), EXPECTED_SIZE_BYTES, expectedSize); long uuidMsb = byteBuffer.getLong(EXPECTED_SIZE_BYTES + expectedSize); long uuidLsb = byteBuffer.getLong(EXPECTED_SIZE_BYTES + expectedSize + 8); if (uuidMsb != 0 && uuidLsb != 0) { // we should expect reconnection int index = channel2Index.get(socketChannel); UUID uuid = new UUID(uuidMsb, uuidLsb); log.tracef("Slave %d (%s) is going to restart with UUID %s", index, socketChannel.getRemoteAddress(), uuid); SlaveRecord record = slaves[index]; record.uuid = uuid; record.channel.close(); record.channel = null; channel2Index.remove(socketChannel); readBufferMap.remove(socketChannel); reconnections++; } byteBuffer.clear(); responses.add(response); } } if (value < 0) { Integer slaveIndex = channel2Index.get(socketChannel); key.cancel(); if (slaveIndex == null) { throw new IllegalStateException("Unknown slave for socket " + socketChannel); } SlaveRecord record = slaves[slaveIndex]; if (record.channel == null) { // this channel was closed correctly return; } else if (record.channel != socketChannel) { throw new IllegalStateException("Unexpected socket channel " + socketChannel + "; should be " + record.channel); } else { log.warn("Slave stopped! Index: " + slaveIndex + ". Remote socket is: " + socketChannel); throw new IOException("Slave unexpectedly stopped"); } } } public void release() { if (mcastBuffer != null) { try { mcastObject(null, slaves.length); flushBuffers(0); } catch (Exception e) { log.warn("Failed to send termination to slaves.", e); } } if (discoverySelector != null) { try { discoverySelector.close(); } catch (Throwable e) { log.warn("Error closing discovery selector", e); } } if (communicationSelector != null) { try { communicationSelector.close(); } catch (Throwable e) { log.warn("Error closing comunication selector", e); } } for (SlaveRecord record : slaves) { if (record != null && record.channel != null) { try { record.channel.close(); } catch (Throwable e) { log.warn("Error closing channel", e); } } } if (serverSocketChannel != null) { try { serverSocketChannel.socket().close(); } catch (Throwable e) { log.warn("Error closing server socket channel", e); } } } public void restartSlaves(int numSlaves) throws IOException { responses.clear(); mcastObject(new Restart(), numSlaves); flushBuffers(numSlaves); } private void startServerSocket() throws IOException { serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); InetSocketAddress address; if (host == null) { address = new InetSocketAddress(port); } else { address = new InetSocketAddress(host, port); } serverSocketChannel.socket().bind(address); log.info("Master started and listening for connection on: " + address); log.info("Waiting 5 seconds for server socket to open completely"); try { Thread.sleep(5000); } catch (InterruptedException ex) { // ignore } } private int readInt(SocketChannel socketChannel) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(EXPECTED_SIZE_BYTES); while (buffer.hasRemaining()) { socketChannel.read(buffer); } buffer.flip(); return buffer.getInt(); } private ByteBuffer readBytes(SocketChannel socketChannel, int numBytes) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(numBytes); while (buffer.hasRemaining()) { socketChannel.read(buffer); } buffer.flip(); return buffer; } private void writeInt(SocketChannel socketChannel, int slaveIndex) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(EXPECTED_SIZE_BYTES); buffer.putInt(slaveIndex); buffer.flip(); while (buffer.hasRemaining()) { socketChannel.write(buffer); } } public static class Restart implements Serializable {} }