/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify; import java.awt.BorderLayout; import java.util.Iterator; import javax.swing.JComponent; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.border.TitledBorder; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner; import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost; import edu.cmu.minorthird.classify.experiments.CrossValSplitter; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; /** * Stacked generalization. This implementation is based on Wolpert, * D.H. (1992), Stacked Generalization, Neural Networks, Vol. 5, * pp. 241-259, Pergamon Press. http://citeseer.nj.nec.com/wolpert92stacked.html * * @author William Cohen */ public class StackedLearner extends BatchClassifierLearner { private static Logger log = Logger.getLogger(StackedLearner.class); private static final boolean DEBUG = false; private ExampleSchema schema; private BatchClassifierLearner[] innerLearners; private BatchClassifierLearner finalLearner; private Splitter<Example> splitter; /** Use stacked learning to calibrate the predictions of the inner learner * using logistic regression. */ public StackedLearner(BatchClassifierLearner innerLearner,Splitter<Example> splitter) { this( new BatchClassifierLearner[]{innerLearner}, new MaxEntLearner(), splitter); } /** Use stacked learning to calibrate the predictions of the inner learner * using logistic regression, using 3-CV to split. */ public StackedLearner(BatchClassifierLearner innerLearner) { this( new BatchClassifierLearner[]{innerLearner}, new MaxEntLearner(), new CrossValSplitter<Example>(3)); } /** Use stacked learning to calibrate the predictions of AdaBoost * using logistic regression, using 3-CV to split. */ public StackedLearner() { this( new BatchClassifierLearner[]{new AdaBoost()}, new MaxEntLearner(), new CrossValSplitter<Example>(3)); } /** Create a stacked learner. */ public StackedLearner( BatchClassifierLearner[] innerLearners, BatchClassifierLearner finalLearner, Splitter<Example> splitter) { this.innerLearners = innerLearners; this.finalLearner = finalLearner; this.splitter = splitter; } public Splitter<Example> getSplitter() { return splitter; } public void setSplitter(Splitter<Example> splitter) { this.splitter=splitter; } public void setInnerLearner(BatchClassifierLearner learner) { this.innerLearners = new BatchClassifierLearner[]{learner}; } public BatchClassifierLearner getInnerLearner() { if (innerLearners.length!=1) throw new IllegalStateException("multiple inner learners"); return innerLearners[0]; } @Override final public void setSchema(ExampleSchema schema) { this.schema = schema; for (int i=0; i<innerLearners.length; i++) { innerLearners[i].setSchema(schema); } finalLearner.setSchema(schema); } @Override final public ExampleSchema getSchema(){ return schema; } @Override public Classifier batchTrain(Dataset dataset) { BasicDataset stackedData = new BasicDataset(); Classifier[] innerClassifiers = new Classifier[innerLearners.length]; // build transformed dataset of examples where features // are predictions of inner learners on test data, and // classes are the real classes Dataset.Split split = dataset.split(splitter); for (int k=0; k<split.getNumPartitions(); k++) { Dataset trainData = split.getTrain(k); for (int i=0; i<innerLearners.length; i++) { innerLearners[i].reset(); log.info("training inner learner "+(i+1)+"/"+innerLearners.length +" on fold "+(k+1)+"/"+split.getNumPartitions()); innerClassifiers[i] = innerLearners[i].batchTrain(trainData); } Dataset testData = split.getTest(k); log.info("transforming test examples of fold "+(k+1)+"/"+split.getNumPartitions()); for (Iterator<Example> j=testData.iterator(); j.hasNext(); ) { Example e = j.next(); stackedData.add(new Example(transformInstance(schema,e,innerClassifiers), e.getLabel())); } } // train final learner on transformed data, and innerLearners on the real data log.info("training level-1 learner"); Classifier finalClassifier = finalLearner.batchTrain(stackedData); log.info("result is "+finalClassifier); for (int i=0; i<innerLearners.length; i++) { log.info("training inner learner "+(i+1)+"/"+innerLearners.length+" on full dataset"); innerClassifiers[i] = innerLearners[i].batchTrain(dataset); } classifier = new StackedClassifier(schema,innerClassifiers,finalClassifier); return classifier; } private static Instance transformInstance(ExampleSchema schema,Instance oldInstance,Classifier[] innerClassifiers) { MutableInstance newInstance = new MutableInstance(); for (int i=0; i<innerClassifiers.length; i++) { ClassLabel ithPrediction = innerClassifiers[i].classification(oldInstance); String learner = "learner_"+i; for (int h=0; h<schema.getNumberOfClasses(); h++) { String className = schema.getClassName(h); double w = ithPrediction.getWeight(className); newInstance.addNumeric( new Feature(new String[]{learner,"class_"+className}), w); } } if (DEBUG) log.debug("Transformed "+newInstance+" <= "+oldInstance); return newInstance; } private static String explainTransformedInstance(ExampleSchema schema,Instance oldInstance,Classifier[] innerClassifiers) { StringBuffer buf = new StringBuffer(""); MutableInstance newInstance = new MutableInstance(); for (int i=0; i<innerClassifiers.length; i++) { ClassLabel ithPrediction = innerClassifiers[i].classification(oldInstance); String learner = "learner_"+i; for (int h=0; h<schema.getNumberOfClasses(); h++) { String className = schema.getClassName(h); double w = ithPrediction.getWeight(className); newInstance.addNumeric( new Feature(new String[]{learner,"class_"+className}), w); buf.append("learner#"+(i+1)+" predicts "+className+":\n"+ innerClassifiers[i].explain(oldInstance)+"\n"); } } if (DEBUG) log.debug("Transformed "+newInstance+" <= "+oldInstance); return buf.toString(); } static private class StackedClassifier implements Classifier,Visible { private ExampleSchema schema; private Classifier[] innerClassifiers; private Classifier finalClassifier; public StackedClassifier(ExampleSchema schema,Classifier[] innerClassifiers,Classifier finalClassifier) { this.schema = schema; this.innerClassifiers = innerClassifiers; this.finalClassifier = finalClassifier; } @Override public ClassLabel classification(Instance instance) { Instance newInstance = transformInstance(schema,instance,innerClassifiers); return finalClassifier.classification(newInstance); } // public double score(Instance instance,String classLabelName) { // return classification(instance).getWeight(classLabelName); // } @Override public String explain(Instance instance) { StringBuffer buf = new StringBuffer(""); buf.append(explainTransformedInstance(schema,instance,innerClassifiers)); Instance newInstance = transformInstance(schema,instance,innerClassifiers); buf.append("final classifier:\n"); buf.append(finalClassifier.explain(newInstance)); return buf.toString(); } @Override public Explanation getExplanation(Instance instance) { Explanation ex = new Explanation(explain(instance)); return ex; } @Override public Viewer toGUI() { Viewer v = new ComponentViewer() { static final long serialVersionUID=20071015; @Override public JComponent componentFor(Object o) { StackedClassifier sc = (StackedClassifier)o; JPanel mainPanel = new JPanel(); mainPanel.setLayout(new BorderLayout()); mainPanel.setBorder(new TitledBorder("Stacked Classifier")); JPanel finalPanel = new JPanel(); finalPanel.setBorder(new TitledBorder("Final classifier")); Viewer w = new SmartVanillaViewer(sc.finalClassifier); finalPanel.add(w); w.setSuperView(this); mainPanel.add(finalPanel,BorderLayout.NORTH); JPanel innerPanel = new JPanel(); innerPanel.setBorder(new TitledBorder("Inner classifier(s)")); for (int i=0; i<innerClassifiers.length; i++) { Viewer u = new SmartVanillaViewer(innerClassifiers[i]); innerPanel.add(u); u.setSuperView(this); } mainPanel.add(innerPanel,BorderLayout.SOUTH); return new JScrollPane(mainPanel); } }; return v; } } }