package hex.deeplearning;
import static hex.deeplearning.Neurons.*;
import org.junit.*;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.PrettyPrint;
import java.util.Random;
public class NeuronsTest extends water.TestUtil {
@BeforeClass public static void setup() { stall_till_cloudsize(1); }
@Ignore
@Test
public void matrixVecTest() {
int rows = 2048;
int cols = 8192;
int loops = 5;
int warmup_loops = 5;
long seed = 0x533D;
float nnz_ratio_vec = 0.01f; //fraction of non-zeroes for vector
float nnz_ratio_mat = 0.1f; //fraction of non-zeroes for matrix
float [] a = new float[rows*cols];
double [] x = new double[cols];
double [] y = new double[rows];
double [] res = new double[rows];
byte [] bits = new byte[rows];
for (int row=0;row<rows;++row) {
y[row] = 0;
res[row] = 0;
bits[row] = (byte)("abcdefghijklmnopqrstuvwxyz".toCharArray()[row%26]);
}
Random rng = new Random(seed);
for (int col=0;col<cols;++col)
if (rng.nextFloat() < nnz_ratio_vec)
x[col] = ((float)col)/cols;
for (int row=0;row<rows;++row) {
int off = row*cols;
for (int col=0;col<cols;++col) {
if (rng.nextFloat() < nnz_ratio_mat)
a[off+col] = ((float)(row+col))/cols;
}
}
Storage.DenseRowMatrix dra = new Storage.DenseRowMatrix(a, rows, cols);
Storage.DenseColMatrix dca = new Storage.DenseColMatrix(dra, rows, cols);
Storage.SparseRowMatrix sra = new Storage.SparseRowMatrix(dra, rows, cols);
Storage.SparseColMatrix sca = new Storage.SparseColMatrix(dca, rows, cols);
Storage.DenseVector dx = new Storage.DenseVector(x);
Storage.DenseVector dy = new Storage.DenseVector(y);
Storage.DenseVector dres = new Storage.DenseVector(res);
/**
* warmup
*/
System.out.println("warming up.");
float sum = 0;
for (int l=0;l<warmup_loops;++l) {
gemv_naive(res, a, x, y, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv_naive(dres, dra, dx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv_row_optimized(res, a, x, y, bits);
sum += res[rows/2];
}
/**
* naive version
*/
System.out.println("\nstarting naive.");
sum = 0;
long start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_naive(res, a, x, y, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + ArrayUtils.sum(res));
System.out.println("naive time: " + PrettyPrint.msecs(System.currentTimeMillis() - start, true));
System.out.println("\nstarting dense row * dense.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_naive(dres, dra, dx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + ArrayUtils.sum(res));
System.out.println("dense row * dense time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting optimized dense row * dense.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_row_optimized(res, a, x, y, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + ArrayUtils.sum(res));
System.out.println("optimized dense row * dense time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
}
}