/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * BoundaryPanel.java * Copyright (C) 2002 Mark Hall * */ package weka.gui.boundaryvisualizer; import java.awt.*; import java.awt.event.*; import javax.swing.JPanel; import java.util.Vector; import weka.core.*; import weka.classifiers.Classifier; import weka.classifiers.DistributionClassifier; import weka.classifiers.bayes.NaiveBayesSimple; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.trees.j48.J48; import weka.classifiers.lazy.IBk; import weka.classifiers.functions.Logistic; import weka.clusterers.EM; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; import weka.filters.unsupervised.attribute.Add; /** * BoundaryPanel. A class to handle the plotting operations * associated with generating a 2D picture of a classifier's decision * boundaries. * * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a> * @version $Revison: 1.0 $ * @since 1.0 * @see JPanel */ public class BoundaryPanel extends JPanel { // training data private Instances m_trainingData; // distribution classifier to use private DistributionClassifier m_classifier; // data generator to use private DataGenerator m_dataGenerator; // index of the class attribute private int m_classIndex = -1; // indexes of class values (any other classes are grouped together // and correspond to the colour black private int [] m_rgbClassValues = new int [3]; // attributes for visualizing on private int m_xAttribute; private int m_yAttribute; // min, max and ranges of these attributes private double m_minX; private double m_minY; private double m_maxX; private double m_maxY; private double m_rangeX; private double m_rangeY; // pixel width and height in terms of attribute values private double m_pixHeight; private double m_pixWidth; // used for offscreen drawing private Image m_osi = null; // width and height of the display area private int m_panelWidth; private int m_panelHeight; // the number of samples to take from each generating model for each // pixel private int m_numberOfSamplesFromEachGeneratingModel = 2; // listeners to be notified when plot is complete private Vector m_listeners = new Vector(); // small inner class for rendering the bitmap on to private class PlotPanel extends JPanel { public void paintComponent(Graphics g) { super.paintComponent(g); if (m_osi != null) { g.drawImage(m_osi,0,0,this); } } } // the actual plotting area private PlotPanel m_plotPanel = new PlotPanel(); // thread for running the plotting operation in private Thread m_plotThread = null; // Stop the plotting thread private boolean m_stopPlotting = false; /** * Creates a new <code>BoundaryPanel</code> instance. * * @param panelWidth the width in pixels of the panel * @param panelHeight the height in pixels of the panel */ public BoundaryPanel(int panelWidth, int panelHeight) { m_panelWidth = panelWidth; m_panelHeight = panelHeight; setLayout(new BorderLayout()); m_plotPanel.setMinimumSize(new Dimension(m_panelWidth, m_panelHeight)); m_plotPanel.setPreferredSize(new Dimension(m_panelWidth, m_panelHeight)); m_plotPanel.setMaximumSize(new Dimension(m_panelWidth, m_panelHeight)); add(m_plotPanel, BorderLayout.CENTER); setPreferredSize(m_plotPanel.getPreferredSize()); setMaximumSize(m_plotPanel.getMaximumSize()); setMinimumSize(m_plotPanel.getMinimumSize()); for (int i = 0; i < 3; i++) { m_rgbClassValues[i] = -1; } } /** * Set up the off screen bitmap for rendering to */ protected void initialize() { int iwidth = m_plotPanel.getWidth(); int iheight = m_plotPanel.getHeight(); // System.err.println(iwidth+" "+iheight); m_osi = m_plotPanel.createImage(iwidth, iheight); Graphics m = m_osi.getGraphics(); m.fillRect(0,0,iwidth,iheight); } /** * Stop the plotting thread */ public void stopPlotting() { m_stopPlotting = true; } private void computeMinMaxAtts() { m_minX = Double.MAX_VALUE; m_minY = Double.MAX_VALUE; m_maxX = Double.MIN_VALUE; m_maxY = Double.MIN_VALUE; for (int i = 0; i < m_trainingData.numInstances(); i++) { Instance inst = m_trainingData.instance(i); double x = inst.value(m_xAttribute); double y = inst.value(m_yAttribute); if (x != Instance.missingValue()) { if (x < m_minX) { m_minX = x; } if (x > m_maxX) { m_maxX = x; } } if (y != Instance.missingValue()) { if (y < m_minY) { m_minY = y; } if (y > m_maxY) { m_maxY = y; } } } m_rangeX = (m_maxX - m_minX); m_rangeY = (m_maxY - m_minY); m_pixWidth = m_rangeX / (double)m_panelWidth; m_pixHeight = m_rangeY / (double) m_panelHeight; } /** * Return the x attribute value that corresponds to the middle of * the pix'th horizontal pixel * * @param pix the horizontal pixel number * @return a value in attribute space */ private double getMidX(int pix) { double midX = m_minX + (pix * m_pixWidth); midX += (m_pixWidth / 2.0); return midX; } /** * Return the y attribute value that corresponds to the middle of * the pix'th vertical pixel * * @param pix the vertical pixel number * @return a value in attribute space */ private double getMidY(int pix) { double midY = m_minY + (pix * m_pixHeight); midY += (m_pixHeight / 2.0); return midY; } /** * Start the plotting thread * * @exception Exception if an error occurs */ public void start() throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (m_classifier == null) { throw new Exception("No classifier set (BoundaryPanel)"); } if (m_dataGenerator == null) { throw new Exception("No data generator set (BoundaryPanel)"); } if (m_trainingData.attribute(m_xAttribute).isNominal() || m_trainingData.attribute(m_yAttribute).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } computeMinMaxAtts(); if (m_plotThread == null) { m_plotThread = new Thread() { public void run() { m_stopPlotting = false; try { /* if (m_osi == null) { initialize(); repaint(); } */ initialize(); repaint(); // train the classifier // System.err.println("Building classifier..."); m_classifier.buildClassifier(m_trainingData); // make a copy of the training data minus the class Remove rmF = new Remove(); rmF.setAttributeIndices(""+(m_classIndex+1)); rmF.setInvertSelection(false); rmF.setInputFormat(m_trainingData); rmF.setAttributeIndices(""+(m_classIndex+1)); Instances trainNoClass = Filter.useFilter(m_trainingData, rmF); Instances trainNoClassHeader = new Instances(trainNoClass,0); // build DataGenerator // System.err.println("Building data generator..."); boolean [] attsToWeightOn = new boolean[trainNoClass.numAttributes()]; attsToWeightOn[m_xAttribute] = true; attsToWeightOn[m_yAttribute] = true; m_dataGenerator.setWeightingDimensions(attsToWeightOn); m_dataGenerator.buildGenerator(trainNoClass); int samplesPerPixel = m_numberOfSamplesFromEachGeneratingModel * m_dataGenerator.getNumGeneratingModels(); // generate samples Add addF = new Add(); addF.setInputFormat(trainNoClass); addF.setAttributeIndex(m_classIndex); double pixelMidX = 0; double pixelMidY = 0; double [] dist; double [] weightingAttsValues = new double [attsToWeightOn.length]; double [] vals = new double[m_trainingData.numAttributes()]; Instance predInst = new Instance(1.0, vals); predInst.setDataset(m_trainingData); abortPlot: for (int i = 0; i < m_panelHeight; i++) { pixelMidY = getMidY(m_panelHeight-i-1); for (int j = 0; j < m_panelWidth; j++) { if (m_stopPlotting) { break abortPlot; } pixelMidX = getMidX(j); double sumOfWeights = 0; double [] sumOfProbs = new double [m_trainingData.classAttribute().numValues()]; for (int z = 0; z < samplesPerPixel; z++) { weightingAttsValues[m_xAttribute] = pixelMidX; weightingAttsValues[m_yAttribute] = pixelMidY; m_dataGenerator.setWeightingValues(weightingAttsValues); Instance newInst = m_dataGenerator.generateInstanceFast(); sumOfWeights += newInst.weight(); int index = 0; for (int k = 0; k < predInst.numAttributes(); k++) { if (k != m_trainingData.classIndex()) { vals[k] = newInst.value(index); index++; } } // classify the instance dist = m_classifier.distributionForInstance(predInst); for (int k = 0; k < sumOfProbs.length; k++) { sumOfProbs[k] += (dist[k] * newInst.weight()); } } // average Utils.normalize(sumOfProbs, sumOfWeights); // plot the point Graphics osg = m_osi.getGraphics(); Graphics g = m_plotPanel.getGraphics(); float [] colVal = new float[3]; for (int k = 0; k < 3; k++) { if (k < sumOfProbs.length) { if (m_rgbClassValues[k] != -1) { colVal[k] = (float)sumOfProbs[m_rgbClassValues[k]]; } } if (colVal[k] < 0) { colVal[k] = 0; } if (colVal[k] > 1) { colVal[k] = 1; } } osg.setColor(new Color(colVal[0], colVal[1], colVal[2])); osg.drawLine(j,i,j,i); if (j == 0) { g.drawImage(m_osi,0,0,m_plotPanel); } } } } catch (Exception ex) { ex.printStackTrace(); } finally { m_plotThread = null; // notify any listeners that we are finished Vector l; ActionEvent e = new ActionEvent(this, 0, ""); synchronized(this) { l = (Vector)m_listeners.clone(); } for (int i = 0; i < l.size(); i++) { ActionListener al = (ActionListener)l.elementAt(i); al.actionPerformed(e); } } } }; m_plotThread.setPriority(Thread.MIN_PRIORITY); m_plotThread.start(); } } /** * Set how many samples to take from each generating model for use in * computing the colour of a pixel * * @param ns the number of samples to use from each generating model */ public void setNumberOfSamplesFromEachGeneratingModel(int ns) { if (ns >= 1) { m_numberOfSamplesFromEachGeneratingModel = ns; } } /** * Set the training data to use * * @param trainingData the training data * @exception Exception if an error occurs */ public void setTrainingData(Instances trainingData) throws Exception { m_trainingData = trainingData; if (m_trainingData.classIndex() < 0) { throw new Exception("No class attribute set (BoundaryPanel)"); } m_classIndex = m_trainingData.classIndex(); } /** * Register a listener to be notified when plotting completes * * @param newListener the listener to add */ public void addActionListener(ActionListener newListener) { m_listeners.add(newListener); } /** * Remove a listener * * @param removeListener the listener to remove */ public void removeActionListener(ActionListener removeListener) { m_listeners.removeElement(removeListener); } /** * Set the classifier to use. * * @param classifier the classifier to use */ public void setClassifier(DistributionClassifier classifier) { m_classifier = classifier; } /** * Set the data generator to use for generating new instances * * @param dataGenerator the data generator to use */ public void setDataGenerator(DataGenerator dataGenerator) { m_dataGenerator = dataGenerator; } /** * Set the class value index for the red colour * * @param classVal an <code>int</code> value * @exception Exception if an error occurs */ public void setRedClassValue(int classVal) throws Exception { setClassValue(0, classVal); } /** * Set the class value index for the green colour * * @param classVal an <code>int</code> value * @exception Exception if an error occurs */ public void setGreenClassValue(int classVal) throws Exception { setClassValue(1, classVal); } /** * Set the class value index for the blue colour * * @param classVal an <code>int</code> value * @exception Exception if an error occurs */ public void setBlueClassValue(int classVal) throws Exception { setClassValue(2, classVal); } /** * Set a class value for a particular colour (RGB) * * @param index the colour - 0 = red, 1 = green, 2 = blue * @param classVal the class value index to associate with the colour * @exception Exception if an error occurs */ private void setClassValue(int index, int classVal) throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (classVal < 0 || classVal > m_trainingData.classAttribute().numValues()) { throw new Exception("Class value out of range (BoundaryPanel)"); } m_rgbClassValues[index] = classVal; } /** * Set the x attribute index * * @param xatt index of the attribute to use on the x axis * @exception Exception if an error occurs */ public void setXAttribute(int xatt) throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (xatt < 0 || xatt > m_trainingData.numAttributes()) { throw new Exception("X attribute out of range (BoundaryPanel)"); } if (m_trainingData.attribute(xatt).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } if (m_trainingData.numDistinctValues(xatt) < 2) { throw new Exception("Too few distinct values for X attribute " +"(BoundaryPanel)"); } m_xAttribute = xatt; } /** * Set the y attribute index * * @param yatt index of the attribute to use on the y axis * @exception Exception if an error occurs */ public void setYAttribute(int yatt) throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (yatt < 0 || yatt > m_trainingData.numAttributes()) { throw new Exception("X attribute out of range (BoundaryPanel)"); } if (m_trainingData.attribute(yatt).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } if (m_trainingData.numDistinctValues(yatt) < 2) { throw new Exception("Too few distinct values for Y attribute " +"(BoundaryPanel)"); } m_yAttribute = yatt; } /** * Main method for testing this class * * @param args a <code>String[]</code> value */ public static void main (String [] args) { try { if (args.length < 7) { System.err.println("Usage : BoundaryPanel <dataset> " +"<class col> <Red classVal(index)> " +"<Green classVal(index)> " +"<Blue classVal(index)> <xAtt> <yAtt>"); System.exit(1); } final javax.swing.JFrame jf = new javax.swing.JFrame("Weka classification boundary visualizer"); jf.getContentPane().setLayout(new BorderLayout()); final BoundaryPanel bv = new BoundaryPanel(200,200); jf.getContentPane().add(bv, BorderLayout.CENTER); jf.setSize(bv.getMinimumSize()); // jf.setSize(200,200); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); System.exit(0); } }); jf.pack(); jf.setVisible(true); // bv.initialize(); bv.repaint(); System.err.println("Loading instances from : "+args[0]); java.io.Reader r = new java.io.BufferedReader( new java.io.FileReader(args[0])); Instances i = new Instances(r); i.setClassIndex(Integer.parseInt(args[1])); bv.setTrainingData(i); bv.setClassifier(new Logistic()); bv.setDataGenerator(new KDDataGenerator()); bv.setRedClassValue(Integer.parseInt(args[2])); bv.setGreenClassValue(Integer.parseInt(args[3])); bv.setBlueClassValue(Integer.parseInt(args[4])); bv.setXAttribute(Integer.parseInt(args[5])); bv.setYAttribute(Integer.parseInt(args[6])); bv.start(); } catch (Exception ex) { ex.printStackTrace(); } } }