/* * This file is part of ALOE. * * ALOE is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * ALOE is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * You should have received a copy of the GNU General Public License * along with ALOE. If not, see <http://www.gnu.org/licenses/>. * * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl) */ package etc.aloe.controllers; import etc.aloe.data.EvaluationReport; import etc.aloe.data.ExampleSet; import etc.aloe.data.FeatureSpecification; import etc.aloe.data.Model; import etc.aloe.data.Predictions; import etc.aloe.data.Segment; import etc.aloe.data.SegmentSet; import etc.aloe.processes.Balancing; import etc.aloe.processes.CrossValidationPrep; import etc.aloe.processes.CrossValidationSplit; import etc.aloe.processes.FeatureExtraction; import etc.aloe.processes.FeatureGeneration; import etc.aloe.processes.LabelMapping; import etc.aloe.processes.Training; /** * Class that performs cross validation on segmented data and and produces an * evaluation report. * * @author Michael Brooks <mjbrooks@uw.edu> */ public class CrossValidationController { private int folds; private EvaluationReport evaluationReport; private SegmentSet segmentSet; private FeatureGeneration featureGenerationImpl; private FeatureExtraction featureExtractionImpl; private Training trainingImpl; private Balancing balancingImpl; private double falsePositiveCost = 1; private double falseNegativeCost = 1; private boolean balanceTestSet; private LabelMapping mappingImpl; public EvaluationReport getEvaluationReport() { return evaluationReport; } public void setSegmentSet(SegmentSet segments) { this.segmentSet = segments; } public CrossValidationController() { } public void setFolds(int folds) { this.folds = folds; } public void setCosts(double falsePositiveCost, double falseNegativeCost) { this.falsePositiveCost = falsePositiveCost; this.falseNegativeCost = falseNegativeCost; } public void run() { if (this.folds > 0) { System.out.println("== " + this.folds + "-Fold Cross Validation =="); segmentSet = segmentSet.onlyLabeled(); //Prepare for cross validation System.out.println("Randomizing and stratifying segments."); CrossValidationPrep<Segment> validationPrep = new CrossValidationPrep<Segment>(); validationPrep.randomize(segmentSet.getSegments()); segmentSet.setSegments(validationPrep.stratify(segmentSet.getSegments(), folds)); evaluationReport = new EvaluationReport(this.folds + " Cross Validation", falsePositiveCost, falseNegativeCost); for (int foldIndex = 0; foldIndex < this.folds; foldIndex++) { System.out.println("- Starting fold " + (foldIndex + 1)); //Split the data CrossValidationSplit<Segment> split = new CrossValidationSplit<Segment>(); System.out.println("- Splitting out training set"); SegmentSet trainingSegments = new SegmentSet(); trainingSegments.setSegments(split.getTrainingForFold(segmentSet.getSegments(), foldIndex, this.folds)); if (getBalancingImpl() != null) { trainingSegments = getBalancingImpl().balance(trainingSegments); } System.out.println("- Extracting basic features from training set"); ExampleSet basicTrainingExamples = trainingSegments.getBasicExamples(); trainingSegments = null; FeatureGeneration generation = getFeatureGenerationImpl(); System.out.println("- Generating features"); FeatureSpecification spec = generation.generateFeatures(basicTrainingExamples); FeatureExtraction extraction = getFeatureExtractionImpl(); System.out.println("- Extracting features from training set"); ExampleSet trainingSet = extraction.extractFeatures(basicTrainingExamples, spec); basicTrainingExamples = null; Training training = getTrainingImpl(); Model model = training.train(trainingSet); trainingSet = null; System.out.println("- Splitting out test set"); SegmentSet testingSegments = new SegmentSet(); testingSegments.setSegments(split.getTestingForFold(segmentSet.getSegments(), foldIndex, this.folds)); if (getBalancingImpl() != null && balanceTestSet) { testingSegments = getBalancingImpl().balance(testingSegments); } System.out.println("- Extracting basic features from test set"); ExampleSet basicTestingExamples = testingSegments.getBasicExamples(); System.out.println("- Extracting features from test set"); ExampleSet testingSet = extraction.extractFeatures(basicTestingExamples, spec); basicTestingExamples = null; Predictions predictions = model.getPredictions(testingSet); EvaluationReport report = new EvaluationReport("Fold " + (foldIndex + 1), falsePositiveCost, falseNegativeCost); report.addPredictions(predictions); LabelMapping mapping = getMappingImpl(); mapping.map(predictions, testingSegments); report.addLabeledTestData(testingSegments); evaluationReport.addPartial(report); int numCorrect = report.getTrueNegativeCount() + report.getTruePositiveCount(); System.out.println("- Fold " + (foldIndex + 1) + " completed (" + numCorrect + "/" + testingSet.size() + " correct)."); System.out.println(); } } else { System.out.println("== Skipping Cross Validation =="); } } public FeatureGeneration getFeatureGenerationImpl() { return this.featureGenerationImpl; } public void setFeatureGenerationImpl(FeatureGeneration featureGenerator) { this.featureGenerationImpl = featureGenerator; } public FeatureExtraction getFeatureExtractionImpl() { return this.featureExtractionImpl; } public void setFeatureExtractionImpl(FeatureExtraction featureExtractor) { this.featureExtractionImpl = featureExtractor; } public Training getTrainingImpl() { return this.trainingImpl; } public void setTrainingImpl(Training training) { this.trainingImpl = training; } public Balancing getBalancingImpl() { return this.balancingImpl; } public void setBalancingImpl(Balancing balancing) { this.balancingImpl = balancing; } public void setBalanceTestSet(boolean balanceTestSet) { this.balanceTestSet = balanceTestSet; } public LabelMapping getMappingImpl() { return this.mappingImpl; } public void setMappingImpl(LabelMapping mapping) { this.mappingImpl = mapping; } }