package org.apache.s4.model;
import java.util.Random;
import org.ejml.data.DenseMatrix64F;
import junit.framework.Assert;
import junit.framework.TestCase;
public class TestGaussianModel extends TestCase {
private int NUM_VECTORS = 100000;
private double mean[] = { 153, 10.0, 5.0, 0.1 };
private double std[] = { 30, 2.0, 1.0, 5.5 };
private int numElements = mean.length;
private DenseMatrix64F vectors[] = new DenseMatrix64F[NUM_VECTORS];
private double doubleArrays[][] = new double[NUM_VECTORS][numElements];
private float floatArrays[][] = new float[NUM_VECTORS][numElements];
private Random random = new Random(0);
protected void setUp() {
/* Generate the data set. */
for (int i = 0; i < NUM_VECTORS; i++) {
vectors[i] = new DenseMatrix64F(numElements, 1);
for (int j = 0; j < numElements; j++) {
double v = mean[j] + std[j] * random.nextGaussian();
vectors[i].set(j, v);
doubleArrays[i][j] = v;
floatArrays[i][j] = (float)v;
}
}
}
public void testTrainerUsingEJML() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(vectors[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
public void testTrainerUsingDoubleArray() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(doubleArrays[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
public void testTrainerUsingFloatArray() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(floatArrays[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
}