package hex.deeplearning; import junit.framework.Assert; import org.junit.Test; import java.util.Arrays; import java.util.Random; public class DropoutTest { @Test public void test() throws Exception { final int units = 1000; Neurons.DenseVector a = new Neurons.DenseVector(units); double sum1=0, sum2=0, sum3=0, sum4=0; final int loops = 10000; for (int l = 0; l < loops; ++l) { long seed = new Random().nextLong(); Dropout d = new Dropout(units, 0.3); Arrays.fill(a.raw(), 1f); d.randomlySparsifyActivation(a, seed); sum1 += water.util.Utils.sum(a.raw()); d = new Dropout(units, 0.0); Arrays.fill(a.raw(), 1f); d.randomlySparsifyActivation(a, seed + 1); sum2 += water.util.Utils.sum(a.raw()); d = new Dropout(units, 1.0); Arrays.fill(a.raw(), 1f); d.randomlySparsifyActivation(a, seed + 2); sum3 += water.util.Utils.sum(a.raw()); d = new Dropout(units, 0.314); d.fillBytes(seed+3); // Log.info("loop: " + l + " sum4: " + sum4); for (int i=0; i<units; ++i) { if (d.unit_active(i)) { sum4++; assert(d.unit_active(i)); } else assert(!d.unit_active(i)); } // Log.info(d.toString()); } sum1 /= loops; sum2 /= loops; sum3 /= loops; sum4 /= loops; Assert.assertTrue(Math.abs(sum1-700)<1); Assert.assertTrue(sum2 == units); Assert.assertTrue(sum3 == 0); Assert.assertTrue(Math.abs(sum4-686)<1); } }