/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * EnsembleClassifier.java * Copyright (C) 2003 Prem Melville * */ package weka.classifiers; import weka.core.*; import com.jmage.*; import java.util.*; /** * Abstract class for Ensemble Classifiers * * @author Prem Melville * @version $Revision: 1.3 $ */ public abstract class EnsembleClassifier extends DistributionClassifier implements AdditionalMeasureProducer{ /** the error on the training data */ protected double m_TrainError=0; /** the average error of the ensemble on the training data */ protected double m_TrainEnsembleError=0; /** the ensemble diversity computed in the training data */ protected double m_TrainEnsembleDiversity=0; /** Sum of ensemble weights */ protected double m_SumEnsembleWts=0; /** Vote weights of ensemble members */ protected double []m_EnsembleWts; /** Returns class predictions of each ensemble member */ public abstract double []getEnsemblePredictions(Instance instance) throws Exception; /** * Returns vote weights of ensemble members. * * @return vote weights of ensemble members */ public abstract double []getEnsembleWts(); /** Returns size of ensemble */ public abstract double getEnsembleSize(); /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(3); newVector.addElement("measureTrainError"); newVector.addElement("measureTrainEnsembleError"); newVector.addElement("measureTrainEnsembleDiversity"); return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareTo("measureTrainError") == 0) { return measureTrainError(); } else if (additionalMeasureName.compareTo("measureTrainEnsembleError") == 0) { return measureTrainEnsembleError(); } else if (additionalMeasureName.compareTo("measureTrainEnsembleDiversity") == 0) { return measureTrainEnsembleDiversity(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (DEC)"); } } /** * @return the error on the training data **/ public double measureTrainError(){ return m_TrainError; } /** * @return the average error of the ensemble on the training data */ public double measureTrainEnsembleError(){ return m_TrainEnsembleError; } /** * @return the ensemble diversity computed in the training data */ public double measureTrainEnsembleDiversity(){ return m_TrainEnsembleDiversity; } /** Initialize measures */ protected void initMeasures(){ m_SumEnsembleWts=0; m_TrainError=0; m_TrainEnsembleError=0; m_TrainEnsembleDiversity=0; } /** * Compute ensemble measures. * @param data training instances */ protected void computeEnsembleMeasures(Instances data) throws Exception{ for(int j=0; j<getEnsembleSize(); j++) m_SumEnsembleWts += m_EnsembleWts[j]; //DEBUG //System.out.println("Ensemble size = "+getEnsembleSize()); if(m_SumEnsembleWts == 0.0){ System.out.println("Ensemble wts sum to 0!"); for(int j=0; j<m_EnsembleWts.length; j++) System.out.print("\t"+m_EnsembleWts[j]); System.out.println(); } double totalInstanceWt=0; Instance curr; for (int i = 0; i < data.numInstances(); i++) { curr = data.instance(i); totalInstanceWt += curr.weight(); if(curr.weight() != 1.0) System.out.println(">>> Instance Weight = "+curr.weight()); updateEnsembleStats(classifyInstance(curr), curr, getEnsemblePredictions(curr)); } //DEBUG Assert.that(totalInstanceWt==data.numInstances(),"Instance wts don't total to num of instances!"); m_TrainError = 100.0 * (m_TrainError/totalInstanceWt); m_TrainEnsembleError = 100.0 * m_TrainEnsembleError/totalInstanceWt; m_TrainEnsembleDiversity = 100.0 * m_TrainEnsembleDiversity/totalInstanceWt; } /** * Update statistics for ensemble classifiers. * * @param pred ensemble prediction * @param instance training instance * @param ensemblePreds predictions of ensemble members */ protected void updateEnsembleStats(double pred, Instance instance, double []ensemblePreds){ //System.out.print("Updating Ensemble Stats..."); double sumEnsembleError = 0, sumEnsembleDiversity = 0; double actualClass = instance.classValue(); for(int i=0; i<getEnsembleSize(); i++){ if(actualClass != ensemblePreds[i]) sumEnsembleError += m_EnsembleWts[i]; //if member's prediction differs from the ensemble prediction, diversity increases if(pred != ensemblePreds[i]) sumEnsembleDiversity += m_EnsembleWts[i]; } if(pred != actualClass) m_TrainError += instance.weight(); m_TrainEnsembleError += ((sumEnsembleError/m_SumEnsembleWts)*instance.weight()); m_TrainEnsembleDiversity += ((sumEnsembleDiversity/m_SumEnsembleWts)*instance.weight()); } }