package org.deeplearning4j.parallelism;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*;
/**
* @author raver119@gmail.com
*/
@Slf4j
public class MagicQueueTest {
@Before
public void setUp() throws Exception {
}
@Test
public void addDataSet1() throws Exception {
MagicQueue queue = new MagicQueue.Builder().setNumberOfBuckets(1).build();
int numDevices = 1; // Force single device
DataSet dataSet_1 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_2 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_3 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_4 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_5 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_6 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_7 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_8 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
queue.add(dataSet_1);
queue.add(dataSet_2);
queue.add(dataSet_3);
queue.add(dataSet_4);
queue.add(dataSet_5);
queue.add(dataSet_6);
queue.add(dataSet_7);
queue.add(dataSet_8);
Thread.sleep(500);
assertEquals(8 / numDevices, queue.size());
int cnt = 0;
while (!queue.isEmpty()) {
DataSet ds = (DataSet) queue.poll();
assertNotEquals("Failed on iteration: " + cnt, null, ds);
cnt++;
}
assertEquals(8, cnt);
}
/**
* This test will fail on single-gpu system
*
* @throws Exception
*/
@Test
public void addDataSet2() throws Exception {
MagicQueue queue = new MagicQueue.Builder().build();
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
for (int i = 0; i < numDevices * 4; i++) {
DataSet dataSet = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
queue.add(dataSet);
}
Thread.sleep(500);
if (numDevices == 1)
assertEquals(4, queue.size());
else
assertEquals(8 / numDevices, queue.size());
int cnt = 0;
while (!queue.isEmpty()) {
DataSet ds = (DataSet) queue.poll();
if (cnt < 4) {
assertNotEquals("Failed on iteration: " + cnt, null, ds);
cnt++;
} else {
break;
}
}
assertEquals(4, cnt);
}
/**
* THIS TEST REQUIRES CUDA BACKEND AND MULTI-GPU ENVIRONMENT
* TO USE THIS TEST EFFICIENTLY - ENABLE ND4J-CUDA BACKEND FOR THIS MODULE
*
* In this test we check actual data relocation within MagicQueue
*
* @throws Exception
*/
@Test
public void test_cuda_multiGPU_testAffinityChange1() throws Exception {
MagicQueue queue = new MagicQueue.Builder().build();
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
DataSet dataSet_1 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_2 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_3 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_4 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_5 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_6 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_7 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_8 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
// All arrays are located on same initial device
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_1.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_2.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_3.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_4.getFeatures()).intValue());
queue.add(dataSet_1);
queue.add(dataSet_2);
queue.add(dataSet_3);
queue.add(dataSet_4);
queue.add(dataSet_5);
queue.add(dataSet_6);
queue.add(dataSet_7);
queue.add(dataSet_8);
Thread.sleep(500);
assertEquals(8 / numDevices, queue.size());
log.info("Checking first device...");
// All arrays are spread over all available devices
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_1.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_1.getLabels()).intValue());
int nextDev = 0;
if (numDevices > 1) {
log.info("Checking second device...");
nextDev++;
}
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_2.getFeatures()).intValue());
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_2.getLabels()).intValue());
if (numDevices > 2) {
log.info("Checking third device...");
nextDev++;
} else {
log.info("Checking first device...");
nextDev = 0;
}
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_3.getFeatures()).intValue());
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_3.getLabels()).intValue());
if (numDevices > 2) {
log.info("Checking fourth device...");
nextDev++;
} else {
if (numDevices > 1) {
log.info("Checking second device...");
nextDev = 1;
} else {
log.info("Checking first device...");
nextDev = 0;
}
}
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_4.getFeatures()).intValue());
assertEquals(nextDev, Nd4j.getAffinityManager().getDeviceForArray(dataSet_4.getLabels()).intValue());
}
@Test
public void testSequential() throws Exception {
MagicQueue queue = new MagicQueue.Builder().setMode(MagicQueue.Mode.SEQUENTIAL).build();
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
DataSet dataSet_1 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_2 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_3 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_4 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_5 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_6 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_7 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
DataSet dataSet_8 = new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f}));
// All arrays are located on same initial device
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_1.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_2.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_3.getFeatures()).intValue());
assertEquals(0, Nd4j.getAffinityManager().getDeviceForArray(dataSet_4.getFeatures()).intValue());
queue.add(dataSet_1);
queue.add(dataSet_2);
queue.add(dataSet_3);
queue.add(dataSet_4);
queue.add(dataSet_5);
queue.add(dataSet_6);
queue.add(dataSet_7);
queue.add(dataSet_8);
assertEquals(8, queue.size());
int cnt = 0;
while (!queue.isEmpty()) {
DataSet ds = (DataSet) queue.poll(2, TimeUnit.SECONDS);
// making sure dataset isn't null
assertNotEquals("Failed on round " + cnt, null, ds);
// making sure device for this array is a "next one"
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()).intValue());
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getLabels()).intValue());
cnt++;
}
assertEquals(8, cnt);
}
@Test
public void testSequentialIterable() throws Exception {
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < 1024; i++)
list.add(new DataSet(Nd4j.create(new float[] {1f, 2f, 3f}), Nd4j.create(new float[] {1f, 2f, 3f})));
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
ExistingDataSetIterator edsi = new ExistingDataSetIterator(list);
MagicQueue queue = new MagicQueue.Builder().setMode(MagicQueue.Mode.SEQUENTIAL).setCapacityPerFlow(32).build();
AsyncDataSetIterator adsi = new AsyncDataSetIterator(edsi, 10, queue);
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
// making sure dataset isn't null
assertNotEquals("Failed on round " + cnt, null, ds);
// making sure device for this array is a "next one"
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()).intValue());
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getLabels()).intValue());
cnt++;
}
assertEquals(list.size(), cnt);
}
}