package edu.usc.cssl.tacit.classify.naivebayes.services;
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Logger;
import org.apache.commons.math3.stat.inference.AlternativeHypothesis;
import org.apache.commons.math3.stat.inference.BinomialTest;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.BshInterpreter;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.ProgressMessageLogFormatter;
import edu.usc.cssl.tacit.common.ui.views.ConsoleView;
/**
* Classify documents, run trials, print statistics from a vector file.
*
* @author Andrew McCallum <a
* href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
public abstract class Vectors2Classify {
static BshInterpreter interpreter = new BshInterpreter();
static ArrayList<String> result = new ArrayList<String>();
private static Logger logger = MalletLogger
.getLogger(Vectors2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger
.getLogger(Vectors2Classify.class.getName() + "-pl");
private static ArrayList<String> classifierTrainerStrings = new ArrayList<String>();
private static boolean[][] ReportOptions = new boolean[3][4];
private static String[][] ReportOptionArgs = new String[3][4]; // arg in
// dataset:reportOption=arg
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption {
static final String[] dataOptions = { "train", "test", "validation" };
static final String[] reportOptions = { "accuracy", "f1", "confusion",
"raw" };
static final int train = 0;
static final int test = 1;
static final int validation = 2;
static final int accuracy = 0;
static final int f1 = 1;
static final int confusion = 2;
static final int raw = 3;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings(
Vectors2Classify.class,
"report",
"[train|test|validation]:[accuracy|f1:label|confusion|raw]",
true,
new String[] { "test:accuracy", "test:confusion", "train:accuracy" },
"", null) {
@Override
public void postParsing(CommandOption.List list) {
java.lang.String defaultRawFormatting = "siw";
for (int argi = 0; argi < this.value.length; argi++) {
// convert options like --report train:accuracy --report
// test:f1=labelA to
// boolean array of options.
// first, split the argument at semicolon.
// System.out.println(argi + " " + this.value[argi]);
java.lang.String arg = this.value[argi];
java.lang.String fields[] = arg.split("[:=]");
java.lang.String dataSet = fields[0];
java.lang.String reportOption = fields[1];
java.lang.String reportOptionArg = null;
if (fields.length >= 3) {
reportOptionArg = fields[2];
}
// System.out.println("Report option arg " + reportOptionArg);
// find the datasource (test,train,validation)
boolean foundDataSource = false;
int i = 0;
for (; i < ReportOption.dataOptions.length; i++) {
if (dataSet.equals(ReportOption.dataOptions[i])) {
foundDataSource = true;
break;
}
}
if (!foundDataSource) {
throw new IllegalArgumentException("Unknown argument = "
+ dataSet + " in --report " + this.value[argi]);
}
// find the report option (accuracy, f1, confusion, raw)
boolean foundReportOption = false;
int j = 0;
for (; j < ReportOption.reportOptions.length; j++) {
if (reportOption.equals(ReportOption.reportOptions[j])) {
foundReportOption = true;
break;
}
}
if (!foundReportOption) {
throw new IllegalArgumentException("Unknown argument = "
+ reportOption + " in --report " + this.value[argi]);
}
// Mark the (dataSet,reportionOption) pair as selected
ReportOptions[i][j] = true;
if (j == ReportOption.f1) {
// make sure a label was specified for f1
if (reportOptionArg == null) {
throw new IllegalArgumentException(
"F1 must have label argument in --report "
+ this.value[argi]);
}
// Pass through the string argument
ReportOptionArgs[i][j] = reportOptionArg;
} else if (reportOptionArg != null) {
throw new IllegalArgumentException(
"No arguments after = allowed in --report "
+ this.value[argi]);
}
}
}
};
static CommandOption.String trainerConstructor = new CommandOption.String(
Vectors2Classify.class,
"trainer",
"ClassifierTrainer constructor",
true,
"new NaiveBayesTrainer()",
"Java code for the constructor used to create a ClassifierTrainer. "
+ "If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended."
+ "You may use this option mutiple times to compare multiple classifiers.",
null) {
@Override
public void postParsing(CommandOption.List list) {
classifierTrainerStrings.add(this.value);
}
};
static CommandOption.String outputFile = new CommandOption.String(
Vectors2Classify.class,
"output-classifier",
"FILENAME",
true,
"classifier.mallet",
"The filename in which to write the classifier after it has been trained.",
null);
/*
* static CommandOption.String pipeFile = new CommandOption.String
* (Vectors2Classify.class, "output-pipe", "FILENAME", true,
* "classifier_pipe.mallet",
* "The filename in which to write the classifier's instancePipe after it has been trained."
* , null);
*/
static CommandOption.String inputFile = new CommandOption.String(
Vectors2Classify.class,
"input",
"FILENAME",
true,
"text.vectors",
"The filename from which to read the list of training instances. Use - for stdin.",
null);
static CommandOption.String trainingFile = new CommandOption.String(
Vectors2Classify.class,
"training-file",
"FILENAME",
true,
"text.vectors",
"Read the training set instance list from this file. "
+ "If this is specified, the input file parameter is ignored",
null);
static CommandOption.String testFile = new CommandOption.String(
Vectors2Classify.class,
"testing-file",
"FILENAME",
true,
"text.vectors",
"Read the test set instance list to this file. "
+ "If this option is specified, the training-file parameter must be specified and "
+ " the input-file parameter is ignored", null);
static CommandOption.String validationFile = new CommandOption.String(
Vectors2Classify.class,
"validation-file",
"FILENAME",
true,
"text.vectors",
"Read the validation set instance list to this file."
+ "If this option is specified, the training-file parameter must be specified and "
+ "the input-file parameter is ignored", null);
static CommandOption.Double trainingProportionOption = new CommandOption.Double(
Vectors2Classify.class, "training-portion", "DECIMAL", true, 1.0,
"The fraction of the instances that should be used for training.",
null);
static CommandOption.Double validationProportionOption = new CommandOption.Double(
Vectors2Classify.class,
"validation-portion",
"DECIMAL",
true,
0.0,
"The fraction of the instances that should be used for validation.",
null);
static CommandOption.Double unlabeledProportionOption = new CommandOption.Double(
Vectors2Classify.class,
"unlabeled-portion",
"DECIMAL",
true,
0.0,
"The fraction of the training instances that should have their labels hidden. "
+ "Note that these are taken out of the training-portion, not allocated separately.",
null);
static CommandOption.Integer randomSeedOption = new CommandOption.Integer(
Vectors2Classify.class,
"random-seed",
"INTEGER",
true,
0,
"The random seed for randomly selecting a proportion of the instance list for training",
null);
static CommandOption.Integer numTrialsOption = new CommandOption.Integer(
Vectors2Classify.class, "num-trials", "INTEGER", true, 1,
"The number of random train/test splits to perform", null);
static CommandOption.Object classifierEvaluatorOption = new CommandOption.Object(
Vectors2Classify.class, "classifier-evaluator", "CONSTRUCTOR",
true, null,
"Java code for constructing a ClassifierEvaluating object", null);
// static CommandOption.Boolean printTrainAccuracyOption = new
// CommandOption.Boolean
// (Vectors2Classify.class, "print-train-accuracy", "true|false", true,
// true,
// "After training, run the resulting classifier on the instances included in training, "
// +"and print the accuracy", null);
//
// static CommandOption.Boolean printTestAccuracyOption = new
// CommandOption.Boolean
// (Vectors2Classify.class, "print-test-accuracy", "true|false", true, true,
// "After training, run the resulting classifier on the instances not included in training, "
// +"and print the accuracy", null);
static CommandOption.Integer verbosityOption = new CommandOption.Integer(
Vectors2Classify.class,
"verbosity",
"INTEGER",
true,
-1,
"The level of messages to print: 0 is silent, 8 is most verbose. "
+ "Levels 0-8 correspond to the java.logger predefined levels "
+ "off, severe, warning, info, config, fine, finer, finest, all. "
+ "The default value is taken from the mallet logging.properties file,"
+ " which currently defaults to INFO level (3)", null);
static CommandOption.Boolean noOverwriteProgressMessagesOption = new CommandOption.Boolean(
Vectors2Classify.class,
"noOverwriteProgressMessages",
"true|false",
false,
false,
"Suppress writing-in-place on terminal for progess messages - repetitive messages "
+ "of which only the latest is generally of interest", null);
static CommandOption.Integer crossValidation = new CommandOption.Integer(
Vectors2Classify.class, "cross-validation", "INT", true, 0,
"The number of folds for cross-validation (DEFAULT=0).", null);
public static ArrayList<String> main(String[] args) throws bsh.EvalError,
java.io.IOException {
result.clear();
classifierTrainerStrings = new ArrayList<String>();
ReportOptions = new boolean[][]{{false, false, false, false}, {false, false, false, false}, {false, false, false, false}};
double pvalue = 0;
// Process the command-line options
CommandOption
.setSummary(
Vectors2Classify.class,
"A tool for training, saving and printing diagnostics from a classifier on vectors.");
CommandOption.process(Vectors2Classify.class, args);
// handle default trainer here for now; default argument processing
// doesn't work
if (!trainerConstructor.wasInvoked()) {
classifierTrainerStrings.add("new NaiveBayesTrainer()");
}
if (!report.wasInvoked()) {
ReportOptions = new boolean[][]{{true, false, false, false}, {true, false, true, false}, {false, false, false, false}};
//report.postParsing(null); // force postprocessing of default value
}
int verbosity = verbosityOption.value;
Logger rootLogger = ((MalletLogger) progressLogger).getRootLogger();
if (verbosityOption.wasInvoked()) {
rootLogger.setLevel(MalletLogger.LoggingLevels[verbosity]);
}
if (noOverwriteProgressMessagesOption.value == false) {
// install special formatting for progress messages
// find console handler on root logger; change formatter to one
// that knows about progress messages
Handler[] handlers = rootLogger.getHandlers();
for (int i = 0; i < handlers.length; i++) {
if (handlers[i] instanceof ConsoleHandler) {
handlers[i].setFormatter(new ProgressMessageLogFormatter());
}
}
}
boolean separateIlists = testFile.wasInvoked()
|| trainingFile.wasInvoked() || validationFile.wasInvoked();
InstanceList ilist = null;
InstanceList testFileIlist = null;
InstanceList trainingFileIlist = null;
InstanceList validationFileIlist = null;
if (!separateIlists) { // normal case, --input-file specified
// Read in the InstanceList, from stdin if the input filename is
// "-".
ilist = InstanceList.load(new File(inputFile.value));
//ilist = new InstanceList(ilist.getAlphabet(), ilist.getAlphabet());
} else { // user specified separate files for testing and training sets.
trainingFileIlist = InstanceList.load(new File(trainingFile.value));
logger.info("Training vectors loaded from " + trainingFile.value);
if (testFile.wasInvoked()) {
testFileIlist = InstanceList.load(new File(testFile.value));
logger.info("Testing vectors loaded from " + testFile.value);
if (!testFileIlist.getPipe().alphabetsMatch(
trainingFileIlist.getPipe())) {
throw new RuntimeException(trainingFileIlist.getPipe()
.getDataAlphabet()
+ "\n"
+ testFileIlist.getPipe().getDataAlphabet()
+ "\n"
+ trainingFileIlist.getPipe().getTargetAlphabet()
+ "\n"
+ testFileIlist.getPipe().getTargetAlphabet()
+ "\n"
+ "Training and testing alphabets don't match!\n");
}
}
if (validationFile.wasInvoked()) {
validationFileIlist = InstanceList.load(new File(
validationFile.value));
logger.info("validation vectors loaded from "
+ validationFile.value);
if (!validationFileIlist.getPipe().alphabetsMatch(
trainingFileIlist.getPipe())) {
throw new RuntimeException(
trainingFileIlist.getPipe().getDataAlphabet()
+ "\n"
+ validationFileIlist.getPipe()
.getDataAlphabet()
+ "\n"
+ trainingFileIlist.getPipe()
.getTargetAlphabet()
+ "\n"
+ validationFileIlist.getPipe()
.getTargetAlphabet()
+ "\n"
+ "Training and validation alphabets don't match!\n");
}
} else {
validationFileIlist = new InstanceList(
new cc.mallet.pipe.Noop());
}
}
if (crossValidation.wasInvoked()
&& trainingProportionOption.wasInvoked()) {
logger.warning("Both --cross-validation and --training-portion were invoked. Using cross validation with "
+ crossValidation.value + " folds.");
}
if (crossValidation.wasInvoked()
&& validationProportionOption.wasInvoked()) {
logger.warning("Both --cross-validation and --validation-portion were invoked. Using cross validation with "
+ crossValidation.value + " folds.");
}
if (crossValidation.wasInvoked() && numTrialsOption.wasInvoked()) {
logger.warning("Both --cross-validation and --num-trials were invoked. Using cross validation with "
+ crossValidation.value + " folds.");
}
int numTrials;
if (crossValidation.wasInvoked()) {
numTrials = crossValidation.value;
} else {
numTrials = numTrialsOption.value;
}
Random r = randomSeedOption.wasInvoked() ? new Random(
randomSeedOption.value) : new Random();
int numTrainers = classifierTrainerStrings.size();
double trainAccuracy[][] = new double[numTrainers][numTrials];
double testAccuracy[][] = new double[numTrainers][numTrials];
double validationAccuracy[][] = new double[numTrainers][numTrials];
String trainConfusionMatrix[][] = new String[numTrainers][numTrials];
String testConfusionMatrix[][] = new String[numTrainers][numTrials];
String validationConfusionMatrix[][] = new String[numTrainers][numTrials];
double t = trainingProportionOption.value;
double v = validationProportionOption.value;
if (!separateIlists) {
if (crossValidation.wasInvoked()) {
logger.info("Cross-validation folds = " + crossValidation.value);
} else {
logger.info("Training portion = " + t);
logger.info(" Unlabeled training sub-portion = "
+ unlabeledProportionOption.value);
logger.info("Validation portion = " + v);
logger.info("Testing portion = " + (1 - v - t));
}
}
// for (int i=0; i<3; i++){
// for (int j=0; j<4; j++){
// System.out.print(" " + ReportOptions[i][j]);
// }
// System.out.println();
// }
CrossValidationIterator cvIter;
if (crossValidation.wasInvoked()) {
if (crossValidation.value < 2) {
throw new RuntimeException(
"At least two folds (set with --cross-validation) are required for cross validation");
}
//System.out.println("Alphabets : "+ ilist.getDataAlphabet() +":"+ ilist.getTargetAlphabet());
cvIter = new CrossValidationIterator(ilist, crossValidation.value,
r);
} else {
cvIter = null;
}
String[] trainerNames = new String[numTrainers];
for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
System.out.println("\n-------------------- Trial " + trialIndex
+ " --------------------\n");
InstanceList[] ilists;
BitSet unlabeledIndices = null;
if (!separateIlists) {
if (crossValidation.wasInvoked()) {
InstanceList[] cvSplit = cvIter.next();
ilists = new InstanceList[3];
ilists[0] = cvSplit[0];
ilists[1] = cvSplit[1];
ilists[2] = cvSplit[0].cloneEmpty();
} else {
ilists = ilist.split(r, new double[] { t, 1 - t - v, v });
}
} else {
ilists = new InstanceList[3];
ilists[0] = trainingFileIlist;
ilists[1] = testFileIlist;
ilists[2] = validationFileIlist;
}
if (unlabeledProportionOption.value > 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(),
unlabeledProportionOption.value);
// InfoGain ig = new InfoGain (ilists[0]);
// int igl = Math.min (10, ig.numLocations());
// for (int i = 0; i < igl; i++)
// System.out.println
// ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
// ig.print();
// FeatureSelection selectedFeatures = new FeatureSelection (ig,
// 8000);
// ilists[0].setFeatureSelection (selectedFeatures);
// OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer
// (ilists[0]);
// orfi.induceFeatures (ilists[0], false, true);
// System.out.println
// ("Training with "+ilists[0].size()+" instances");
long time[] = new long[numTrainers];
for (int c = 0; c < numTrainers; c++) {
time[c] = System.currentTimeMillis();
ClassifierTrainer trainer = getTrainer(classifierTrainerStrings
.get(c));
trainer.setValidationInstances(ilists[2]);
// ConsoleView.writeInConsole("Trial " + trialIndex + " Training " + trainer + " with " + ilists[0].size() + " instances");
ConsoleView.printlInConsoleln("Training " + trainer + " with " + ilists[0].size() + " instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier = trainer.train(ilists[0]);
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
//ConsoleView.writeInConsole("Trial " + trialIndex + " Training " + trainer.toString() + " finished");
ConsoleView.printlInConsoleln("Training " + trainer.toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial(classifier, ilists[0]);
// assert (ilists[1].size() > 0);
Trial testTrial = new Trial(classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
// gdruck - only perform evaluation if requested in report
// options
if (ReportOptions[ReportOption.train][ReportOption.confusion]
&& ilists[0].size() > 0)
trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix(
trainTrial).toString();
if (ReportOptions[ReportOption.test][ReportOption.confusion]
&& ilists[1].size() > 0)
testConfusionMatrix[c][trialIndex] = new ConfusionMatrix(
testTrial).toString();
if (ReportOptions[ReportOption.validation][ReportOption.confusion]
&& ilists[2].size() > 0)
validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix(
validationTrial).toString();
// gdruck - only perform evaluation if requested in report
// options
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
validationAccuracy[c][trialIndex] = validationTrial
.getAccuracy();
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (numTrainers > 1)
filename = filename + trainer.toString();
if (numTrials > 1)
filename = filename + ".trial" + trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(filename));
oos.writeObject(classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException(
"Couldn't write classifier to filename "
+ filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]) {
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]) {
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
//System.out.println("Report Option :"+(ReportOptions[ReportOption.test][ReportOption.raw]));
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]) {
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
System.out.println("Bino test vars size " + ilists[1].size()
+ "and accuracy + " + testTrial.getAccuracy()
+ " then success " + (int) testTrial.getAccuracy()
* ilists[1].size());
BinomialTest binomtest = new BinomialTest();
double p = 0.5;
// train
if (ReportOptions[ReportOption.train][ReportOption.confusion]) {
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " Training Data Confusion Matrix");
ConsoleView.printlInConsoleln(trainer.toString() + " Training Data Confusion Matrix");
if (ilists[0].size() > 0)
ConsoleView.printlInConsoleln(trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]) {
pvalue = binomtest
.binomialTest(ilists[0].size(), (int) (trainTrial
.getAccuracy() * ilists[0].size()), p,
AlternativeHypothesis.TWO_SIDED);
if (pvalue != 0) {
if (pvalue > 0.5)
pvalue = Math.abs(pvalue - 1);
ConsoleView.printlInConsoleln("Binomial 2-Sided P value = " + pvalue + "\n");
}
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data accuracy= " + trainAccuracy[c][trialIndex]);
ConsoleView.printlInConsoleln(trainer.toString() + " training data accuracy= " + trainAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.f1]) {
String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer "+ trainer.toString() + " training data F1(" + label + ") = " + trainTrial.getF1(label));
ConsoleView.printlInConsoleln(trainer.toString() + " training data F1(" + label + ") = " + trainTrial.getF1(label));
}
// validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]) {
// ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " Validation Data Confusion Matrix");
ConsoleView.printlInConsoleln(trainer.toString() + " Validation Data Confusion Matrix");
if (ilists[2].size() > 0)
ConsoleView.printlInConsoleln(validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]) {
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data accuracy= " + validationAccuracy[c][trialIndex]);
ConsoleView.printlInConsoleln(trainer.toString() + " validation data accuracy= " + validationAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]) {
String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data F1(" + label + ") = " + validationTrial.getF1(label));
ConsoleView.printlInConsoleln(trainer.toString() + " validation data F1(" + label + ") = " + validationTrial.getF1(label));
}
// test
if (ReportOptions[ReportOption.test][ReportOption.confusion]) {
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " Test Data Confusion Matrix");
ConsoleView.printlInConsoleln(trainer.toString() + " Test Data Confusion Matrix");
if (ilists[1].size() > 0)
ConsoleView.printlInConsoleln(testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]) {
pvalue = binomtest.binomialTest(ilists[1].size(),
(int) (testTrial.getAccuracy() * ilists[1].size()),
0.5, AlternativeHypothesis.TWO_SIDED);
if (pvalue != 0) {
if (pvalue > 0.5)
pvalue = Math.abs(pvalue - 1);
ConsoleView.printlInConsoleln("Binomial 2-Sided P value = " + pvalue + " \n");
}
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data accuracy= " + testAccuracy[c][trialIndex]);
ConsoleView.printlInConsoleln(trainer.toString() + " test data accuracy= " + testAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]) {
String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
//ConsoleView.writeInConsole("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data F1(" + label + ") = " + testTrial.getF1(label));
ConsoleView.printlInConsoleln(trainer.toString() + " test data F1(" + label + ") = " + testTrial.getF1(label));
}
if (trialIndex == 0)
trainerNames[c] = trainer.toString();
} // end for each trainer
} // end for each trial
// New reporting
// "[train|test|validation]:[accuracy|f1|confusion|raw]"
for (int c = 0; c < numTrainers; c++) {
ConsoleView.printlInConsole("\n" + trainerNames[c].toString()+ "\n");
if (ReportOptions[ReportOption.train][ReportOption.accuracy]) {
/*ConsoleView.printlInConsoleln("Summary. train accuracy mean = "
+ MatrixOps.mean(trainAccuracy[c]) + " stddev = "
+ MatrixOps.stddev(trainAccuracy[c]) + " stderr = "
+ MatrixOps.stderr(trainAccuracy[c])); */
String trainResult = "";
if (pvalue != 0)
trainResult+="Summary. train accuracy = " + MatrixOps.mean(trainAccuracy[c]);
else
trainResult+="Summary. train accuracy = " + MatrixOps.mean(trainAccuracy[c]);
if(numTrials > 1) {
trainResult+=" stddev = " + MatrixOps.stddev(trainAccuracy[c]) + " stderr = "+ MatrixOps.stderr(trainAccuracy[c]);
}
ConsoleView.printlInConsoleln(trainResult);
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]) {
/*
ConsoleView.printlInConsoleln("Summary. validation accuracy mean = "
+ MatrixOps.mean(validationAccuracy[c]) + " stddev = "
+ MatrixOps.stddev(validationAccuracy[c])
+ " stderr = "
+ MatrixOps.stderr(validationAccuracy[c]));*/
String validationResult = "";
if (pvalue != 0)
validationResult+="Summary. validation accuracy = " + MatrixOps.mean(validationAccuracy[c]);
else
validationResult+="Summary. validation accuracy = " + MatrixOps.mean(validationAccuracy[c]);
if(numTrials > 1) {
validationResult+=" stddev = " + MatrixOps.stddev(validationAccuracy[c]) + " stderr = "+ MatrixOps.stderr(validationAccuracy[c]);
}
ConsoleView.printlInConsoleln(validationResult);
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]) {
String testResult = "";
if (pvalue != 0)
testResult+="Summary. test accuracy = " + MatrixOps.mean(testAccuracy[c]) + " Binomial 2-Sided Pvalue = " + pvalue;
else
testResult+="Summary. test accuracy = " + MatrixOps.mean(testAccuracy[c]) + " Pvalue < 10^(-1022)\n";
if(numTrials > 1) {
testResult+=" stddev = " + MatrixOps.stddev(testAccuracy[c]) + " stderr = "+ MatrixOps.stderr(testAccuracy[c]);
}
ConsoleView.printlInConsoleln(testResult);
/*
if (pvalue != 0)
ConsoleView.printlInConsoleln("Summary. test accuracy mean = "
+ MatrixOps.mean(testAccuracy[c]) + " stddev = "
+ MatrixOps.stddev(testAccuracy[c]) + " stderr = "
+ MatrixOps.stderr(testAccuracy[c]) + " pvalue = "
+ pvalue);
else
ConsoleView.printlInConsoleln("Summary. test accuracy mean = "
+ MatrixOps.mean(testAccuracy[c]) + " stddev = "
+ MatrixOps.stddev(testAccuracy[c]) + " stderr = "
+ MatrixOps.stderr(testAccuracy[c])
+ " P value < 10^(-1022)\n"); */
}
// If we are testing the classifier with two folders, result will be
// empty - no report is generated
if (result.isEmpty()) {
if (pvalue != 0)
result.add("Summary. test accuracy = " + MatrixOps.mean(testAccuracy[c]) + " Binomial 2-Sided Pvalue = " + pvalue);
else
result.add("Summary. test accuracy = " + MatrixOps.mean(testAccuracy[c]) + " Pvalue < 10^(-1022)\n");
if(numTrials > 1) {
result.add(" stddev = " + MatrixOps.stddev(testAccuracy[c]) + " stderr = "+ MatrixOps.stderr(testAccuracy[c]));
}
}
} // end for each trainer
return result;
}
private static void printTrialClassification(Trial trial) {
for (Classification c : trial) {
String classification = "";
Instance instance = c.getInstance();
System.out.print(instance.getName() + " " + instance.getTarget()
+ " ");
classification = instance.getName() + "," + instance.getTarget()
+ " ";
Labeling labeling = c.getLabeling();
boolean foundPredictedClass = false;
for (int j = 0; j < labeling.numLocations(); j++) {
if(!labeling.getLabelAtRank(j).toString().isEmpty()) {
if(!foundPredictedClass) {
classification = classification
+ labeling.getLabelAtRank(j).toString() + ","
+ labeling.getValueAtRank(j) + ",";
foundPredictedClass = true;
} else {
classification = classification
+ labeling.getLabelAtRank(j).toString() + "("
+ labeling.getValueAtRank(j) + ")" + "; ";
}
System.out.print(labeling.getLabelAtRank(j).toString() + ":"
+ labeling.getValueAtRank(j) + " ");
}
}
result.add(classification);
System.out.print("\n");
}
}
private static Object createTrainer(String arg) {
try {
return interpreter.eval(arg);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException("Java interpreter eval error\n"
+ e);
}
}
private static ClassifierTrainer getTrainer(String arg) {
// parse something like Maxent,gaussianPriorVariance=10,numIterations=20
// first, split the argument at commas.
java.lang.String fields[] = arg.split(",");
// Massage constructor name, so that MaxEnt, MaxEntTrainer, new
// MaxEntTrainer()
// all call new MaxEntTrainer()
java.lang.String constructorName = fields[0];
Object trainer;
if (constructorName.indexOf('(') != -1) // if contains (), pass it
// though
trainer = createTrainer(arg);
else {
if (constructorName.endsWith("Trainer")) {
trainer = createTrainer("new " + constructorName + "()"); // add
// parens
// if
// they
// forgot
} else {
trainer = createTrainer("new " + constructorName + "Trainer()"); // make
// trainer
// name
// from
// classifier
// name
}
}
// find methods associated with the class we just built
Method methods[] = trainer.getClass().getMethods();
// find setters corresponding to parameter names.
for (int i = 1; i < fields.length; i++) {
java.lang.String nameValuePair[] = fields[i].split("=");
java.lang.String parameterName = nameValuePair[0];
java.lang.String parameterValue = nameValuePair[1]; // todo: check
// for val
// present!
java.lang.Object parameterValueObject;
try {
parameterValueObject = interpreter.eval(parameterValue);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException(
"Java interpreter eval error on parameter "
+ parameterName + "\n" + e);
}
boolean foundSetter = false;
for (int j = 0; j < methods.length; j++) {
// System.out.println("method " + j + " name is " +
// methods[j].getName());
// System.out.println("set" +
// Character.toUpperCase(parameterName.charAt(0)) +
// parameterName.substring(1));
if (("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName
.substring(1)).equals(methods[j].getName())
&& methods[j].getParameterTypes().length == 1) {
// System.out.println("Matched method " +
// methods[j].getName());
// Class[] ptypes = methods[j].getParameterTypes();
// System.out.println("Parameter types:");
// for (int k=0; k<ptypes.length; k++){
// System.out.println("class " + k + " = " +
// ptypes[k].getName());
// }
try {
java.lang.Object[] parameterList = new java.lang.Object[] { parameterValueObject };
// System.out.println("Argument types:");
// for (int k=0; k<parameterList.length; k++){
// System.out.println("class " + k + " = " +
// parameterList[k].getClass().getName());
// }
methods[j].invoke(trainer, parameterList);
} catch (IllegalAccessException e) {
System.out.println("IllegalAccessException " + e);
throw new IllegalArgumentException(
"Java access error calling setter\n" + e);
} catch (InvocationTargetException e) {
System.out.println("IllegalTargetException " + e);
throw new IllegalArgumentException(
"Java target error calling setter\n" + e);
}
foundSetter = true;
break;
}
}
if (!foundSetter) {
System.out.println("Parameter " + parameterName
+ " not found on trainer " + constructorName);
System.out.println("Available parameters for "
+ constructorName);
for (int j = 0; j < methods.length; j++) {
if (methods[j].getName().startsWith("set")
&& methods[j].getParameterTypes().length == 1) {
System.out.println(Character.toLowerCase(methods[j]
.getName().charAt(3))
+ methods[j].getName().substring(4));
}
}
throw new IllegalArgumentException(
"no setter found for parameter " + parameterName);
}
}
assert (trainer instanceof ClassifierTrainer);
return ((ClassifierTrainer) trainer);
}
}