package info.ephyra.questionanalysis.atype.minorthird.hierarchical;
import info.ephyra.questionanalysis.atype.extractor.FeatureExtractor;
import info.ephyra.questionanalysis.atype.extractor.FeatureExtractorFactory;
import info.ephyra.util.Properties;
import java.io.File;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Date;
import java.util.Formatter;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import libsvm.svm_parameter;
import org.apache.log4j.Logger;
import edu.cmu.lti.javelin.util.FileUtil;
import edu.cmu.lti.javelin.util.Language;
import edu.cmu.lti.javelin.util.MLToolkit;
import edu.cmu.lti.util.Pair;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.CascadingBinaryLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.MostFrequentFirstLearner;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.OneVsAllLearner;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow;
import edu.cmu.minorthird.classify.algorithms.linear.KWayMixtureLearner;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.NegativeBinomialLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.random.RandomElement;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.Tester;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.gui.ViewerFrame;
/**
* Tool for training and evaluating hierarchical classifiers.
*
* @author Justin Betteridge
* @version 2008-02-10
*/
public class HierarchicalClassifierTrainer{
private static Logger log = Logger.getLogger(HierarchicalClassifierTrainer.class);
private FeatureExtractor extractor;
private String trainingFile;
private String testingFile;
private int crossValidationFolds;
private String[] learnerNames;
private boolean useClassLevels;
private HashSet<String> classLabels;
private HashSet<String> trainingLabels;
private HashSet<String> featureTypes;
private boolean loadTraining;
private String classifierDir;
//private ExampleSchema schema;
private Dataset trainingSet;
private Dataset testingSet;
private Classifier classifier;
private Pair<Language,Language> languagePair;
private Properties properties;
private CrossValidatedDataset cvDataset;
private Evaluation evaluation;
private long runTime;
public HierarchicalClassifierTrainer(Pair<Language,Language> languagePair){
this.languagePair = languagePair;
}
/**
* Overrides default properties with those given.
*
* @param properties
*/
public void setProperties(Properties properties) {
for (String property : properties.keySet()) {
this.properties.put(property,properties.get(property));
}
try {
initialize();
} catch (Exception e) {
log.error("Error re-initializing: ",e);
}
}
public void initialize() throws Exception {
if (languagePair == null)
throw new Exception("Langauage pair must be set before calling initialize");
if (properties == null) {
properties = Properties.loadFromClassName(this.getClass().getName());
properties = properties.mapProperties().get(languagePair.getFirst()+"_"+languagePair.getSecond());
extractor=FeatureExtractorFactory.getInstance(languagePair.getFirst());
}
trainingFile=properties.getProperty("trainingFile");
testingFile=properties.getProperty("testingFile");
crossValidationFolds=Integer.parseInt(properties.getProperty("crossValidationFolds"));
learnerNames=properties.getProperty("learners").split(",");
for(int i=0;i<learnerNames.length;i++){
learnerNames[i]=learnerNames[i].trim();
}
useClassLevels=Boolean.parseBoolean(properties.getProperty("useClassLevels"));
if (!useClassLevels && learnerNames.length > 1) {
String[] newArr = new String[1];
learnerNames = Arrays.asList(learnerNames).subList(0,1).toArray(newArr);
}
classLabels=new HashSet<String>();
String[] labels=properties.getProperty("classLabels").split(",");
for(int i=0;i<labels.length;i++){
labels[i]=HierarchicalClassifier.getHierarchicalClassName(labels[i],learnerNames.length,useClassLevels);
classLabels.add(labels[i]);
}
//schema=new ExampleSchema(labels);
featureTypes=new HashSet<String>();
String[] types=properties.getProperty("featureTypes").split(",");
for(int i=0;i<types.length;i++){
featureTypes.add(types[i].trim());
}
classifierDir=properties.getProperty("classifierDir");
trainingSet=makeDataset(trainingFile);
if(crossValidationFolds<0){
testingSet=makeDataset(testingFile);
}
}
private Dataset makeDataset(String fileName){
if (trainingLabels == null) {
loadTraining = true;
trainingLabels = new HashSet<String>();
}
Dataset set=new BasicDataset();
extractor.setUseClassLevels(useClassLevels);
extractor.setClassLevels(learnerNames.length);
Example[] examples=extractor.loadFile(fileName);
for(int i=0;i<examples.length;i++){
String label = examples[i].getLabel().bestClassName();
if(classLabels.contains(label)){
MutableInstance instance=new MutableInstance(examples[i].getSource(),examples[i].getSubpopulationId());
Feature.Looper bLooper=examples[i].binaryFeatureIterator();
while(bLooper.hasNext()){
Feature f=bLooper.nextFeature();
if(featureTypes.contains(f.getPart(0))){
instance.addBinary(f);
}
}
Feature.Looper nLooper=examples[i].numericFeatureIterator();
while(nLooper.hasNext()){
Feature f=nLooper.nextFeature();
if(featureTypes.contains(f.getPart(0))){
instance.addNumeric(f,examples[i].getWeight(f));
}
}
Example example=new Example(instance,examples[i].getLabel());
MLToolkit.println(example);
if (loadTraining) {
trainingLabels.add(label);
set.add(example);
}
else {
if (!trainingLabels.contains(label))
MLToolkit.println("Label of test example not found in training set (discarding): "+label);
else set.add(example);
}
}
else{
MLToolkit.println("Discarding example for Class: "+label);
}
}
if (loadTraining) loadTraining = false;
MLToolkit.println("Loaded "+set.size()+" examples for experiment from "+fileName);
return set;
}
public HierarchicalClassifierLearner createHierarchicalClassifierLearner(String[] learners){
ClassifierLearner[] prototypes=new ClassifierLearner[learners.length];
for(int i=0;i<prototypes.length;i++){
prototypes[i]=createLearnerByName(learners[i]);
}
return new HierarchicalClassifierLearner(prototypes);
}
public ClassifierLearner createLearnerByName(String name){
ClassifierLearner learner;
//K-Nearest-Neighbor learner, using m3rd recommended parameters
if(name.equalsIgnoreCase("KNN")){
learner=new KnnLearner();
}
//K-Way Mixture learner, using m3rd recommended parameters
else if(name.equalsIgnoreCase("KWAY_MIX")){
learner=new KWayMixtureLearner();
}
//Maximum Entropy learner, using m3rd recommended parameters
else if(name.equalsIgnoreCase("MAX_ENT")){
learner=new MaxEntLearner();
}
//Balanced Winnow learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("BWINNOW_OVA")){
learner=new OneVsAllLearner(new BalancedWinnow());
}
//Margin Perceptron learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("MPERCEPTRON_OVA")){
learner=new OneVsAllLearner(new MarginPerceptron());
}
//Naive Bayes learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("NBAYES_OVA")){
learner=new OneVsAllLearner(new NaiveBayes());
}
//Voted Perceptron learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("VPERCEPTRON_OVA")){
learner=new OneVsAllLearner(new VotedPerceptron());
}
//Ada Boost learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOST_OVA")){
learner=new OneVsAllLearner(new AdaBoost());
}
//Ada Boost learner with Cascading binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOST_CB")){
learner=new CascadingBinaryLearner(new AdaBoost());
}
//Ada Boost learner with Most Frequent First binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOST_MFF")){
learner=new MostFrequentFirstLearner(new AdaBoost());
}
//Ada Boost learner (Logistic Regression version) with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOSTL_OVA")){
learner=new OneVsAllLearner(new AdaBoost.L());
}
//Ada Boost learner (Logistic Regression version) with Cascading binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOSTL_CB")){
learner=new CascadingBinaryLearner(new AdaBoost.L());
}
//Ada Boost learner (Logistic Regression version) with Most Frequent First binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("ADABOOSTL_MFF")){
learner=new MostFrequentFirstLearner(new AdaBoost.L());
}
//Decision Tree learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("DTREE_OVA")){
learner=new OneVsAllLearner(new DecisionTreeLearner());
}
//Decision Tree learner with Cascading binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("DTREE_CB")){
learner=new CascadingBinaryLearner(new DecisionTreeLearner());
}
//Decision Tree learner with Most Frequent First binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("DTREE_MFF")){
learner=new MostFrequentFirstLearner(new DecisionTreeLearner());
}
//Negative Binomial learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("NEGBI_OVA")){
learner=new OneVsAllLearner(new NegativeBinomialLearner());
}
//Negative Binomial learner with Cascading binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("NEGBI_CB")){
learner=new CascadingBinaryLearner(new NegativeBinomialLearner());
}
//Negative Binomial learner with Most Frequent First binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("NEGBI_MFF")){
learner=new MostFrequentFirstLearner(new NegativeBinomialLearner());
}
//SVM learner with One vs All binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("SVM_OVA")){
learner=new OneVsAllLearner(new SVMLearner());
}
//SVM learner with One vs All binary transformer, using testing parameters
else if(name.equalsIgnoreCase("SVM_OVA_CONF1")){
svm_parameter param=new svm_parameter();
param.svm_type=svm_parameter.C_SVC;
param.kernel_type=svm_parameter.POLY;
param.degree=2;
param.gamma=1; // 1/k
param.coef0=0;
param.nu=0.5;
param.cache_size=40;
param.C=1;
param.eps=1e-3;
param.p=0.1;
param.shrinking=1;
param.nr_weight=0;
param.weight_label=new int[0];
param.weight=new double[0];
learner=new OneVsAllLearner(new SVMLearner(param));
}
//SVM learner with Cascading binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("SVM_CB")){
learner=new CascadingBinaryLearner(new SVMLearner());
}
//SVM learner with Most Frequent First binary transformer, using m3rd recommended parameters
else if(name.equalsIgnoreCase("SVM_MFF")){
learner=new MostFrequentFirstLearner(new SVMLearner());
}
else{
System.err.println("Unrecognized learner name: "+name);
learner=null;
}
return learner;
}
public Evaluation runExperiment(){
runTime=System.currentTimeMillis();
ClassifierLearner learner=createHierarchicalClassifierLearner(learnerNames);
if(crossValidationFolds<0){
evaluation=Tester.evaluate(learner,trainingSet,testingSet);
}
else{
Splitter splitter=new CrossValSplitter(new RandomElement(System.currentTimeMillis()),crossValidationFolds);
cvDataset=new CrossValidatedDataset(learner,trainingSet,splitter,true);
evaluation=cvDataset.getEvaluation();
//remove later
//ViewerFrame frame=new ViewerFrame(trainingFile,cvDataset.toGUI());
//frame.setVisible(true);
//evaluation=Tester.evaluate(learner,trainingSet,splitter);
}
runTime=System.currentTimeMillis()-runTime;
return evaluation;
}
public void trainClassifier(){
runTime=System.currentTimeMillis();
ClassifierLearner learner=createHierarchicalClassifierLearner(learnerNames);
classifier=new DatasetClassifierTeacher(trainingSet).train(learner);
runTime=System.currentTimeMillis()-runTime;
}
public void saveClassifier(String fileName){
try{
IOUtil.saveSerialized((Serializable)classifier,new File(fileName));
}
catch(Exception e){
e.printStackTrace(System.err);
}
}
public void saveClassifier(){
String fileName=classifierDir+System.currentTimeMillis()/1000;
for(int i=0;i<learnerNames.length;i++){
fileName+="-"+learnerNames[i];
}
if(useClassLevels){
fileName+="-HC";
}
fileName+="-"+(new File(trainingFile)).getName();
saveClassifier(fileName);
}
public void loadClassifier(String fileName){
try{
classifier=(Classifier)IOUtil.loadSerialized(new File(fileName));
}
catch(Exception e){
e.printStackTrace(System.err);
}
}
public Classifier getClassifier(){
return classifier;
}
public String createReport(){
DecimalFormat format=new DecimalFormat("#0.00");
StringBuffer b=new StringBuffer();
b.append("Question Answer Type Classification Report\n");
b.append((new Date())+"\n");
b.append("\n");
b.append("Training Data File: "+trainingFile+"\n");
if(crossValidationFolds<0){
b.append("Testing Data File: "+testingFile+"\n");
}
else{
b.append("Testing using "+crossValidationFolds+"-fold cross validation"+"\n");
}
b.append("\n");
b.append("Valid Class Labels:");
for(Iterator it=classLabels.iterator();it.hasNext();b.append(" "+(String)it.next()));
b.append("\n");
b.append("\n");
if(useClassLevels){
b.append("Using Hierarchical Classifier Learners:\n");
for(int i=0;i<learnerNames.length;i++){
b.append("\t"+learnerNames[i]+"\n");
}
}
else{
b.append("Using Simple Classifier Learner: "+learnerNames[0]+"\n");
}
b.append("\n");
b.append("Feature Selection:\n");
for(Iterator it=featureTypes.iterator();it.hasNext();b.append("\t"+(String)it.next()+"\n"));
b.append("\n");
b.append("Experiment Results:\n");
b.append("\n");
b.append("\tAccuracy: "+format.format((1.0-evaluation.errorRate())*100)+"% ["+evaluation.numExamples()+" example(s)]\n");
b.append("\n");
b.append("\tAccuracy by Class:\n");
b.append("\n");
String[] classNames=evaluation.getClasses();
double[] numExamples=evaluation.numberOfExamplesByClass();
double[] errorRates=evaluation.errorRateByClass();
double total=0;
for(int i=0;i<classNames.length;i++){
double accuracy=(1.0-errorRates[i])*100;
b.append("\t\t"+classNames[i]+" "+format.format(accuracy)+"% ["+(int)numExamples[i]+" example(s)]\n");
total+=accuracy;
}
b.append("\n");
b.append("\tAverage Class Accuracy: "+format.format(total/classNames.length)+"%\n");
b.append("\n");
b.append("Run Time: "+runTime+" ms\n");
b.append("\n");
b.append("Confusion Matrix:\n");
b.append("\n");
b.append(prettyPrintCM(evaluation.confusionMatrix(),evaluation.getClasses()));
b.append("\n");
return b.toString();
}
private String prettyPrintCM (Evaluation.Matrix matrix, String[] classes ) {
double[][] values = matrix.values;
String[] classAbb = new String[classes.length];
StringBuilder res = new StringBuilder();
Formatter formatter = new Formatter(res,Locale.US);
int max = 0;
for (int i = 0; i < classes.length; i++) {
classAbb[i] = classes[i].replaceAll("\\B(.{1,2}).*?(.)\\b","$1$2");
if (classAbb[i].length() > max) max = classAbb[i].length();
}
max++;
String formatStr = "%-"+max+"s";
formatter.format(formatStr,"");
for (int i = 0; i < classes.length; i++) {
formatter.format(formatStr,classAbb[i]);
}
res.append("\n\n");
for (int i = 0; i < classes.length; i++ ) {
formatter.format(formatStr,classAbb[i]);
for (int j = 0; j < classes.length; j++) {
formatter.format(formatStr,Double.toString(values[i][j]));
}
res.append("\n\n");
}
return res.toString();
}
public static void main(String[] args)throws Exception{
if (args.length > 3 || args.length < 2 || (args.length == 3 && !args[0].equals("--train"))) {
System.err.println("Usage:");
System.err.println("java HierarchicalClassifierTrainer [--train] <questionLang> <corpusLang>\n");
System.err.println(" - <questionLang> and <corpusLang> must be one of the following:");
System.err.println(" en_US, ja_JP, jp_JP, zh_TW, zh_CN");
System.err.println(" - Outputs a trained model in the current directory if --train is used.");
System.err.println(" - Otherwise, performs an evaluation using the configuration in the");
System.err.println(" properties file and outputs a report describing the results.");
System.exit(0);
}
boolean train = false;
int langPairInd = 0;
if (args[0].equals("--train")) {
train = true;
langPairInd++;
}
Pair<Language,Language> languagePair = new Pair<Language,Language>(
Language.valueOf(args[langPairInd]),
Language.valueOf(args[langPairInd+1]));
HierarchicalClassifierTrainer qct=new HierarchicalClassifierTrainer(languagePair);
qct.initialize();
if (train) {
System.out.println("Training classifier...");
qct.trainClassifier();
qct.saveClassifier();
System.out.println("Classifier saved.");
} else {
System.out.println("Running experiment...");
Evaluation eval=qct.runExperiment();
FileUtil.writeFile(qct.createReport(),args[0]+".report"+System.currentTimeMillis()+".txt","UTF-8");
ViewerFrame frame=new ViewerFrame(args[0],eval.toGUI());
frame.setVisible(true);
}
}
}