package quickml.supervised.crossValidation.attributeImportance;
import org.junit.Before;
import org.junit.Test;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLogCVLossFunction;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierMSELossFunction;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction;
import java.util.ArrayList;
import static com.google.common.collect.Lists.newArrayList;
import static org.junit.Assert.assertEquals;
public class LossFunctionTrackerTest {
private LossFunctionTracker lossFunctionTracker;
private ClassifierLossFunction mseLossFunction;
private ClassifierLossFunction rmseLossFunction;
private ClassifierLossFunction logCVLossFunction;
@Before
public void setUp() throws Exception {
mseLossFunction = new ClassifierMSELossFunction();
rmseLossFunction = new ClassifierRMSELossFunction();
logCVLossFunction = new ClassifierLogCVLossFunction(0.0000001);
}
@Test
public void testLossFunctionsWithAtLeastOneFunctionAreValid() throws Exception {
lossFunctionTracker = new LossFunctionTracker(newArrayList(mseLossFunction));
assertEquals(1, lossFunctionTracker.lossFunctionNames().size());
}
@Test
public void testLossFunctionsWithTwoOrMoreAreValid() throws Exception {
lossFunctionTracker = new LossFunctionTracker(newArrayList(mseLossFunction, rmseLossFunction, logCVLossFunction));
assertEquals(3, lossFunctionTracker.lossFunctionNames().size());
}
@Test(expected = IllegalArgumentException.class)
public void testEmptyListIsInvalid() throws Exception {
ArrayList<ClassifierLossFunction> list = newArrayList();
lossFunctionTracker = new LossFunctionTracker(list);
}
}