/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* Copyright (C) 2002 University of Waikato
*/
package weka.classifiers.meta;
import weka.classifiers.AbstractClassifierTest;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.NoSupportForMissingValuesException;
import weka.core.SelectedTag;
import weka.core.UnsupportedAttributeTypeException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveType;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import junit.framework.Test;
import junit.framework.TestSuite;
/**
* Tests ThresholdSelector. Run from the command line with:<p>
* java weka.classifiers.meta.ThresholdSelectorTest
*
* @author <a href="mailto:len@reeltwo.com">Len Trigg</a>
* @author FracPete (fracpete at waikato dot ac dot nz)
* @version $Revision: 1.8 $
*/
public class ThresholdSelectorTest
extends AbstractClassifierTest {
private static double[] DIST1 = new double [] {
0.25,
0.375,
0.5,
0.625,
0.75,
0.875,
1.0
};
/** A set of instances to test with */
protected transient Instances m_Instances;
/** Used to generate various types of predictions */
protected transient EvaluationUtils m_Evaluation;
public ThresholdSelectorTest(String name) {
super(name);
}
/**
* Called by JUnit before each test method. This implementation creates
* the default classifier to test and loads a test set of Instances.
*
* @exception Exception if an error occurs reading the example instances.
*/
protected void setUp() throws Exception {
super.setUp();
m_Evaluation = new EvaluationUtils();
m_Instances = new Instances(
new BufferedReader(
new InputStreamReader(
ClassLoader.getSystemResourceAsStream(
"weka/classifiers/data/ClassifierTest.arff"))));
}
/** Creates a default ThresholdSelector */
public Classifier getClassifier() {
return getClassifier(DIST1);
}
/** Called by JUnit after each test method */
protected void tearDown() {
super.tearDown();
m_Evaluation = null;
}
/**
* Creates a ThresholdSelector that returns predictions from a
* given distribution
*/
public Classifier getClassifier(double[] dist) {
return getClassifier(new ThresholdSelectorDummyClassifier(dist));
}
/**
* Creates a ThresholdSelector with the given subclassifier.
*
* @param classifier a <code>Classifier</code> to use as the
* subclassifier
* @return a new <code>ThresholdSelector</code>
*/
public Classifier getClassifier(Classifier classifier) {
ThresholdSelector t = new ThresholdSelector();
t.setClassifier(classifier);
return t;
}
/**
* Builds a model using the current classifier using the first
* half of the current data for training, and generates a bunch of
* predictions using the remaining half of the data for testing.
*
* @return a <code>FastVector</code> containing the predictions.
*/
protected FastVector useClassifier() throws Exception {
Classifier dc = null;
int tot = m_Instances.numInstances();
int mid = tot / 2;
Instances train = null;
Instances test = null;
try {
train = new Instances(m_Instances, 0, mid);
test = new Instances(m_Instances, mid, tot - mid);
dc = m_Classifier;
} catch (Exception ex) {
ex.printStackTrace();
fail("Problem setting up to use classifier: " + ex);
}
int counter = 0;
do {
try {
return m_Evaluation.getTrainTestPredictions(dc, train, test);
} catch (UnsupportedAttributeTypeException ex) {
SelectedTag tag = null;
boolean invert = false;
String msg = ex.getMessage();
if ((msg.indexOf("string") != -1) &&
(msg.indexOf("attributes") != -1)) {
System.err.println("\nDeleting string attributes.");
tag = new SelectedTag(Attribute.STRING,
RemoveType.TAGS_ATTRIBUTETYPE);
} else if ((msg.indexOf("only") != -1) &&
(msg.indexOf("nominal") != -1)) {
System.err.println("\nDeleting non-nominal attributes.");
tag = new SelectedTag(Attribute.NOMINAL,
RemoveType.TAGS_ATTRIBUTETYPE);
invert = true;
} else if ((msg.indexOf("only") != -1) &&
(msg.indexOf("numeric") != -1)) {
System.err.println("\nDeleting non-numeric attributes.");
tag = new SelectedTag(Attribute.NUMERIC,
RemoveType.TAGS_ATTRIBUTETYPE);
invert = true;
} else {
throw ex;
}
RemoveType attFilter = new RemoveType();
attFilter.setAttributeType(tag);
attFilter.setInvertSelection(invert);
attFilter.setInputFormat(train);
train = Filter.useFilter(train, attFilter);
attFilter.batchFinished();
test = Filter.useFilter(test, attFilter);
counter++;
if (counter > 2) {
throw ex;
}
} catch (NoSupportForMissingValuesException ex2) {
System.err.println("\nReplacing missing values.");
ReplaceMissingValues rmFilter = new ReplaceMissingValues();
rmFilter.setInputFormat(train);
train = Filter.useFilter(train, rmFilter);
rmFilter.batchFinished();
test = Filter.useFilter(test, rmFilter);
} catch (IllegalArgumentException ex3) {
String msg = ex3.getMessage();
if (msg.indexOf("Not enough instances") != -1) {
System.err.println("\nInflating training data.");
Instances trainNew = new Instances(train);
for (int i = 0; i < train.numInstances(); i++) {
trainNew.add(train.instance(i));
}
train = trainNew;
} else {
throw ex3;
}
}
} while (true);
}
public void testRangeNone() throws Exception {
int cind = 0;
((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE));
((ThresholdSelector)m_Classifier).setRangeCorrection(new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE));
FastVector result = null;
m_Instances.setClassIndex(1);
result = useClassifier();
assertTrue(result.size() != 0);
double minp = 0;
double maxp = 0;
for (int i = 0; i < result.size(); i++) {
NominalPrediction p = (NominalPrediction)result.elementAt(i);
double prob = p.distribution()[cind];
if ((i == 0) || (prob < minp)) minp = prob;
if ((i == 0) || (prob > maxp)) maxp = prob;
}
assertTrue("Upper limit shouldn't increase", maxp <= 1.0);
assertTrue("Lower limit shouldn'd decrease", minp >= 0.25);
}
public void testDesignatedClass() throws Exception {
int cind = 0;
for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) {
((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE));
m_Instances.setClassIndex(1);
FastVector result = useClassifier();
assertTrue(result.size() != 0);
}
}
public void testEvaluationMode() throws Exception {
int cind = 0;
for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) {
((ThresholdSelector)m_Classifier).setEvaluationMode(new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL));
m_Instances.setClassIndex(1);
FastVector result = useClassifier();
assertTrue(result.size() != 0);
}
}
public void testNumXValFolds() throws Exception {
try {
((ThresholdSelector)m_Classifier).setNumXValFolds(0);
fail("Expected IllegalArgumentException");
} catch (IllegalArgumentException e) {
// OK
}
int cind = 0;
for (int i = 2; i < 20; i += 2) {
((ThresholdSelector)m_Classifier).setNumXValFolds(i);
m_Instances.setClassIndex(1);
FastVector result = useClassifier();
assertTrue(result.size() != 0);
}
}
public static Test suite() {
return new TestSuite(ThresholdSelectorTest.class);
}
public static void main(String[] args){
junit.textui.TestRunner.run(suite());
}
}