/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.tui;
import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;
import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import java.util.Random;
/**
* Classify documents, run trials, print statistics from a vector file.
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
public abstract class Calo2Classify
{
private static Classifier classifierL; // CPAL - added
private static Logger logger = MalletLogger.getLogger(Calo2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Calo2Classify.class.getName() + "-pl");
private static ArrayList classifierTrainers = new ArrayList();
private static boolean[][] ReportOptions = new boolean[3][4];
private static String[][] ReportOptionArgs = new String[3][4]; //arg in dataset:reportOption=arg
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption
{
static final String[] dataOptions = {"train", "test", "validation"};
static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"};
static final int train=0;
static final int test =1;
static final int validation=2;
static final int accuracy=0;
static final int f1=1;
static final int confusion=2;
static final int raw=3;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
(Calo2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]",
true, new String[] {"test:accuracy", "test:confusion", "train:accuracy"},
"", null)
{
public void postParsing (CommandOption.List list)
{
java.lang.String defaultRawFormatting = "siw";
for (int argi=0; argi<this.value.length; argi++){
// convert options like --report train:accuracy --report test:f1=labelA to
// boolean array of options.
// first, split the argument at semicolon.
//System.out.println(argi + " " + this.value[argi]);
java.lang.String arg = this.value[argi];
java.lang.String fields[] = arg.split("[:=]");
java.lang.String dataSet = fields[0];
java.lang.String reportOption = fields[1];
java.lang.String reportOptionArg = null;
if (fields.length >=3){
reportOptionArg = fields[2];
}
//System.out.println("Report option arg " + reportOptionArg);
//find the datasource (test,train,validation)
boolean foundDataSource = false;
int i=0;
for (; i<ReportOption.dataOptions.length; i++){
if (dataSet.equals(ReportOption.dataOptions[i])){
foundDataSource = true;
break;
}
}
if (!foundDataSource){
throw new IllegalArgumentException("Unknown argument = " + dataSet + " in --report " + this.value[argi]);
}
//find the report option (accuracy, f1, confusion, raw)
boolean foundReportOption = false;
int j=0;
for (; j<ReportOption.reportOptions.length; j++){
if (reportOption.equals(ReportOption.reportOptions[j])){
foundReportOption = true;
break;
}
}
if (!foundReportOption){
throw new IllegalArgumentException("Unknown argument = " + reportOption + " in --report " + this.value[argi]);
}
//Mark the (dataSet,reportionOption) pair as selected
ReportOptions[i][j] = true;
if (j == ReportOption.f1){
// make sure a label was specified for f1
if (reportOptionArg == null){
throw new IllegalArgumentException("F1 must have label argument in --report " + this.value[argi]);
}
// Pass through the string argument
ReportOptionArgs[i][j]= reportOptionArg;
}else if (reportOptionArg != null){
throw new IllegalArgumentException("No arguments after = allowed in --report " + this.value[argi]);
}
}
}
};
static CommandOption.Object trainerConstructor = new CommandOption.Object
(Calo2Classify.class, "trainer", "ClassifierTrainer constructor", true, new NaiveBayesTrainer(),
"Java code for the constructor used to create a ClassifierTrainer. "+
"If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended."+
"You may use this option mutiple times to compare multiple classifiers.", null)
{
public void parseArg (java.lang.String arg) {
// parse something like Maxent,gaussianPriorVariance=10,numIterations=20
//System.out.println("Arg = " + arg);
// first, split the argument at commas.
java.lang.String fields[] = arg.split(",");
//Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
// all call new MaxEntTrainer()
java.lang.String constructorName = fields[0];
if (constructorName.indexOf('(') != -1) // if contains (), pass it though
super.parseArg(arg);
else {
if (constructorName.endsWith("Trainer")){
super.parseArg("new " + constructorName + "()"); // add parens if they forgot
}else{
super.parseArg("new "+constructorName+"Trainer()"); // make trainer name from classifier name
}
}
// find methods associated with the class we just built
Method methods[] = this.value.getClass().getMethods();
// find setters corresponding to parameter names.
for (int i=1; i<fields.length; i++){
java.lang.String nameValuePair[] = fields[i].split("=");
java.lang.String parameterName = nameValuePair[0];
java.lang.String parameterValue = nameValuePair[1]; //todo: check for val present!
java.lang.Object parameterValueObject;
try {
parameterValueObject = getInterpreter().eval(parameterValue);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException ("Java interpreter eval error on parameter "+
parameterName + "\n"+e);
}
boolean foundSetter = false;
for (int j=0; j<methods.length; j++){
// System.out.println("method " + j + " name is " + methods[j].getName());
// System.out.println("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1));
if ( ("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)).equals(methods[j].getName()) &&
methods[j].getParameterTypes().length == 1){
// System.out.println("Matched method " + methods[j].getName());
// Class[] ptypes = methods[j].getParameterTypes();
// System.out.println("Parameter types:");
// for (int k=0; k<ptypes.length; k++){
// System.out.println("class " + k + " = " + ptypes[k].getName());
// }
try {
java.lang.Object[] parameterList = new java.lang.Object[]{parameterValueObject};
// System.out.println("Argument types:");
// for (int k=0; k<parameterList.length; k++){
// System.out.println("class " + k + " = " + parameterList[k].getClass().getName());
// }
methods[j].invoke(this.value, parameterList);
} catch ( IllegalAccessException e) {
System.out.println("IllegalAccessException " + e);
throw new IllegalArgumentException ("Java access error calling setter\n"+e);
} catch ( InvocationTargetException e) {
System.out.println("IllegalTargetException " + e);
throw new IllegalArgumentException ("Java target error calling setter\n"+e);
}
foundSetter = true;
break;
}
}
if (!foundSetter){
System.out.println("Parameter " + parameterName + " not found on trainer " + constructorName);
System.out.println("Available parameters for " + constructorName);
for (int j=0; j<methods.length; j++){
if ( methods[j].getName().startsWith("set") && methods[j].getParameterTypes().length == 1){
System.out.println(Character.toLowerCase(methods[j].getName().charAt(3)) +
methods[j].getName().substring(4));
}
}
throw new IllegalArgumentException ("no setter found for parameter " + parameterName);
}
}
}
public void postParsing (CommandOption.List list) {
assert (this.value instanceof ClassifierTrainer);
//System.out.println("v2c PostParsing " + this.value);
classifierTrainers.add (this.value);
}
};
// CPAL - added this to load a classifier from a file
static CommandOption.String loadmodelFile = new CommandOption.String
(Calo2Classify.class, "load-model", "FILENAME", true, "classifier.mallet",
"The filename in which to write the classifier after it has been trained.", null);
static CommandOption.String outputFile = new CommandOption.String
(Calo2Classify.class, "output-classifier", "FILENAME", true, "classifier.mallet",
"The filename in which to write the classifier after it has been trained.", null);
static CommandOption.String inputFile = new CommandOption.String
(Calo2Classify.class, "input", "FILENAME", true, "text.vectors",
"The filename from which to read the list of training instances. Use - for stdin.", null);
static CommandOption.String trainingFile = new CommandOption.String
(Calo2Classify.class, "training-file", "FILENAME", true, "text.vectors",
"Read the training set instance list from this file. " +
"If this is specified, the input file parameter is ignored", null);
static CommandOption.String testFile = new CommandOption.String
(Calo2Classify.class, "testing-file", "FILENAME", true, "text.vectors",
"Read the test set instance list to this file. " +
"If this option is specified, the training-file parameter must be specified and " +
" the input-file parameter is ignored", null);
static CommandOption.String validationFile = new CommandOption.String
(Calo2Classify.class, "validation-file", "FILENAME", true, "text.vectors",
"Read the validation set instance list to this file." +
"If this option is specified, the training-file parameter must be specified and " +
"the input-file parameter is ignored", null);
static CommandOption.Double trainingProportionOption = new CommandOption.Double
(Calo2Classify.class, "training-portion", "DECIMAL", true, 1.0,
"The fraction of the instances that should be used for training.", null);
static CommandOption.Double validationProportionOption = new CommandOption.Double
(Calo2Classify.class, "validation-portion", "DECIMAL", true, 0.0,
"The fraction of the instances that should be used for validation.", null);
static CommandOption.Double unlabeledProportionOption = new CommandOption.Double
(Calo2Classify.class, "unlabeled-portion", "DECIMAL", true, 0.0,
"The fraction of the training instances that should have their labels hidden. "
+"Note that these are taken out of the training-portion, not allocated separately.", null);
static CommandOption.Integer randomSeedOption = new CommandOption.Integer
(Calo2Classify.class, "random-seed", "INTEGER", true, 0,
"The random seed for randomly selecting a proportion of the instance list for training", null);
static CommandOption.Integer numTrialsOption = new CommandOption.Integer
(Calo2Classify.class, "num-trials", "INTEGER", true, 1,
"The number of random train/test splits to perform", null);
static CommandOption.Object classifierEvaluatorOption = new CommandOption.Object
(Calo2Classify.class, "classifier-evaluator", "CONSTRUCTOR", true, null,
"Java code for constructing a ClassifierEvaluating object", null);
// static CommandOption.Boolean printTrainAccuracyOption = new CommandOption.Boolean
// (Vectors2Classify.class, "print-train-accuracy", "true|false", true, true,
// "After training, run the resulting classifier on the instances included in training, "
// +"and print the accuracy", null);
//
// static CommandOption.Boolean printTestAccuracyOption = new CommandOption.Boolean
// (Vectors2Classify.class, "print-test-accuracy", "true|false", true, true,
// "After training, run the resulting classifier on the instances not included in training, "
// +"and print the accuracy", null);
static CommandOption.Integer verbosityOption = new CommandOption.Integer
(Calo2Classify.class, "verbosity", "INTEGER", true, -1,
"The level of messages to print: 0 is silent, 8 is most verbose. " +
"Levels 0-8 correspond to the java.logger predefined levels "+
"off, severe, warning, info, config, fine, finer, finest, all. " +
"The default value is taken from the mallet logging.properties file," +
" which currently defaults to INFO level (3)", null);
static CommandOption.Boolean noOverwriteProgressMessagesOption = new CommandOption.Boolean
(Calo2Classify.class, "noOverwriteProgressMessages", "true|false", false, false,
"Suppress writing-in-place on terminal for progess messages - repetitive messages "
+"of which only the latest is generally of interest", null);
public static void main (String[] args) throws bsh.EvalError, IOException
{
// Process the command-line options
CommandOption.setSummary (Calo2Classify.class,
"A tool for training, saving and printing diagnostics from a classifier on vectors.");
CommandOption.process (Calo2Classify.class, args);
// handle default trainer here for now; default argument processing doesn't work
if (!trainerConstructor.wasInvoked()){
classifierTrainers.add (new NaiveBayesTrainer());
}
if (!report.wasInvoked()){
report.postParsing(null); // force postprocessing of default value
}
int verbosity = verbosityOption.value;
Logger rootLogger = ((MalletLogger)progressLogger).getRootLogger();
if (verbosityOption.wasInvoked()){
rootLogger.setLevel( MalletLogger.LoggingLevels[verbosity]);
}
if (noOverwriteProgressMessagesOption.value == false){
// install special formatting for progress messages
// find console handler on root logger; change formatter to one
// that knows about progress messages
Handler[] handlers = rootLogger.getHandlers();
for (int i = 0; i < handlers.length; i++) {
if (handlers[i] instanceof ConsoleHandler) {
handlers[i].setFormatter(new ProgressMessageLogFormatter());
}
}
}
boolean separateIlists = testFile.wasInvoked() || trainingFile.wasInvoked() ||
validationFile.wasInvoked();
InstanceList ilist=null;
InstanceList testFileIlist=null;
InstanceList trainingFileIlist=null;
InstanceList validationFileIlist=null;
if (!separateIlists) { // normal case, --input-file specified
// Read in the InstanceList, from stdin if the input filename is "-".
ilist = InstanceList.load (new File(inputFile.value));
}else{ // user specified separate files for testing and training sets.
trainingFileIlist = InstanceList.load (new File(trainingFile.value));
logger.info("Training vectors loaded from " + trainingFile.value);
if (testFile.wasInvoked()){
testFileIlist = InstanceList.load (new File(testFile.value));
logger.info("Testing vectors loaded from " + testFile.value);
}
if (validationFile.wasInvoked()){
validationFileIlist = InstanceList.load (new File(validationFile.value));
logger.info("validation vectors loaded from " + validationFile.value);
}
}
int numTrials = numTrialsOption.value;
Random r = randomSeedOption.wasInvoked() ? new Random (randomSeedOption.value) : new Random ();
ClassifierTrainer[] trainers = new ClassifierTrainer[classifierTrainers.size()];
for (int i = 0; i < classifierTrainers.size(); i++) {
trainers[i] = (ClassifierTrainer) classifierTrainers.get(i);
logger.fine ("Trainer specified = "+trainers[i].toString());
}
double trainAccuracy[][] = new double[trainers.length][numTrials];
double testAccuracy[][] = new double[trainers.length][numTrials];
double validationAccuracy[][] = new double[trainers.length][numTrials];
String trainConfusionMatrix[][] = new String[trainers.length][numTrials];
String testConfusionMatrix[][] = new String[trainers.length][numTrials];
String validationConfusionMatrix[][] = new String[trainers.length][numTrials];
double t = trainingProportionOption.value;
double v = validationProportionOption.value;
if (!separateIlists) {
logger.info("Training portion = " + t);
logger.info(" Unlabeled training sub-portion = "+unlabeledProportionOption.value);
logger.info("Validation portion = " + v);
logger.info("Testing portion = " + (1 - v - t));
}
// for (int i=0; i<3; i++){
// for (int j=0; j<4; j++){
// System.out.print(" " + ReportOptions[i][j]);
// }
// System.out.println();
// }
// CPAL - Initialize A Classifier to be used for each trial
// CPAL - use this to load a classifier
if (loadmodelFile.wasInvoked()) {
String filename = loadmodelFile.value;
//String filename = outputFile.value;
//if (trainers.length > 1) filename = filename+trainers[c].toString();
//if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
//ObjectOutputStream oos = new ObjectOutputStream
// (new FileOutputStream (filename));
//oos.writeObject (classifier);
ObjectInputStream iis = new ObjectInputStream
(new FileInputStream (filename));
classifierL = (Classifier) iis.readObject();
iis.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't read classifier from filename "+
filename);
}
}
// CPAL
for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
System.out.println("\n-------------------- Trial " + trialIndex + " --------------------\n");
InstanceList[] ilists;
BitSet unlabeledIndices = null;
if (!separateIlists){
ilists = ilist.split (r, new double[] {t, 1-t-v, v});
} else {
ilists = new InstanceList[3];
ilists[0] = trainingFileIlist;
ilists[1] = testFileIlist;
ilists[2] = testFileIlist;
}
if (unlabeledProportionOption.value > 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(),
unlabeledProportionOption.value);
//InfoGain ig = new InfoGain (ilists[0]);
//int igl = Math.min (10, ig.numLocations());
//for (int i = 0; i < igl; i++)
//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
//ig.print();
//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
//ilists[0].setFeatureSelection (selectedFeatures);
//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
//orfi.induceFeatures (ilists[0], false, true);
//System.out.println ("Training with "+ilists[0].size()+" instances");
long time[] = new long[trainers.length];
for (int c = 0; c < trainers.length; c++){
time[c] = System.currentTimeMillis();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier;
if(loadmodelFile.wasInvoked()) {
classifier = classifierL;
} else {
classifier = trainers[c].train (ilists[0]);
}
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial (classifier, ilists[0]);
assert (ilists[1].size() > 0);
Trial testTrial = new Trial (classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
if (ilists[0].size()>0) trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
if (ilists[1].size()>0) testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
if (ilists[2].size()>0) validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (trainers.length > 1) filename = filename+trainers[c].toString();
if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream
(new FileOutputStream (filename));
oos.writeObject (classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write classifier to filename "+
filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
//train
if (ReportOptions[ReportOption.train][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Training Data Confusion Matrix");
if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
}
//validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Validation Data Confusion Matrix");
if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
}
//test
if (ReportOptions[ReportOption.test][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
}
} // end for each trainer
} // end for each trial
// New reporting
//"[train|test|validation]:[accuracy|f1|confusion|raw]"
for (int c=0; c < trainers.length; c++) {
System.out.println ("\n"+trainers[c].toString());
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
} // end for each trainer
}
private static void printTrialClassification(Trial trial)
{
for (int i = 0; i < trial.size(); i++) {
Instance instance = trial.get(i).getInstance();
System.out.print(instance.getName() + " " + instance.getTarget() + " ");
Labeling labeling = trial.get(i).getLabeling();
for (int j = 0; j < labeling.numLocations(); j++){
System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
}
System.out.println();
}
}
}