/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * ClassifierErrorsPlotInstances.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand */ package weka.gui.explorer; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.IntervalEstimator; import weka.classifiers.evaluation.NumericPrediction; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import weka.core.Utils; import weka.gui.visualize.Plot2D; import weka.gui.visualize.PlotData2D; /** * A class for generating plottable visualization errors. * <p/> * Example usage: * <pre> * Instances train = ... // from somewhere * Instances test = ... // from somewhere * Classifier cls = ... // from somewhere * // build classifier * cls.buildClassifier(train); * // evaluate classifier and generate plot instances * ClassifierPlotInstances plotInstances = new ClassifierPlotInstances(); * plotInstances.setClassifier(cls); * plotInstances.setInstances(train); * plotInstances.setClassIndex(train.classIndex()); * plotInstances.setUp(); * Evaluation eval = new Evaluation(train); * for (int i = 0; i < test.numInstances(); i++) * plotInstances.process(test.instance(i), cls, eval); * // generate visualization * VisualizePanel visPanel = new VisualizePanel(); * visPanel.addPlot(plotInstances.getPlotData("plot name")); * visPanel.setColourIndex(plotInstances.getPlotInstances().classIndex()+1); * // clean up * plotInstances.cleanUp(); * </pre> * * @author fracpete (fracpete at waikato dot ac dot nz) * @version $Revision: 6904 $ */ public class ClassifierErrorsPlotInstances extends AbstractPlotInstances { /** for serialization. */ private static final long serialVersionUID = -3941976365792013279L; /** the minimum plot size for numeric errors. */ protected int m_MinimumPlotSizeNumeric; /** the maximum plot size for numeric errors. */ protected int m_MaximumPlotSizeNumeric; /** whether to save the instances for visualization or just evaluate the * instance. */ protected boolean m_SaveForVisualization; /** for storing the plot shapes. */ protected FastVector m_PlotShapes; /** for storing the plot sizes. */ protected FastVector m_PlotSizes; /** the classifier being used. */ protected Classifier m_Classifier; /** the class index. */ protected int m_ClassIndex; /** the Evaluation object to use. */ protected Evaluation m_Evaluation; /** * Initializes the members. */ protected void initialize() { super.initialize(); m_PlotShapes = new FastVector(); m_PlotSizes = new FastVector(); m_Classifier = null; m_ClassIndex = -1; m_Evaluation = null; m_SaveForVisualization = true; m_MinimumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMinimumPlotSizeNumeric(); m_MaximumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMaximumPlotSizeNumeric(); } /** * Sets the classifier used for making the predictions. * * @param value the classifier to use */ public void setClassifier(Classifier value) { m_Classifier = value; } /** * Returns the currently set classifier. * * @return the classifier in use */ public Classifier getClassifier() { return m_Classifier; } /** * Sets the 0-based class index. * * @param index the class index */ public void setClassIndex(int index) { m_ClassIndex = index; } /** * Returns the 0-based class index. * * @return the class index */ public int getClassIndex() { return m_ClassIndex; } /** * Sets the Evaluation object to use. * * @param value the evaluation to use */ public void setEvaluation(Evaluation value) { m_Evaluation = value; } /** * Returns the Evaluation object in use. * * @return the evaluation object */ public Evaluation getEvaluation() { return m_Evaluation; } /** * Sets whether the instances are saved for visualization or only evaluation * of the prediction is to happen. * * @param value if true then the instances will be saved */ public void setSaveForVisualization(boolean value) { m_SaveForVisualization = value; } /** * Returns whether the instances are saved for visualization for only * evaluation of the prediction is to happen. * * @return true if the instances are saved */ public boolean getSaveForVisualization() { return m_SaveForVisualization; } /** * Checks whether classifier, class index and evaluation are provided. */ protected void check() { super.check(); if (m_Classifier == null) throw new IllegalStateException("No classifier set!"); if (m_ClassIndex == -1) throw new IllegalStateException("No class index set!"); if (m_Evaluation == null) throw new IllegalStateException("No evaluation set"); } /** * Sets up the structure for the plot instances. Sets m_PlotInstances to null * if instances are not saved for visualization. * * @see #getSaveForVisualization() */ protected void determineFormat() { FastVector hv; Attribute predictedClass; Attribute classAt; FastVector attVals; int i; if (!m_SaveForVisualization) { m_PlotInstances = null; return; } hv = new FastVector(); classAt = m_Instances.attribute(m_ClassIndex); if (classAt.isNominal()) { attVals = new FastVector(); for (i = 0; i < classAt.numValues(); i++) attVals.addElement(classAt.value(i)); predictedClass = new Attribute("predicted" + classAt.name(), attVals); } else { predictedClass = new Attribute("predicted" + classAt.name()); } for (i = 0; i < m_Instances.numAttributes(); i++) { if (i == m_Instances.classIndex()) hv.addElement(predictedClass); hv.addElement(m_Instances.attribute(i).copy()); } m_PlotInstances = new Instances( m_Instances.relationName() + "_predicted", hv, m_Instances.numInstances()); m_PlotInstances.setClassIndex(m_ClassIndex + 1); } /** * Process a classifier's prediction for an instance and update a * set of plotting instances and additional plotting info. m_PlotShape * for nominal class datasets holds shape types (actual data points have * automatic shape type assignment; classifier error data points have * box shape type). For numeric class datasets, the actual data points * are stored in m_PlotInstances and m_PlotSize stores the error (which is * later converted to shape size values). * * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the classifier on * the instance to predict * @see #m_PlotShapes * @see #m_PlotSizes * @see #m_PlotInstances */ public void process(Instance toPredict, Classifier classifier, Evaluation eval) { double pred; double[] values; int i; try { pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { toPredict = ((weka.classifiers.misc.InputMappedClassifier)classifier). constructMappedInstance(toPredict); } if (!m_SaveForVisualization) return; if (m_PlotInstances != null) { values = new double[m_PlotInstances.numAttributes()]; for (i = 0; i < m_PlotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i+1] = toPredict.value(i); i++; } else { values[i] = toPredict.value(i-1); } } m_PlotInstances.add(new DenseInstance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) { m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape m_PlotShapes.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant (automatically assigned) point shape m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } m_PlotSizes.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) { errd = new Double(pred - toPredict.classValue()); m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present or prediction is missing m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE)); } m_PlotSizes.addElement(errd); } } } catch (Exception ex) { ex.printStackTrace(); } } /** * Scales numeric class predictions into shape sizes for plotting * in the visualize panel. */ protected void scaleNumericPredictions() { double maxErr; double minErr; double err; int i; Double errd; double temp; maxErr = Double.NEGATIVE_INFINITY; minErr = Double.POSITIVE_INFINITY; // find min/max errors for (i = 0; i < m_PlotSizes.size(); i++) { errd = (Double) m_PlotSizes.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (err < minErr) minErr = err; if (err > maxErr) maxErr = err; } } // scale errors for (i = 0; i < m_PlotSizes.size(); i++) { errd = (Double) m_PlotSizes.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (maxErr - minErr > 0) { temp = (((err - minErr) / (maxErr - minErr)) * (m_MaximumPlotSizeNumeric - m_MinimumPlotSizeNumeric + 1)); m_PlotSizes.setElementAt(new Integer((int) temp) + m_MinimumPlotSizeNumeric, i); } else { m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i); } } else { m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i); } } } /** * Adds the prediction intervals as additional attributes at the end. * Since classifiers can returns varying number of intervals per instance, * the dataset is filled with missing values for non-existing intervals. */ protected void addPredictionIntervals() { int maxNum; int num; int i; int n; FastVector preds; FastVector atts; Instances data; Instance inst; Instance newInst; double[] values; double[][] predInt; // determine the maximum number of intervals maxNum = 0; preds = m_Evaluation.predictions(); for (i = 0; i < preds.size(); i++) { num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length; if (num > maxNum) maxNum = num; } // create new header atts = new FastVector(); for (i = 0; i < m_PlotInstances.numAttributes(); i++) atts.addElement(m_PlotInstances.attribute(i)); for (i = 0; i < maxNum; i++) { atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-lowerBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-upperBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-width")); } data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances()); data.setClassIndex(m_PlotInstances.classIndex()); // update data for (i = 0; i < m_PlotInstances.numInstances(); i++) { inst = m_PlotInstances.instance(i); // copy old values values = new double[data.numAttributes()]; System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes()); // add interval data predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals(); for (n = 0; n < maxNum; n++) { if (n < predInt.length){ values[m_PlotInstances.numAttributes() + n*3 + 0] = predInt[n][0]; values[m_PlotInstances.numAttributes() + n*3 + 1] = predInt[n][1]; values[m_PlotInstances.numAttributes() + n*3 + 2] = predInt[n][1] - predInt[n][0]; } else { values[m_PlotInstances.numAttributes() + n*3 + 0] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n*3 + 1] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n*3 + 2] = Utils.missingValue(); } } // create new Instance newInst = new DenseInstance(inst.weight(), values); data.add(newInst); } m_PlotInstances = data; } /** * Performs optional post-processing. * * @see #scaleNumericPredictions() * @see #addPredictionIntervals() */ protected void finishUp() { super.finishUp(); if (!m_SaveForVisualization) return; if (m_Instances.attribute(m_ClassIndex).isNumeric()) { scaleNumericPredictions(); if (m_Classifier instanceof IntervalEstimator) addPredictionIntervals(); } } /** * Assembles and returns the plot. The relation name of the dataset gets * added automatically. * * @param name the name of the plot * @return the plot or null if plot instances weren't saved for visualization * @throws Exception if plot generation fails */ protected PlotData2D createPlotData(String name) throws Exception { PlotData2D result; if (!m_SaveForVisualization) return null; result = new PlotData2D(m_PlotInstances); result.setShapeSize(m_PlotSizes); result.setShapeType(m_PlotShapes); result.setPlotName(name + " (" + m_Instances.relationName() + ")"); // result.addInstanceNumberAttribute(); return result; } /** * For freeing up memory. Plot data cannot be generated after this call! */ public void cleanUp() { super.cleanUp(); m_Classifier = null; m_PlotShapes = null; m_PlotSizes = null; m_Evaluation = null; } }