package org.nd4j.parameterserver.distributed;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage;
import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage;
import org.nd4j.parameterserver.distributed.logic.ClientRouter;
import org.nd4j.parameterserver.distributed.training.impl.CbowTrainer;
import org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer;
import org.nd4j.parameterserver.distributed.transport.MulticastTransport;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.*;
/**
* This set of tests doesn't has any assertions within.
* All we care about here - performance and availability
*
* Tests for all environments are paired: one test for blocking messages, other one for non-blocking messages.
*
* @author raver119@gmail.com
*/
@Slf4j
public class VoidParameterServerStressTest {
private static final int NUM_WORDS = 100000;
@Before
public void setUp() throws Exception {
}
@After
public void tearDown() throws Exception {
}
/**
* This test measures performance of blocking messages processing, VectorRequestMessage in this case
*/
@Test
@Ignore
public void testPerformanceStandalone1() {
VoidConfiguration voidConfiguration =
VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build();
voidConfiguration.setShardAddresses("192.168.1.35");
VoidParameterServer parameterServer = new VoidParameterServer();
parameterServer.init(voidConfiguration);
parameterServer.initializeSeqVec(100, NUM_WORDS, 123, 10, true, false);
final List<Long> times = new CopyOnWriteArrayList<>();
Thread[] threads = new Thread[8];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 1000000; i++) {
long time1 = System.nanoTime();
INDArray array = parameterServer.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 1000 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
parameterServer.shutdown();
}
/**
* This test measures performance of non-blocking messages processing, SkipGramRequestMessage in this case
*/
@Test
@Ignore
public void testPerformanceStandalone2() {
VoidConfiguration voidConfiguration =
VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build();
voidConfiguration.setShardAddresses("192.168.1.35");
VoidParameterServer parameterServer = new VoidParameterServer();
parameterServer.init(voidConfiguration);
parameterServer.initializeSeqVec(100, NUM_WORDS, 123, 10, true, false);
final List<Long> times = new CopyOnWriteArrayList<>();
Thread[] threads = new Thread[8];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 100000; i++) {
SkipGramRequestMessage sgrm = getSGRM();
long time1 = System.nanoTime();
parameterServer.execDistributed(sgrm);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 1000 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
parameterServer.shutdown();
}
@Test
@Ignore
public void testPerformanceMulticast1() throws Exception {
VoidConfiguration voidConfiguration =
VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build();
List<String> addresses = new ArrayList<>();
for (int s = 0; s < 5; s++) {
addresses.add("192.168.1.35:3789" + s);
}
voidConfiguration.setShardAddresses(addresses);
voidConfiguration.setForcedRole(NodeRole.CLIENT);
VoidConfiguration[] voidConfigurations = new VoidConfiguration[5];
VoidParameterServer[] shards = new VoidParameterServer[5];
for (int s = 0; s < shards.length; s++) {
voidConfigurations[s] = VoidConfiguration.builder().unicastPort(Integer.valueOf("3789" + s))
.networkMask("192.168.0.0/16").build();
voidConfigurations[s].setShardAddresses(addresses);
MulticastTransport transport = new MulticastTransport();
transport.setIpAndPort("192.168.1.35", Integer.valueOf("3789" + s));
shards[s] = new VoidParameterServer(false);
shards[s].setShardIndex((short) s);
shards[s].init(voidConfigurations[s], transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[s].getNodeRole());
}
// this is going to be our Client shard
VoidParameterServer parameterServer = new VoidParameterServer();
parameterServer.init(voidConfiguration);
assertEquals(NodeRole.CLIENT, VoidParameterServer.getInstance().getNodeRole());
log.info("Instantiation finished...");
parameterServer.initializeSeqVec(100, NUM_WORDS, 123, 20, true, false);
log.info("Initialization finished...");
final List<Long> times = new CopyOnWriteArrayList<>();
Thread[] threads = new Thread[8];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 100000; i++) {
long time1 = System.nanoTime();
INDArray array = parameterServer.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 1000 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
parameterServer.shutdown();;
for (VoidParameterServer server : shards) {
server.shutdown();
}
}
/**
* This is one of the MOST IMPORTANT tests
*/
@Test
public void testPerformanceUnicast1() {
List<String> list = new ArrayList<>();
for (int t = 0; t < 1; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size())
.shardAddresses(list).build();
VoidParameterServer[] shards = new VoidParameterServer[list.size()];
for (int t = 0; t < shards.length; t++) {
shards[t] = new VoidParameterServer(NodeRole.SHARD);
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
shards[t].setShardIndex((short) t);
shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
}
VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
RoutedTransport transport = new RoutedTransport();
ClientRouter router = new InterleavedRouter(0);
transport.setRouter(router);
transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
router.init(voidConfiguration, transport);
clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
final List<Long> times = new CopyOnWriteArrayList<>();
// at this point, everything should be started, time for tests
clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
log.info("Initialization finished, going to tests...");
Thread[] threads = new Thread[4];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 200; i++) {
long time1 = System.nanoTime();
INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 100 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
// shutdown everything
for (VoidParameterServer shard : shards) {
shard.getTransport().shutdown();
}
clientNode.getTransport().shutdown();
}
/**
* This is second super-important test for unicast transport.
* Here we send non-blocking messages
*/
@Test
@Ignore
public void testPerformanceUnicast2() {
List<String> list = new ArrayList<>();
for (int t = 0; t < 5; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size())
.shardAddresses(list).build();
VoidParameterServer[] shards = new VoidParameterServer[list.size()];
for (int t = 0; t < shards.length; t++) {
shards[t] = new VoidParameterServer(NodeRole.SHARD);
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
shards[t].setShardIndex((short) t);
shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
}
VoidParameterServer clientNode = new VoidParameterServer();
RoutedTransport transport = new RoutedTransport();
ClientRouter router = new InterleavedRouter(0);
transport.setRouter(router);
transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
router.init(voidConfiguration, transport);
clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
final List<Long> times = new CopyOnWriteArrayList<>();
// at this point, everything should be started, time for tests
clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
log.info("Initialization finished, going to tests...");
Thread[] threads = new Thread[4];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 200; i++) {
Frame<SkipGramRequestMessage> frame =
new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getSGRM());
}
long time1 = System.nanoTime();
clientNode.execDistributed(frame);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 100 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
// shutdown everything
for (VoidParameterServer shard : shards) {
shard.getTransport().shutdown();
}
clientNode.getTransport().shutdown();
}
/**
* This test checks for single Shard scenario, when Shard is also a Client
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast3() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1)
.shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new CbowTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
final List<Long> times = new ArrayList<>();
log.info("Starting loop...");
for (int i = 0; i < 200; i++) {
Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getCRM());
}
long time1 = System.nanoTime();
parameterServer.execDistributed(frame);
long time2 = System.nanoTime();
times.add(time2 - time1);
if (i % 50 == 0)
log.info("{} frames passed...", i);
}
Collections.sort(times);
log.info("p50: {} us", times.get(times.size() / 2) / 1000);
parameterServer.shutdown();
}
/**
* This test checks multiple Clients hammering single Shard
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast4() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1)
.shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new SkipGramTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
VoidParameterServer[] clients = new VoidParameterServer[1];
for (int c = 0; c < clients.length; c++) {
clients[c] = new VoidParameterServer(NodeRole.CLIENT);
Transport clientTransport = new RoutedTransport();
clientTransport.setIpAndPort("127.0.0.1", Integer.valueOf("4872" + c));
clients[c].init(voidConfiguration, clientTransport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clients[c].getNodeRole());
}
final List<Long> times = new CopyOnWriteArrayList<>();
log.info("Starting loop...");
Thread[] threads = new Thread[clients.length];
for (int t = 0; t < threads.length; t++) {
final int c = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
AtomicLong sequence = new AtomicLong(0);
for (int i = 0; i < 500; i++) {
Frame<SkipGramRequestMessage> frame = new Frame<>(sequence.incrementAndGet());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getSGRM());
}
long time1 = System.nanoTime();
clients[c].execDistributed(frame);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 50 == 0)
log.info("Thread_{} finished {} frames...", c, i);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (Thread thread : threads)
thread.join();
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
for (VoidParameterServer client : clients) {
client.shutdown();
}
parameterServer.shutdown();
}
/**
* This method just produces random SGRM requests, fot testing purposes.
* No real sense could be found here.
*
* @return
*/
protected static SkipGramRequestMessage getSGRM() {
int w1 = RandomUtils.nextInt(0, NUM_WORDS);
int w2 = RandomUtils.nextInt(0, NUM_WORDS);
byte[] codes = new byte[RandomUtils.nextInt(15, 45)];
int[] points = new int[codes.length];
for (int e = 0; e < codes.length; e++) {
codes[e] = (byte) (e % 2 == 0 ? 0 : 1);
points[e] = RandomUtils.nextInt(0, NUM_WORDS);
}
return new SkipGramRequestMessage(w1, w2, points, codes, (short) 0, 0.025, 213412L);
}
protected static CbowRequestMessage getCRM() {
int w1 = RandomUtils.nextInt(0, NUM_WORDS);
int syn0[] = new int[5];
for (int e = 0; e < syn0.length; e++) {
syn0[e] = RandomUtils.nextInt(0, NUM_WORDS);
}
byte[] codes = new byte[RandomUtils.nextInt(15, 45)];
int[] points = new int[codes.length];
for (int e = 0; e < codes.length; e++) {
codes[e] = (byte) (e % 2 == 0 ? 0 : 1);
points[e] = RandomUtils.nextInt(0, NUM_WORDS);
}
return new CbowRequestMessage(syn0, points, w1, codes, 0, 0.025, 119);
}
}