package quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders; import com.google.common.collect.Lists; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import java.util.Arrays; import java.util.List; import static org.junit.Assert.*; public class MonotonicConvergenceRecommenderTest { private MonotonicConvergenceRecommender recommender; @Before public void setUp() throws Exception { recommender = new MonotonicConvergenceRecommender(Arrays.asList(1, 5, 10, 20, 40), 0.1); } @Test public void testWeStopIfThresholdIsNotReached() throws Exception { List<Double> losses = Lists.newArrayList(); for (int i = 0; i < recommender.getValues().size(); i++) { double prevLoss = (i>0) ? losses.get(i-1) : 1.0; losses.add(prevLoss*2); if (!recommender.shouldContinue(losses)) break; } // assertEquals(5, losses.size()); } @Test public void testWeContinueIfWeHaventGoneOverTheTolerance() throws Exception { List<Double> losses = Lists.newArrayList(); double[] lossValue = new double[]{0.001, 0.002, 0.002001, 0.004, 0.005}; for (int i = 0; i < recommender.getValues().size(); i++) { losses.add(lossValue[i]); if (!recommender.shouldContinue(losses)) break; } System.out.println("losses = " + losses); assertEquals(3, losses.size()); } }