package hex.deeplearning;
import static hex.deeplearning.Neurons.*;
import hex.deeplearning.Neurons.*;
import org.junit.Ignore;
import org.junit.Test;
import water.PrettyPrint;
import water.util.Log;
import water.util.Utils;
import java.util.Random;
public class NeuronsTest {
@Test
@Ignore
public void matrixVecTest() {
int rows = 2048;
int cols = 8192;
int loops = 50;
int warmup_loops = 50;
long seed = 0x533D;
float nnz_ratio_vec = 0.01f; //fraction of non-zeroes for vector
float nnz_ratio_mat = 0.1f; //fraction of non-zeroes for matrix
float [] a = new float[rows*cols];
float [] x = new float[cols];
float [] y = new float[rows];
float [] res = new float[rows];
byte [] bits = new byte[rows];
for (int row=0;row<rows;++row) {
y[row] = 0;
res[row] = 0;
bits[row] = (byte)(new String("abcdefghijklmnopqrstuvwxyz").toCharArray()[row%26]);
}
Random rng = new Random(seed);
for (int col=0;col<cols;++col)
if (rng.nextFloat() < nnz_ratio_vec)
x[col] = ((float)col)/cols;
for (int row=0;row<rows;++row) {
int off = row*cols;
for (int col=0;col<cols;++col) {
if (rng.nextFloat() < nnz_ratio_mat)
a[off+col] = ((float)(row+col))/cols;
}
}
DenseRowMatrix dra = new DenseRowMatrix(a, rows, cols);
DenseColMatrix dca = new DenseColMatrix(dra, rows, cols);
SparseRowMatrix sra = new SparseRowMatrix(dra, rows, cols);
SparseColMatrix sca = new SparseColMatrix(dca, rows, cols);
DenseVector dx = new DenseVector(x);
DenseVector dy = new DenseVector(y);
DenseVector dres = new DenseVector(res);
SparseVector sx = new SparseVector(x);
/**
* warmup
*/
System.out.println("warming up.");
float sum = 0;
for (int l=0;l<warmup_loops;++l) {
gemv_naive(res, a, x, y, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv_naive(dres, dra, dx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv_row_optimized(res, a, x, y, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv(dres, dca, dx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv(dres, dra, sx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv(dres, dca, sx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv(dres, sra, sx, dy, bits);
sum += res[rows/2];
}
for (int l=0;l<warmup_loops;++l) {
gemv(dres, sca, sx, dy, bits);
sum += res[rows/2];
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
/**
* naive version
*/
System.out.println("\nstarting naive.");
sum = 0;
long start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_naive(res, a, x, y, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("naive time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting dense row * dense.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_naive(dres, dra, dx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("dense row * dense time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting optimized dense row * dense.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv_row_optimized(res, a, x, y, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("optimized dense row * dense time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting dense col * dense.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv(dres, dca, dx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("dense col * dense time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting dense row * sparse.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv(dres, dra, sx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("dense row * sparse time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting dense col * sparse.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv(dres, dca, sx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("dense col * sparse time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting sparse row * sparse.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv(dres, sra, sx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("sparse row * sparse time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
System.out.println("\nstarting sparse col * sparse.");
sum = 0;
start = System.currentTimeMillis();
for (int l=0;l<loops;++l) {
gemv(dres, sca, sx, dy, bits);
sum += res[rows/2]; //do something useful
}
System.out.println("result: " + sum + " and " + Utils.sum(res));
System.out.println("sparse col * sparse time: " + PrettyPrint.msecs(System.currentTimeMillis()-start, true));
}
@Test
public void sparseTester() {
DenseVector dv = new DenseVector(20);
dv.set(3,0.21f);
dv.set(7,0.13f);
dv.set(18,0.14f);
SparseVector sv = new SparseVector(dv);
assert(sv.size() == 20);
assert(sv.nnz() == 3);
// dense treatment
for (int i=0;i<sv.size();++i)
Log.info("sparse [" + i + "] = " + sv.get(i));
// sparse treatment
for (SparseVector.Iterator it=sv.begin(); !it.equals(sv.end()); it.next()) {
// Log.info(it.toString());
Log.info(it.index() + " -> " + it.value());
}
DenseColMatrix dcm = new DenseColMatrix(3,5);
dcm.set(2,1,3.2f);
dcm.set(1,3,-1.2f);
assert(dcm.get(2,1)==3.2f);
assert(dcm.get(1,3)==-1.2f);
assert(dcm.get(0,0)==0f);
DenseRowMatrix drm = new DenseRowMatrix(3,5);
drm.set(2,1,3.2f);
drm.set(1,3,-1.2f);
assert(drm.get(2,1)==3.2f);
assert(drm.get(1,3)==-1.2f);
assert(drm.get(0,0)==0f);
SparseColMatrix scm = new SparseColMatrix(3,5);
scm.set(2,1,3.2f);
scm.set(1,3,-1.2f);
assert(scm.get(2,1)==3.2f);
assert(scm.get(1,3)==-1.2f);
assert(scm.get(0,0)==0f);
SparseRowMatrix srm = new SparseRowMatrix(3,5);
srm.set(2,1,3.2f);
srm.set(1,3,-1.2f);
assert(srm.get(2,1)==3.2f);
assert(srm.get(1,3)==-1.2f);
assert(srm.get(0,0)==0f);
}
}