/* * Galaxy * Copyright (C) 2012 Parallel Universe Software Co. * * This file is part of Galaxy. * * Galaxy is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation, either version 3 of * the License, or (at your option) any later version. * * Galaxy is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with Galaxy. If not, see <http://www.gnu.org/licenses/>. */ package co.paralleluniverse.galaxy.netty; import co.paralleluniverse.galaxy.Cluster; import co.paralleluniverse.galaxy.core.Comm; import co.paralleluniverse.galaxy.core.Message; import co.paralleluniverse.galaxy.core.Message.LineMessage; import co.paralleluniverse.galaxy.core.MessageReceiver; import co.paralleluniverse.galaxy.cluster.NodeChangeListener; import co.paralleluniverse.galaxy.cluster.NodeInfo; import static co.paralleluniverse.galaxy.core.MessageMatchers.*; import co.paralleluniverse.galaxy.core.ServerComm; import static co.paralleluniverse.galaxy.netty.MessagePacketMatchers.*; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; import static org.junit.Assert.*; import org.junit.BeforeClass; import static org.hamcrest.CoreMatchers.*; import static org.junit.matchers.JUnitMatchers.*; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Mockito; import static org.mockito.Mockito.*; import static org.mockito.Matchers.*; import static org.mockito.Matchers.any; import static co.paralleluniverse.galaxy.test.LogMock.startLogging; import static co.paralleluniverse.galaxy.test.LogMock.stopLogging; import static co.paralleluniverse.galaxy.test.LogMock.when; import static co.paralleluniverse.galaxy.test.LogMock.doAnswer; import static co.paralleluniverse.galaxy.test.LogMock.doNothing; import static co.paralleluniverse.galaxy.test.LogMock.doReturn; import static co.paralleluniverse.galaxy.test.LogMock.doThrow; import static co.paralleluniverse.galaxy.test.LogMock.mock; import static co.paralleluniverse.galaxy.test.LogMock.spy; import co.paralleluniverse.galaxy.test.ClonesArguments; import static co.paralleluniverse.galaxy.test.MockitoUtil.*; import com.google.common.primitives.Shorts; import java.net.Inet4Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import org.jboss.netty.channel.socket.DatagramChannel; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; /** * Unfortunately, these tests are based on timings, and are therefore a bit flaky. * Also, there should be more tests. * @author pron */ public class UDPCommTest { static final int PORT = 100; static final InetSocketAddress GROUP; static final InetSocketAddress node2Address; static final InetSocketAddress node3Address; static final InetSocketAddress node4Address; static final long MAX_RESERVED_REF_ID = 0xffffffffL; static { try { GROUP = new InetSocketAddress(InetAddress.getByName("1.2.3.4"), PORT); node2Address = new InetSocketAddress(InetAddress.getByName("1.1.1.2"), PORT); node3Address = new InetSocketAddress(InetAddress.getByName("1.1.1.3"), PORT); node4Address = new InetSocketAddress(InetAddress.getByName("1.1.1.4"), PORT); } catch (Exception e) { throw new AssertionError(e); } } UDPComm comm; Cluster cluster; DatagramChannel channel; ServerComm serverComm; Collection<NodeInfo> masters = new ArrayList<NodeInfo>(); MessageReceiver receiver; public UDPCommTest() { } @Before public void setUp() throws Exception { cluster = mock(Cluster.class); when(cluster.isMaster()).thenReturn(true); // otherwise messages aren't passed to the receiver. when(cluster.getMyNodeId()).thenReturn(sh(1)); when(cluster.getNodes()).thenReturn(new HashSet<Short>(Shorts.asList(sh(0, 2, 3, 4)))); when(cluster.getMasters()).thenReturn(masters); addNodeInfo(sh(2), node2Address); addNodeInfo(sh(3), node3Address); addNodeInfo(sh(4), node4Address); serverComm = mock(ServerComm.class); channel = mock(DatagramChannel.class, new ClonesArguments()); // we must clone the arguments because the message packet keeps changing receiver = mock(MessageReceiver.class); comm = new UDPComm("comm", cluster, serverComm, PORT); comm.setChannel(channel); comm.setReceiver(receiver); comm.setSendToServerInsteadOfMulticast(false); comm.setMinimumNodesToMulticast(3); comm.setMulticastGroup(GROUP); comm.setResendPeriodMillisecs(20); comm.setTimeout(350); comm.setMinDelayMicrosecs(5000); // 5 millis comm.setMaxDelayMicrosecs(15000); // 15 millis comm.setExponentialBackoff(false); comm.setJitter(false); for (short node : sh(0, 2, 3, 4)) comm.nodeAdded(node); } @After public void tearDown() { } /////////////////////////////////////////////// @Test public void testSimpleSendMessage() throws Exception { final Message m = Message.GET(sh(2), 1234L); comm.send(m); await(); verify(channel, atLeastOnce()).write(argThat(is(packetThatContains(m))), eq(node2Address)); } @Test public void whenSeveralMessagesThenAggregateInPacketUntilMaxDelay() throws Exception { // we test on responses because requests actually send immediately when comm.send() is called final Message m1 = Message.INVACK(Message.INV(sh(2), id(1111L), sh(10))).setMessageId(10001); final Message m2 = Message.INVACK(Message.INV(sh(2), id(2222L), sh(10))).setMessageId(10002); final Message m3 = Message.INVACK(Message.INV(sh(2), id(3333L), sh(10))).setMessageId(10003); final Message m4 = Message.INVACK(Message.INV(sh(2), id(4444L), sh(10))).setMessageId(10004); comm.send(m1); comm.send(m2); comm.send(m3); sleep(10); // more than min delay since last comm.send(m4); await(); verify(channel, never()).write(argThat(equalTo(packet(m1))), eq(node2Address)); verify(channel, never()).write(argThat(equalTo(packet(m1, m2))), eq(node2Address)); verify(channel, atLeastOnce()).write(argThat(equalTo(packet(m1, m2, m3))), eq(node2Address)); verify(channel, atLeastOnce()).write(argThat(equalTo(packet(m1, m2, m3, m4))), eq(node2Address)); // verify(channel).write(argThat(is(allOf( // packetThatContains(m1), // packetThatContains(m2), // packetThatContains(m3), // not(packetThatContains(m4))))), eq(node2Address)); } @Test public void whenSendRequestThenResendUntilResponse() throws Exception { final Message m = Message.INV(sh(2), id(1234L), sh(10)); comm.send(m); sleep(100); await(); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node2Address)); } @Test public void whenNoResponseThenTimeout() throws Exception { final LineMessage m = Message.GET(sh(2), id(1234L)); comm.send(m); sleep(400); await(); verify(receiver).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenNoResponseForINVThenNoTimeout() throws Exception { final LineMessage m = Message.INV(sh(2), id(1234L), sh(10)); comm.send(m); sleep(400); await(); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenReceiveResponseThenStopResendingAndNoTimeout() throws Exception { final LineMessage m = Message.INV(sh(2), id(1234L), sh(10)); comm.send(m); sleep(200); comm.messageReceived(packet(Message.INVACK(m).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node2Address)); sleep(400); await(); verifyNoMoreInteractions(channel); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenSendRespondThenSendOnlyOnce() throws Exception { final Message m = Message.INVACK(Message.INV(sh(2), id(1234L), sh(10))).setMessageId(10001); comm.send(m); sleep(100); await(); verify(channel, times(1)).write(argThat(is(packetThatContains(m))), eq(node2Address)); } @Test public void whenRequestAgainAndHasResponseThenResendResponse() throws Exception { final LineMessage m = Message.INV(sh(2), id(1234L), sh(10)).setMessageId(10001).setIncoming(); comm.messageReceived(packet(m)); comm.send(Message.INVACK(m)); sleep(100); verify(channel, times(1)).write(argThat(is(packetThatContains(Message.INVACK(m)))), eq(node2Address)); comm.messageReceived(packet(m)); await(); verify(channel, times(2)).write(argThat(is(packetThatContains(Message.INVACK(m)))), eq(node2Address)); } @Test public void testSimpleBroadcast() throws Exception { final LineMessage m = Message.GET(sh(-1), id(1234L)); comm.send(m); await(); verify(channel, atLeastOnce()).write(argThat(is(packetThatContains(m))), eq(GROUP)); } @Test public void whenBroadcastRequestThenResendUntilResponse() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); await(); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); } @Test public void whenBroadcastNoReplyThenTimeout() throws Exception { final LineMessage m = Message.GET(sh(-1), id(1234L)); comm.send(m); sleep(400); await(); verify(receiver).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenBroadcastAndReceiveReplyThenStopResendingAndNoTimeout() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); comm.messageReceived(packet(Message.INVACK(m).setNode(sh(3)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(400); await(); verifyNoMoreInteractions(channel); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenBroadcastAndReceiveAcksThenStopResendingAndNotFound() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); comm.messageReceived(packet(Message.ACK(m).setNode(sh(2)).setIncoming())); comm.messageReceived(packet(Message.ACK(m).setNode(sh(3)).setIncoming())); comm.messageReceived(packet(Message.ACK(m).setNode(sh(4)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(400); await(); verify(receiver).receive(argThat(equalTo(Message.NOT_FOUND(m)))); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenUnicastBroadcastThenResendUnicastUntilResponse() throws Exception { comm.setMinimumNodesToMulticast(10); final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); await(); verify(channel, never()).write(any(MessagePacket.class), eq(GROUP)); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node2Address)); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node3Address)); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node4Address)); verifyNoMoreInteractions(channel); } @Test public void whenBroadcastRequestAndSomeAcksThenResendUnicastUntilResponse() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); comm.messageReceived(packet(Message.ACK(m).setNode(sh(3)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(200); await(); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node2Address)); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node4Address)); verifyNoMoreInteractions(channel); } @Test public void whenUnicastBroadcastAndNoReplyThenTimeout() throws Exception { final LineMessage m = Message.GET(sh(-1), id(1234L)); comm.send(m); sleep(100); comm.messageReceived(packet(Message.ACK(m).setNode(sh(3)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(200); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node2Address)); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(node4Address)); sleep(400); await(); verify(receiver).receive(argThat(equalTo(Message.TIMEOUT(m)))); verifyNoMoreInteractions(receiver); } @Test public void whenUnicastBroadcastAndReceiveReplyThenNoTimeout() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(200); comm.messageReceived(packet(Message.ACK(m).setNode(sh(3)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(100); comm.messageReceived(packet(Message.INVACK(m).setNode(sh(4)).setIncoming())); verify(channel, atLeast(1)).write(argThat(equalTo(packet(m))), eq(node2Address)); verify(channel, atLeast(1)).write(argThat(equalTo(packet(m))), eq(node4Address)); sleep(400); await(); // verifyNoMoreInteractions(channel); - node 2 will continue resending. we don't care about that verify(receiver).receive(argThat(equalTo(Message.INVACK(m).setNode(sh(4)).setIncoming()))); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } @Test public void whenUnicastBroadcastAndReceiveAcksThenNotFound() throws Exception { final LineMessage m = Message.INV(sh(-1), id(1234L), sh(10)); comm.send(m); sleep(150); comm.messageReceived(packet(Message.ACK(m).setNode(sh(3)).setIncoming())); verify(channel, atLeast(3)).write(argThat(equalTo(packet(m))), eq(GROUP)); sleep(150); comm.messageReceived(packet(Message.ACK(m).setNode(sh(2)).setIncoming())); comm.messageReceived(packet(Message.ACK(m).setNode(sh(4)).setIncoming())); verify(channel, atLeast(2)).write(argThat(equalTo(packet(m))), eq(node2Address)); verify(channel, atLeast(2)).write(argThat(equalTo(packet(m))), eq(node4Address)); sleep(400); await(); verify(receiver).receive(argThat(equalTo(Message.NOT_FOUND(m)))); verify(receiver, never()).receive(argThat(equalTo(Message.TIMEOUT(m)))); } /////////////////////////////////////////////// static NodeChangeListener getNodeChangeListener(Cluster mock) { try { return (NodeChangeListener) capture(mock, "addNodeChangeListener", arg(NodeChangeListener.class)); } catch (Exception e) { return null; } } static short sh(int x) { return (short) x; } static short[] sh(int... args) { final short[] array = new short[args.length]; for (int i = 0; i < args.length; i++) array[i] = (short) args[i]; return array; } private long id(long id) { return MAX_RESERVED_REF_ID + id; } void await() { try { ExecutorService executor = comm.getExecutor(); executor.shutdown(); executor.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { System.err.println("Interrupted"); } } static void sleep(int millis) { try { Thread.sleep(millis); } catch (InterruptedException e) { System.err.println("Interrupted"); } } private static MessagePacket packet(Message... ms) { MessagePacket packet = new MessagePacket(); for (Message m : ms) packet.addMessage(m); return packet; } private void addNodeInfo(short node, InetSocketAddress address) { NodeInfo ni = mock(NodeInfo.class); when(ni.getNodeId()).thenReturn(node); when(ni.get(IpConstants.IP_ADDRESS)).thenReturn(address.getAddress()); when(ni.get(IpConstants.IP_COMM_PORT)).thenReturn(address.getPort()); when(cluster.getMaster(node)).thenReturn(ni); masters.add(ni); } private MessagePacket captureMessagePacket() throws Exception { ArgumentCaptor<MessagePacket> captor = (ArgumentCaptor) ArgumentCaptor.forClass(MessagePacket.class); verify(channel, times(1)).write(captor.capture(), any(SocketAddress.class)); return captor.getValue(); } }