package org.nd4j.parameterserver.distributed.logic;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation;
import org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
/**
* @author raver119@gmail.com
*/
@Slf4j
public class ClipboardTest {
@Before
public void setUp() throws Exception {
}
@After
public void tearDown() throws Exception {
}
@Test
public void testPin1() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
for (int i = 0; i < 100; i++) {
VectorAggregation aggregation =
new VectorAggregation(rng.nextLong(), (short) 100, (short) i, Nd4j.create(5));
clipboard.pin(aggregation);
}
assertEquals(false, clipboard.hasCandidates());
assertEquals(0, clipboard.getNumberOfCompleteStacks());
assertEquals(100, clipboard.getNumberOfPinnedStacks());
}
@Test
public void testPin2() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
short shardIdx = 0;
for (int i = 0; i < 300; i++) {
VectorAggregation aggregation =
new VectorAggregation(rng.nextLong(), (short) 100, (short) 1, Nd4j.create(5));
// imitating valid
if (i % 2 == 0 && shardIdx < 100) {
aggregation.setTaskId(validId);
aggregation.setShardIndex(shardIdx++);
}
clipboard.pin(aggregation);
}
VoidAggregation aggregation = clipboard.getStackFromClipboard(0L, validId);
assertNotEquals(null, aggregation);
assertEquals(0, aggregation.getMissingChunks());
assertEquals(true, clipboard.hasCandidates());
assertEquals(1, clipboard.getNumberOfCompleteStacks());
}
/**
* This test checks how clipboard handles singular aggregations
* @throws Exception
*/
@Test
public void testPin3() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
InitializationAggregation aggregation = new InitializationAggregation(1, 0);
clipboard.pin(aggregation);
assertTrue(clipboard.isTracking(0L, aggregation.getTaskId()));
assertTrue(clipboard.isReady(0L, aggregation.getTaskId()));
}
}