package water.init;
import water.fvec.Vec;
import water.*;
import water.util.*;
import java.util.Random;
public class NetworkTest extends Iced {
public int[] msg_sizes = new int[]{1, 1 << 10, 1 << 20}; //INPUT // Message sizes
public int repeats = 10; //INPUT // Repeats
public boolean collective = true; // Do collective test
public boolean serial = true; // Do serial test
public double[] microseconds_collective; //OUTPUT // Collective broadcast/reduce times in microseconds (for each message size)
public double[] bandwidths_collective; //OUTPUT // Collective bandwidths in Bytes/sec (for each message size, for each node)
public double[][] microseconds; //OUTPUT // Round-trip times in microseconds (for each message size, for each node)
public double[][] bandwidths; //OUTPUT // Bi-directional bandwidths in Bytes/sec (for each message size, for each node)
public String[] nodes; //OUTPUT // Nodes
public TwoDimTable table; //OUTPUT
public NetworkTest execImpl() {
microseconds = new double[msg_sizes.length][];
microseconds_collective = new double[msg_sizes.length];
NetworkTester nt = new NetworkTester(msg_sizes, microseconds, microseconds_collective, repeats, serial, collective);
H2O.submitTask(nt);
nt.join();
// compute bandwidths from timing results
bandwidths = new double[msg_sizes.length][];
for (int i = 0; i < bandwidths.length; ++i) {
bandwidths[i] = new double[microseconds[i].length];
for (int j = 0; j < microseconds[i].length; ++j) {
//send and receive the same message -> 2x
bandwidths[i][j] = (2 * msg_sizes[i] /*Bytes*/) / (microseconds[i][j] / 1e6 /*Seconds*/);
}
}
bandwidths_collective = new double[msg_sizes.length];
for (int i = 0; i < bandwidths_collective.length; ++i) {
//broadcast and reduce the message to all nodes -> 2 x nodes
bandwidths_collective[i] = (2 * H2O.CLOUD.size() * msg_sizes[i] /*Bytes*/) / (microseconds_collective[i] / 1e6 /*Seconds*/);
}
// populate node names
nodes = new String[H2O.CLOUD.size()];
for (int i = 0; i < nodes.length; ++i)
nodes[i] = H2O.CLOUD._memary[i].toString();
fillTable();
Log.info(table.toString());
return this;
}
// Helper class to run the actual test
public static class NetworkTester extends H2O.H2OCountedCompleter {
double[][] microseconds;
double[] microseconds_collective;
int[] msg_sizes;
public int repeats = 10;
boolean serial;
boolean collective;
public NetworkTester(int[] msg, double[][] res, double[] res_collective, int rep, boolean serial, boolean collective) {
super((byte)(H2O.MIN_HI_PRIORITY-1));
microseconds = res;
microseconds_collective = res_collective;
msg_sizes = msg;
repeats = rep;
this.serial = serial;
this.collective = collective;
}
@Override public void compute2() {
// serial comm
if (serial) {
for (int i = 0; i < microseconds.length; ++i) {
microseconds[i] = send_recv_all(msg_sizes[i], repeats);
ArrayUtils.div(microseconds[i], 1e3f); //microseconds
}
}
// collective comm
if (collective) {
for (int i = 0; i < microseconds_collective.length; ++i) {
microseconds_collective[i] = send_recv_collective(msg_sizes[i], repeats);
}
ArrayUtils.div(microseconds_collective, 1e3f); //microseconds
}
tryComplete();
}
}
/**
* Helper class that contains a payload and has an empty compute2().
* If it is remotely executed, it will just send the payload over the wire.
*/
private static class PingPongTask extends DTask<PingPongTask> {
private final byte[] _payload;
public PingPongTask(byte[] payload) { _payload = payload; }
@Override public void compute2() { tryComplete(); }
}
/**
* Send a message from this node to all nodes in serial (including self), and receive it back
*
* @param msg_size message size in bytes
* @return Time in nanoseconds that it took to send and receive the message (one per node)
*/
private static double[] send_recv_all(int msg_size, int repeats) {
byte[] payload = new byte[msg_size];
new Random().nextBytes(payload);
final int siz = H2O.CLOUD.size();
double[] times = new double[siz];
for (int i = 0; i < siz; ++i) { //loop over compute nodes
H2ONode node = H2O.CLOUD._memary[i];
Timer t = new Timer();
for (int l = 0; l < repeats; ++l) {
PingPongTask ppt = new PingPongTask(payload); //same payload for all nodes
new RPC<>(node, ppt).call().get(); //blocking send
}
times[i] = (double) t.nanos() / repeats;
}
return times;
}
/**
* Helper class that contains a payload and has an empty map/reduce.
* If it is remotely executed, it will just send the payload over the wire.
*/
private static class CollectiveTask extends MRTask<CollectiveTask> {
private final byte[] _payload; //will be sent over the wire (broadcast/reduce)
public CollectiveTask(byte[] payload) {
_payload = payload;
}
}
/**
* Broadcast a message from this node to all nodes and reduce it back
*
* @param msg_size message size in bytes
* @return Time in nanoseconds that it took
*/
private static double send_recv_collective(int msg_size, int repeats) {
byte[] payload = new byte[msg_size];
new Random().nextBytes(payload);
Vec v = Vec.makeZero(1); //trivial Vec: 1 element with value 0.
Timer t = new Timer();
for (int l = 0; l < repeats; ++l) {
new CollectiveTask(payload).doAll(v); //same payload for all nodes
}
v.remove(new Futures()).blockForPending();
return (double) t.nanos() / repeats;
}
public void fillTable() {
String tableHeader = "Network Test";
String tableDescription = "Launched from " + H2O.SELF._key;
String[] rowHeaders = new String[H2O.CLOUD.size()+1];
rowHeaders[0] = "all - collective bcast/reduce";
for (int i = 0; i < H2O.CLOUD.size(); ++i) {
rowHeaders[1+i] =
((H2O.SELF.equals(H2O.CLOUD._memary[i]) ? "self" : "remote") + " " + H2O.CLOUD._memary[i].toString());
}
String[] colHeaders = new String[msg_sizes.length];
for (int i = 0; i < colHeaders.length; ++i) {
colHeaders[i] = msg_sizes[i] + " bytes";
}
String[] colTypes = new String[msg_sizes.length];
for (int i = 0; i < colTypes.length; ++i) {
colTypes[i] = "string";
}
String[] colFormats = new String[msg_sizes.length];
for (int i = 0; i < colTypes.length; ++i) {
colFormats[i] = "%s";
}
String colHeaderForRowHeaders = "Destination";
table = new TwoDimTable(tableHeader, tableDescription, rowHeaders, colHeaders, colTypes, colFormats, colHeaderForRowHeaders);
for (int m = 0; m < msg_sizes.length; ++m) {
table.set(0, m, PrettyPrint.usecs((long) microseconds_collective[m]) + ", " + PrettyPrint.bytesPerSecond((long) bandwidths_collective[m]));
}
for (int n = 0; n < H2O.CLOUD._memary.length; ++n) {
for (int m = 0; m < msg_sizes.length; ++m) {
table.set(1 + n, m, PrettyPrint.usecs((long) microseconds[m][n]) + ", " + PrettyPrint.bytesPerSecond((long) bandwidths[m][n]));
}
}
}
}