package quickml.supervised; import com.google.common.collect.Lists; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.BenchmarkTest; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; import quickml.supervised.crossValidation.ClassifierLossChecker; import quickml.supervised.crossValidation.SimpleCrossValidator; import quickml.supervised.crossValidation.data.FoldedData; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction; import quickml.supervised.tree.decisionTree.DecisionTreeBuilder; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.List; import java.util.zip.GZIPInputStream; import static org.junit.Assert.assertTrue; /** * Created by ian on 7/4/14. */ public class PredictiveAccuracyTests { private static final Logger logger = LoggerFactory.getLogger(PredictiveAccuracyTests.class); @Test public void irisTest() throws Exception { final FoldedData<ClassifierInstance> data = new FoldedData<ClassifierInstance>(loadIrisDataset(), 4, 4); final SimpleCrossValidator<RandomDecisionForest, ClassifierInstance> validator = new SimpleCrossValidator<RandomDecisionForest, ClassifierInstance>( new RandomDecisionForestBuilder<ClassifierInstance>(new DecisionTreeBuilder<>().minAttributeValueOccurences(10).maxDepth(12)), new ClassifierLossChecker<ClassifierInstance, RandomDecisionForest>(new ClassifierRMSELossFunction()), data); final double crossValidatedLoss = validator.getLossForModel(); double previousLoss = 0.62;//0.673; logger.info("Cross Validated Lost: {}", crossValidatedLoss); assertTrue(String.format("Current loss is %s, but previous loss was %s, this is a regression", crossValidatedLoss, previousLoss), crossValidatedLoss <= previousLoss*1.15); } public static List<ClassifierInstance> loadIrisDataset() throws IOException { final BufferedReader br = new BufferedReader(new InputStreamReader((new GZIPInputStream(BenchmarkTest.class.getResourceAsStream("iris.data.gz"))))); final List<ClassifierInstance> instances = Lists.newLinkedList(); String[] headings = new String[]{"sepal-length", "sepal-width", "petal-length", "petal-width"}; String line = br.readLine(); while (line != null) { String[] splitLine = line.split(","); AttributesMap attributes = AttributesMap.newHashMap(); for (int x = 0; x < splitLine.length - 1; x++) { attributes.put(headings[x], Double.valueOf((String)splitLine[x])); } instances.add(new ClassifierInstance(attributes, splitLine[splitLine.length - 1])); line = br.readLine(); } return instances; } }