/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ExperimentDemo.java
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
*/
package wekaexamples.experiment;
import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.core.Range;
import weka.core.Utils;
import weka.experiment.ClassifierSplitEvaluator;
import weka.experiment.CrossValidationResultProducer;
import weka.experiment.Experiment;
import weka.experiment.InstancesResultListener;
import weka.experiment.PairedCorrectedTTester;
import weka.experiment.PairedTTester;
import weka.experiment.PropertyNode;
import weka.experiment.RandomSplitResultProducer;
import weka.experiment.RegressionSplitEvaluator;
import weka.experiment.ResultMatrix;
import weka.experiment.ResultMatrixPlainText;
import weka.experiment.SplitEvaluator;
import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import javax.swing.DefaultListModel;
/**
* A class for demonstrating the use of the Experiment class for one classifier
* and one or more datasets. Cross-validation or random splits are possible, as
* well as classification or regression.
*
* @author fracpete (fracpete at waikato dot ac dot nz)
* @version $Revision: 5628 $
*/
public class ExperimentDemo {
/**
* Expects the following parameters:
* <ul>
* <li>-classifier "classifier incl. parameters"</li>
* <li>-exptype "classification|regression"</li>
* <li>-splittype "crossvalidation|randomsplit"</li>
* <li>-runs "# of runs"</li>
* <li>-folds "# of cross-validation folds"</li>
* <li>-percentage "percentage for randomsplit"</li>
* <li>-result "arff file for storing the results"</li>
* <li>-t "dataset" (can be supplied multiple times)</li>
* </ul>
*
* @param args the commandline arguments
* @throws Exception if something goes wrong
*/
public static void main(String[] args) throws Exception {
// parameters provided?
if (args.length == 0) {
System.out.println(
"\nUsage: weka.examples.experiment.ExperimentDemo\n"
+ "\t -classifier <classifier incl. parameters>\n"
+ "\t -exptype <classification|regression>\n"
+ "\t -splittype <crossvalidation|randomsplit>\n"
+ "\t -runs <# of runs>\n"
+ "\t -folds <folds for CV>\n"
+ "\t -percentage <percentage for randomsplit>\n"
+ "\t -result <ARFF file for storing the results>\n"
+ "\t -t dataset (can be supplied multiple times)\n");
System.exit(1);
}
// 1. setup the experiment
System.out.println("Setting up...");
Experiment exp = new Experiment();
exp.setPropertyArray(new Classifier[0]);
exp.setUsePropertyIterator(true);
String option;
// classification or regression
option = Utils.getOption("exptype", args);
if (option.length() == 0)
throw new IllegalArgumentException("No experiment type provided!");
SplitEvaluator se = null;
Classifier sec = null;
boolean classification = false;
if (option.equals("classification")) {
classification = true;
se = new ClassifierSplitEvaluator();
sec = ((ClassifierSplitEvaluator) se).getClassifier();
}
else if (option.equals("regression")) {
se = new RegressionSplitEvaluator();
sec = ((RegressionSplitEvaluator) se).getClassifier();
}
else {
throw new IllegalArgumentException("Unknown experiment type '" + option + "'!");
}
// crossvalidation or randomsplit
option = Utils.getOption("splittype", args);
if (option.length() == 0)
throw new IllegalArgumentException("No split type provided!");
if (option.equals("crossvalidation")) {
CrossValidationResultProducer cvrp = new CrossValidationResultProducer();
option = Utils.getOption("folds", args);
if (option.length() == 0)
throw new IllegalArgumentException("No folds provided!");
cvrp.setNumFolds(Integer.parseInt(option));
cvrp.setSplitEvaluator(se);
PropertyNode[] propertyPath = new PropertyNode[2];
try {
propertyPath[0] = new PropertyNode(
se,
new PropertyDescriptor("splitEvaluator",
CrossValidationResultProducer.class),
CrossValidationResultProducer.class);
propertyPath[1] = new PropertyNode(
sec,
new PropertyDescriptor("classifier",
se.getClass()),
se.getClass());
}
catch (IntrospectionException e) {
e.printStackTrace();
}
exp.setResultProducer(cvrp);
exp.setPropertyPath(propertyPath);
}
else if (option.equals("randomsplit")) {
RandomSplitResultProducer rsrp = new RandomSplitResultProducer();
rsrp.setRandomizeData(true);
option = Utils.getOption("percentage", args);
if (option.length() == 0)
throw new IllegalArgumentException("No percentage provided!");
rsrp.setTrainPercent(Double.parseDouble(option));
rsrp.setSplitEvaluator(se);
PropertyNode[] propertyPath = new PropertyNode[2];
try {
propertyPath[0] = new PropertyNode(
se,
new PropertyDescriptor("splitEvaluator",
RandomSplitResultProducer.class),
RandomSplitResultProducer.class);
propertyPath[1] = new PropertyNode(
sec,
new PropertyDescriptor("classifier",
se.getClass()),
se.getClass());
}
catch (IntrospectionException e) {
e.printStackTrace();
}
exp.setResultProducer(rsrp);
exp.setPropertyPath(propertyPath);
}
else {
throw new IllegalArgumentException("Unknown split type '" + option + "'!");
}
// runs
option = Utils.getOption("runs", args);
if (option.length() == 0)
throw new IllegalArgumentException("No runs provided!");
exp.setRunLower(1);
exp.setRunUpper(Integer.parseInt(option));
// classifier
option = Utils.getOption("classifier", args);
if (option.length() == 0)
throw new IllegalArgumentException("No classifier provided!");
String[] options = Utils.splitOptions(option);
String classname = options[0];
options[0] = "";
Classifier c = (Classifier) Utils.forName(Classifier.class, classname, options);
exp.setPropertyArray(new Classifier[]{c});
// datasets
boolean data = false;
DefaultListModel model = new DefaultListModel();
do {
option = Utils.getOption("t", args);
if (option.length() > 0) {
File file = new File(option);
if (!file.exists())
throw new IllegalArgumentException("File '" + option + "' does not exist!");
data = true;
model.addElement(file);
}
}
while (option.length() > 0);
if (!data)
throw new IllegalArgumentException("No data files provided!");
exp.setDatasets(model);
// result
option = Utils.getOption("result", args);
if (option.length() == 0)
throw new IllegalArgumentException("No result file provided!");
InstancesResultListener irl = new InstancesResultListener();
irl.setOutputFile(new File(option));
exp.setResultListener(irl);
// 2. run experiment
System.out.println("Initializing...");
exp.initialize();
System.out.println("Running...");
exp.runExperiment();
System.out.println("Finishing...");
exp.postProcess();
// 3. calculate statistics and output them
System.out.println("Evaluating...");
PairedTTester tester = new PairedCorrectedTTester();
Instances result = new Instances(new BufferedReader(new FileReader(irl.getOutputFile())));
tester.setInstances(result);
tester.setSortColumn(-1);
tester.setRunColumn(result.attribute("Key_Run").index());
if (classification)
tester.setFoldColumn(result.attribute("Key_Fold").index());
tester.setResultsetKeyColumns(
new Range(
""
+ (result.attribute("Key_Dataset").index() + 1)));
tester.setDatasetKeyColumns(
new Range(
""
+ (result.attribute("Key_Scheme").index() + 1)
+ ","
+ (result.attribute("Key_Scheme_options").index() + 1)
+ ","
+ (result.attribute("Key_Scheme_version_ID").index() + 1)));
tester.setResultMatrix(new ResultMatrixPlainText());
tester.setDisplayedResultsets(null);
tester.setSignificanceLevel(0.05);
tester.setShowStdDevs(true);
// fill result matrix (but discarding the output)
if (classification)
tester.multiResultsetFull(0, result.attribute("Percent_correct").index());
else
tester.multiResultsetFull(0, result.attribute("Correlation_coefficient").index());
// output results for reach dataset
System.out.println("\nResult:");
ResultMatrix matrix = tester.getResultMatrix();
for (int i = 0; i < matrix.getColCount(); i++) {
System.out.println(matrix.getColName(i));
System.out.println(" Perc. correct: " + matrix.getMean(i, 0));
System.out.println(" StdDev: " + matrix.getStdDev(i, 0));
}
}
}