package org.nd4j.jita.concurrency; import org.apache.commons.lang3.ArrayUtils; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ @Ignore public class CudaAffinityManagerTest { @Before public void setUp() throws Exception { } @Test public void getDeviceForCurrentThread() throws Exception { CudaAffinityManager manager = new CudaAffinityManager(); Integer deviceId = manager.getDeviceForCurrentThread(); assertEquals(0, deviceId.intValue()); manager.attachThreadToDevice(Thread.currentThread().getId(), 1); assertEquals(1, manager.getDeviceForCurrentThread().intValue()); manager.attachThreadToDevice(Thread.currentThread().getId(), 0); assertEquals(0, manager.getDeviceForCurrentThread().intValue()); } @Test public void getDeviceForAnotherThread() throws Exception { CudaAffinityManager manager = new CudaAffinityManager(); Integer deviceId = manager.getDeviceForCurrentThread(); assertEquals(0, deviceId.intValue()); manager.attachThreadToDevice(1731L, 0); assertEquals(0, manager.getDeviceForThread(1731L).intValue()); } @Test public void getDeviceForAnotherThread2() throws Exception { CudaAffinityManager manager = new CudaAffinityManager(); Integer deviceId = manager.getDeviceForCurrentThread(); assertEquals(0, deviceId.intValue()); System.out.println("Current threadId: " + Thread.currentThread().getId()); Thread thread = new Thread(); long threadIdPrior = thread.getId(); System.out.println("Next threadId: " + thread.getId()); assertNotEquals(Thread.currentThread().getId(), thread.getId()); thread.start(); System.out.println("Current threadId: " + thread.getId()); assertEquals(threadIdPrior, thread.getId()); } /** * This is special test for multi-threaded environment * @throws Exception */ @Test public void getDeviceForAnotherThread3() throws Exception { final int limit = 10; final CudaAffinityManager manager = new CudaAffinityManager(); final Thread threads[] = new Thread[limit]; final AtomicBoolean[] results = new AtomicBoolean[limit]; for (int cnt = 0; cnt < limit; cnt++) { final int c = cnt; results[cnt] = new AtomicBoolean(false); threads[cnt] = new Thread(new Runnable() { @Override public void run() { assertEquals(0, manager.getDeviceForCurrentThread().intValue()); results[c].set(true); } }); manager.attachThreadToDevice(threads[cnt], 0); threads[cnt].start(); } for (int cnt = 0; cnt < limit; cnt++) { threads[cnt].join(); assertTrue("Failed for thread ["+ cnt+"]", results[cnt].get()); } } /** * This is special test for multi-threaded multi-gpu environment * @throws Exception */ @Test public void getDeviceForAnotherThread4() throws Exception { final int limit = 10; final CudaAffinityManager manager = new CudaAffinityManager(); final Thread threads[] = new Thread[limit]; final AtomicBoolean[] results = new AtomicBoolean[limit]; final int cards[] = new int[limit]; for (int cnt = 0; cnt < limit; cnt++) { final int c = cnt; results[cnt] = new AtomicBoolean(false); threads[cnt] = new Thread(new Runnable() { @Override public void run() { // this is pseudo-master thread final int deviceId = manager.getDeviceForCurrentThread(); Thread thread = new Thread(new Runnable() { @Override public void run() { int cdev = manager.getDeviceForCurrentThread(); assertEquals(deviceId, cdev); results[c].set(true); cards[c] = cdev; } }); manager.attachThreadToDevice(thread, deviceId); thread.start(); try { thread.join(); } catch (Exception e) { ; } } }); threads[cnt].start(); } for (int cnt = 0; cnt < limit; cnt++) { threads[cnt].join(); assertTrue("Failed for thread ["+ cnt+"]", results[cnt].get()); } int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); for (int c = 0; c < numDevices; c++) { assertTrue("Failed to find device ["+ c +"] in used devices", ArrayUtils.contains(cards, c)); } } }