package hex.deeplearning;
import static org.junit.Assert.assertTrue;
import org.junit.BeforeClass;
import org.junit.Test;
import water.util.ArrayUtils;
import java.util.Arrays;
import java.util.Random;
public class DropoutTest extends water.TestUtil {
@BeforeClass public static void setup() { stall_till_cloudsize(1); }
@Test
public void test() throws Exception {
final int units = 1000;
Storage.DenseVector a = new Storage.DenseVector(units);
double sum1=0, sum2=0, sum3=0, sum4=0;
final int loops = 10000;
for (int l = 0; l < loops; ++l) {
long seed = new Random().nextLong();
Dropout d = new Dropout(units, 0.3);
Arrays.fill(a.raw(), 1f);
d.randomlySparsifyActivation(a, seed);
sum1 += ArrayUtils.sum(a.raw());
d = new Dropout(units, 0.0);
Arrays.fill(a.raw(), 1f);
d.randomlySparsifyActivation(a, seed + 1);
sum2 += ArrayUtils.sum(a.raw());
d = new Dropout(units, 1.0);
Arrays.fill(a.raw(), 1f);
d.randomlySparsifyActivation(a, seed + 2);
sum3 += ArrayUtils.sum(a.raw());
d = new Dropout(units, 0.314);
d.fillBytes(seed+3);
// Log.info("loop: " + l + " sum4: " + sum4);
for (int i=0; i<units; ++i) {
if (d.unit_active(i)) {
sum4++;
assert(d.unit_active(i));
}
else assert(!d.unit_active(i));
}
// Log.info(d.toString());
}
sum1 /= loops;
sum2 /= loops;
sum3 /= loops;
sum4 /= loops;
assertTrue(Math.abs(sum1 - 700) < 1);
assertTrue(sum2 == units);
assertTrue(sum3 == 0);
assertTrue(Math.abs(sum4 - 686) < 1);
}
}