/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* TestEnsembleClassifier
* Copyright (C) 2003 Prem Melville
*
*/
package weka.classifiers.meta;
import weka.classifiers.*;
import java.util.*;
import weka.core.*;
/**
* This class is for testing Ensemble evaluation
*/
public class TestEnsembleClassifier extends EnsembleClassifier{
protected int m_NumIterations=21;
protected Random random = new Random();
/**
*
* @param data the training data to be used for generating the
* bagged classifier.
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances data) throws Exception {
//Initialize measures
initMeasures();
//initialize ensemble wts to be equal
m_EnsembleWts = new double [m_NumIterations];
for(int j=0; j<m_NumIterations; j++)
m_EnsembleWts[j] = 1.0;
computeEnsembleMeasures(data);
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return preedicted class probability distribution
* @exception Exception if distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double [] sums = new double [instance.numClasses()];
double [] preds = getEnsemblePredictions(instance);
for (int i = 0; i < m_NumIterations; i++) {
sums[(int)preds[i]]++;
}
Utils.normalize(sums);
return sums;
}
/** Returns class predictions of each ensemble member */
public double []getEnsemblePredictions(Instance instance) throws Exception{
double preds[] = new double [m_NumIterations];
double actualClass;
if(instance.classIsMissing()) {
actualClass = random.nextInt(instance.numClasses());
//for(int i=0; i<m_NumIterations; i++) preds[i] = actualClass;
for(int i=0; i<m_NumIterations; i++) preds[i] = 1.0;
}
else {
actualClass = instance.classValue();
for(int i=0; i<m_NumIterations; i++){
if(random.nextFloat()<0.4)
preds[i] = actualClass;
else
preds[i] = (actualClass+1)%instance.numClasses();
}
}
return preds;
}
/**
* Returns vote weights of ensemble members.
*
* @return vote weights of ensemble members
*/
public double []getEnsembleWts(){
return m_EnsembleWts;
}
/** Returns size of ensemble */
public double getEnsembleSize(){
return m_NumIterations;
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.
evaluateModel(new Bagging(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}