package org.deeplearning4j.ui.weights;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.math.BigDecimal;
import static org.junit.Assert.assertEquals;
/**
* @author raver119@gmail.com
*/
public class HistogramBinTest {
@Before
public void setUp() throws Exception {
}
@Test
public void testGetBins() throws Exception {
INDArray array = Nd4j.create(new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.0});
HistogramBin histogram = new HistogramBin.Builder(array).setBinCount(10).build();
assertEquals(0.1, histogram.getMin(), 0.001);
assertEquals(1.0, histogram.getMax(), 0.001);
System.out.println("Result: " + histogram.getBins());
assertEquals(2, histogram.getBins().getDouble(9), 0.001);
}
@Test
public void testGetData1() throws Exception {
INDArray array = Nd4j.create(new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.0});
HistogramBin histogram = new HistogramBin.Builder(array).setBinCount(10).build();
assertEquals(0.1, histogram.getMin(), 0.001);
assertEquals(1.0, histogram.getMax(), 0.001);
System.out.println("Result: " + histogram.getData());
assertEquals(10, histogram.getData().size());
}
@Test
public void testGetData2() throws Exception {
INDArray array = Nd4j.create(new double[] {-1.0f, -0.50f, 0.0f, 0.50f, 1.0f, -1.0f, -0.50f, 0.0f, 0.50f, 1.0f});
HistogramBin histogram = new HistogramBin.Builder(array).setBinCount(10).build();
assertEquals(-1.0, histogram.getMin(), 0.001);
assertEquals(1.0, histogram.getMax(), 0.001);
System.out.println("Result: " + histogram.getData());
assertEquals(10, histogram.getData().size());
assertEquals(2, histogram.getData().get(new BigDecimal("1.00")).get());
}
@Test
public void testGetData4() throws Exception {
INDArray array = Nd4j.create(new double[] {-1.0f, -0.50f, 0.0f, 0.50f, 1.0f, -1.0f, -0.50f, 0.0f, 0.50f, 1.0f});
HistogramBin histogram = new HistogramBin.Builder(array).setBinCount(50).build();
assertEquals(-1.0, histogram.getMin(), 0.001);
assertEquals(1.0, histogram.getMax(), 0.001);
System.out.println("Result: " + histogram.getData());
assertEquals(50, histogram.getData().size());
assertEquals(2, histogram.getData().get(new BigDecimal("1.00")).get());
}
}