package org.nd4j.parameterserver.distributed.messages; import org.agrona.concurrent.UnsafeBuffer; import org.junit.Before; import org.junit.Test; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.logic.Storage; import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage; import org.nd4j.parameterserver.distributed.training.TrainingDriver; import org.nd4j.parameterserver.distributed.transport.Transport; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ public class FrameTest { @Before public void setUp() throws Exception { } /** * Simple test for Frame functionality */ @Test public void testFrame1() { final AtomicInteger count = new AtomicInteger(0); Frame<TrainingMessage> frame = new Frame<>(); for (int i = 0; i < 10; i++) { frame.stackMessage(new TrainingMessage() { @Override public byte getCounter() { return 2; } @Override public void setTargetId(short id) { } @Override public int getRetransmitCount() { return 0; } @Override public void incrementRetransmitCount() { } @Override public long getFrameId() { return 0; } @Override public void setFrameId(long frameId) { } @Override public long getOriginatorId() { return 0; } @Override public void setOriginatorId(long id) { } @Override public short getTargetId() { return 0; } @Override public long getTaskId() { return 0; } @Override public int getMessageType() { return 0; } @Override public byte[] asBytes() { return new byte[0]; } @Override public UnsafeBuffer asUnsafeBuffer() { return null; } @Override public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) { // no-op intentionally } @Override public void extractContext(BaseVoidMessage message) { // no-op intentionally } @Override public void processMessage() { count.incrementAndGet(); } @Override public boolean isJoinSupported() { return false; } @Override public void joinMessage(VoidMessage message) { // no-op } @Override public boolean isBlockingMessage() { return false; } }); } assertEquals(10, frame.size()); frame.processMessage(); assertEquals(20, count.get()); } @Test public void testJoin1() throws Exception { SkipGramRequestMessage sgrm = new SkipGramRequestMessage(0, 1, new int[] {3, 4, 5}, new byte[] {0, 1, 0}, (short) 0, 0.01, 119L); Frame<SkipGramRequestMessage> frame = new Frame<>(sgrm); for (int i = 0; i < 10; i++) { frame.stackMessage(sgrm); } // all messages should be stacked into one message assertEquals(1, frame.size()); assertEquals(11, sgrm.getCounter()); } }