package lsr.paxos.network; import static lsr.common.ProcessDescriptor.processDescriptor; import java.io.IOException; import java.net.DatagramPacket; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.MulticastSocket; import java.nio.ByteBuffer; import java.util.BitSet; import java.util.HashMap; import lsr.common.KillOnExceptionHandler; import lsr.paxos.messages.Message; import lsr.paxos.messages.MessageFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class MulticastNetwork extends Network { private static final int myHeaderSize = 1 + 8 + 4 + 4; private final MulticastSocket socket; private final Network unicastNetwork; private final InetSocketAddress groupAddress; private final Thread readThread; public MulticastNetwork(Network unicastNetwork, long runId) throws IOException { super(); this.sendMessageId = runId << 32 + localId; if (processDescriptor.numReplicas > 256) { throw new RuntimeException("Multicast network supports up to 256 relicas"); } this.unicastNetwork = unicastNetwork; socket = new MulticastSocket(processDescriptor.multicastPort); // MulticastSocket constructor by default sets SO_REUSEADDR groupAddress = new InetSocketAddress( InetAddress.getByName(processDescriptor.multicastIpAddress), processDescriptor.multicastPort); socket.joinGroup(groupAddress.getAddress()); readThread = new ReceiveThread(); logger.info("Multicast network created on {} to {} and joined {}", socket.getLocalAddress(), socket.getInetAddress(), groupAddress); } protected void send(Message message, int destination) { unicastNetwork.send(message, destination); } protected void send(Message message, BitSet destinations) { if (shouldMulticast(destinations)) { try { multicast(message); } catch (IOException e) { throw new RuntimeException("Connectionless socket refused a send operation", e); } } else { logger.trace("Multicast network - passing message to unicast netwrok: {}", message); unicastNetwork.send(message, destinations); } } private final Object sendLock = new Object(); private ByteBuffer sendBuffer = ByteBuffer.allocate(1024); private final int payloadSize = processDescriptor.mtu - 8 /* udp header size */- myHeaderSize; private long sendMessageId; private final HashMap<Long, Message> recentMessages = new HashMap<Long, Message>(); private final byte[] messageBA = new byte[processDescriptor.mtu - 8]; private final ByteBuffer messageBuffer = ByteBuffer.wrap(messageBA); private void multicast(Message message) throws IOException { synchronized (sendLock) { if (sendBuffer.capacity() < message.byteSize()) sendBuffer = ByteBuffer.allocate(message.byteSize()); assert sendBuffer.position() == 0 && sendBuffer.limit() == sendBuffer.capacity(); message.writeTo(sendBuffer); sendBuffer.flip(); int size = sendBuffer.limit(); int part = 0; messageBuffer.rewind(); messageBuffer.put((byte) localId); messageBuffer.putLong(sendMessageId); messageBuffer.putInt(size); recentMessages.put(sendMessageId, message); logger.trace("Multicasting as {}: {}", sendMessageId, message); sendMessageId += processDescriptor.numReplicas; while (sendBuffer.remaining() > payloadSize) { sendBuffer.get(messageBA, myHeaderSize, payloadSize); multicast(part++, payloadSize + myHeaderSize); } int remaining = sendBuffer.remaining(); if (remaining > 0) { sendBuffer.get(messageBA, myHeaderSize, remaining); multicast(part++, myHeaderSize + remaining); } sendBuffer.clear(); } } private void multicast(int part, int size) throws IOException { logger.trace("Sending part {} size {}", part, size); messageBuffer.putInt(1 + 8 + 4, part); DatagramPacket dp = new DatagramPacket(messageBA, size); dp.setSocketAddress(groupAddress); socket.send(dp); } private boolean shouldMulticast(BitSet destinations) { return destinations.cardinality() > 1; } public void start() { assert !readThread.isAlive(); readThread.start(); unicastNetwork.start(); } public class ReceiveThread extends Thread { private class MessageParts { private final BitSet bs = new BitSet(); private final byte[] contents; private final int parts; public MessageParts(int length) { contents = new byte[length]; parts = (length + payloadSize - 1) / payloadSize; } public void addPart(int part) { assert part < parts; if (bs.get(part)) { logger.debug("Received part {} multiple times for some message", part); } bs.set(part); receiveBuffer.get(contents, part * payloadSize, receiveBuffer.remaining()); } public boolean hasAllParts() { return bs.cardinality() == parts; } public byte[] get() { assert hasAllParts(); return contents; } } private final byte[] receiveBA = new byte[processDescriptor.mtu - 8]; private final ByteBuffer receiveBuffer = ByteBuffer.wrap(receiveBA); private final DatagramPacket packet = new DatagramPacket(receiveBA, receiveBA.length); private final HashMap<Long, MessageParts> messageParts = new HashMap<Long, MulticastNetwork.ReceiveThread.MessageParts>(); public ReceiveThread() { super("MulticastReceive"); setUncaughtExceptionHandler(new KillOnExceptionHandler()); setDaemon(true); } public void run() { try { while (true) { socket.receive(packet); int senderId = receiveBuffer.get(); if (senderId == localId) { receiveBuffer.clear(); continue; } long messageId = receiveBuffer.getLong(); int totalSize = receiveBuffer.getInt(); int partNo = receiveBuffer.getInt(); receiveBuffer.limit(packet.getLength()); assert (partNo + 1) * payloadSize > totalSize ? partNo * payloadSize + receiveBuffer.remaining() == totalSize : receiveBuffer.remaining() == payloadSize; MessageParts parts = messageParts.get(messageId); if (parts == null) { parts = new MessageParts(totalSize); messageParts.put(messageId, parts); } if (logger.isTraceEnabled()) logger.trace("Multicast received from {} part {} of {}, size {}", senderId, partNo, messageId, receiveBuffer.remaining()); parts.addPart(partNo); if (parts.hasAllParts()) { received(messageId, senderId, parts); } receiveBuffer.clear(); } } catch (IOException e) { throw new RuntimeException( "Connectionless socket refused a receive operation or read from memory failed", e); } } private void received(long messageId, int sender, MessageParts parts) throws IOException { Message msg = MessageFactory.readByteArray(parts.get()); if (logger.isTraceEnabled()) logger.trace("Multicast received {} from {}: {}", messageId, sender, msg); fireReceiveMessage(msg, sender); recentMessages.remove(messageId); } } private final static Logger logger = LoggerFactory.getLogger(MulticastNetwork.class); }