/* * This file is part of ALOE. * * ALOE is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * ALOE is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * You should have received a copy of the GNU General Public License * along with ALOE. If not, see <http://www.gnu.org/licenses/>. * * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl) */ package etc.aloe.controllers; import etc.aloe.data.ExampleSet; import etc.aloe.data.FeatureSpecification; import etc.aloe.data.Model; import etc.aloe.data.SegmentSet; import etc.aloe.processes.Balancing; import etc.aloe.processes.FeatureExtraction; import etc.aloe.processes.FeatureGeneration; import etc.aloe.processes.FeatureWeighting; import etc.aloe.processes.Training; import java.util.List; import java.util.Map; import weka.core.Instances; /** * Class that trains a model on some segmented data. * * @author Michael Brooks <mjbrooks@uw.edu> */ public class TrainingController { private SegmentSet segmentSet; private FeatureSpecification featureSpecification; private Model model; private FeatureExtraction featureExtractionImpl; private Training trainingImpl; private FeatureGeneration featureGenerationImpl; private Balancing balancingImpl; private FeatureWeighting featureWeightingImpl; private static final int NUM_TOP_FEATURES = 10; private List<String> topFeatures; private List<Map.Entry<String, Double>> featureWeights; private Instances featureValues; public List<String> getTopFeatures() { return topFeatures; } public List<Map.Entry<String, Double>> getFeatureWeights() { return featureWeights; } public void setFeatureWeightingImpl(FeatureWeighting featureWeightingImpl) { this.featureWeightingImpl = featureWeightingImpl; } public FeatureWeighting getFeatureWeightingImpl() { return featureWeightingImpl; } public void setSegmentSet(SegmentSet segments) { this.segmentSet = segments; } public FeatureSpecification getFeatureSpecification() { return this.featureSpecification; } public Model getModel() { return this.model; } public void run() { System.out.println("== Training Final Model =="); SegmentSet trainingSegments = segmentSet.onlyLabeled(); if (getBalancingImpl() != null) { trainingSegments = getBalancingImpl().balance(trainingSegments); } ExampleSet basicExamples = trainingSegments.getBasicExamples(); //Generate the features FeatureGeneration generation = getFeatureGenerationImpl(); this.featureSpecification = generation.generateFeatures(basicExamples); //Extract features FeatureExtraction extraction = getFeatureExtractionImpl(); ExampleSet examples = extraction.extractFeatures(basicExamples, this.featureSpecification); //Train the model Training training = getTrainingImpl(); this.featureValues = examples.getInstances(); this.model = training.train(examples); //Get the top features this.topFeatures = getFeatureWeightingImpl().getTopFeatures(examples, this.model, NUM_TOP_FEATURES); this.featureWeights = getFeatureWeightingImpl().getFeatureWeights(examples, this.model); } public FeatureExtraction getFeatureExtractionImpl() { return this.featureExtractionImpl; } public void setFeatureExtractionImpl(FeatureExtraction featureExtractor) { this.featureExtractionImpl = featureExtractor; } public Training getTrainingImpl() { return this.trainingImpl; } public void setTrainingImpl(Training training) { this.trainingImpl = training; } public void setFeatureGenerationImpl(FeatureGeneration featureGenerationImpl) { this.featureGenerationImpl = featureGenerationImpl; } public FeatureGeneration getFeatureGenerationImpl() { return this.featureGenerationImpl; } public void setBalancingImpl(Balancing balancing) { this.balancingImpl = balancing; } public Balancing getBalancingImpl() { return balancingImpl; } public Instances getFeatureValues() { return this.featureValues; } }