/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* BoundaryVisualizer.java
* Copyright (C) 2002 Mark Hall
*
*/
package weka.gui.boundaryvisualizer;
import java.awt.*;
import java.awt.event.*;
import javax.swing.JPanel;
import javax.swing.JComboBox;
import javax.swing.JButton;
import javax.swing.JTextField;
import javax.swing.BorderFactory;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JLabel;
import java.util.Vector;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.DistributionClassifier;
/**
* BoundaryVisualizer. Allows the visualization of classifier decision
* boundaries in two dimensions. A supplied classifier is first trained on
* supplied training data, then a data generator (currently using
* kernels) is used to generate new instances at points fixed in the two
* visualization dimensions but random in the other dimensions. These
* instances are classified by the classifier and plotted as points with
* colour corresponding to the probability distribution predicted by the
* classifier. At present, 2 * number of training instances are generated
* for each pixel in the display. predicted probability distributions are
* weighted (acording to the fixed visualization dimensions) and averaged
* to produce an RGB value for the pixel.
*
*
* @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
* @version $Revision: 1.1.1.1 $
* @since 1.0
* @see JPanel
*/
public class BoundaryVisualizer extends JPanel {
/**
* Inner class to handle rendering the axis
*
* @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
* @version 1.0
* @since 1.0
* @see JPanel
*/
private class AxisPanel extends JPanel {
private static final int MAX_PRECISION = 10;
private boolean m_vertical = false;
private final int PAD = 5;
private FontMetrics m_fontMetrics;
private int m_fontHeight;
public AxisPanel(boolean vertical) {
m_vertical = vertical;
this.setBackground(Color.black);
// Graphics g = this.getGraphics();
String fontFamily = this.getFont().getFamily();
Font newFont = new Font(fontFamily, Font.PLAIN, 10);
this.setFont(newFont);
}
public Dimension getPreferredSize() {
if (m_fontMetrics == null) {
Graphics g = this.getGraphics();
m_fontMetrics = g.getFontMetrics();
m_fontHeight = m_fontMetrics.getHeight();
}
if (!m_vertical) {
return new Dimension(this.getSize().width, PAD+2+m_fontHeight);
}
return new Dimension(50, this.getSize().height);
}
public void paintComponent(Graphics g) {
super.paintComponent(g);
this.setBackground(Color.black);
if (m_fontMetrics == null) {
m_fontMetrics = g.getFontMetrics();
m_fontHeight = m_fontMetrics.getHeight();
}
Dimension d = this.getSize();
Dimension d2 = m_boundaryPanel.getSize();
g.setColor(Color.gray);
int hf = m_fontMetrics.getAscent();
if (!m_vertical) {
g.drawLine(d.width, PAD, d.width-d2.width, PAD);
// try and draw some scale values
if (getInstances() != null) {
int precisionXmax = 1;
int precisionXmin = 1;
int whole = (int)Math.abs(m_maxX);
double decimal = Math.abs(m_maxX) - whole;
int nondecimal;
nondecimal = (whole > 0)
? (int)(Math.log(whole) / Math.log(10))
: 1;
precisionXmax = (decimal > 0)
? (int)Math.abs(((Math.log(Math.abs(m_maxX)) /
Math.log(10))))+2
: 1;
if (precisionXmax > MAX_PRECISION) {
precisionXmax = 1;
}
String maxStringX = Utils.doubleToString(m_maxX,
nondecimal+1+precisionXmax
,precisionXmax);
whole = (int)Math.abs(m_minX);
decimal = Math.abs(m_minX) - whole;
nondecimal = (whole > 0)
? (int)(Math.log(whole) / Math.log(10))
: 1;
precisionXmin = (decimal > 0)
? (int)Math.abs(((Math.log(Math.abs(m_minX)) /
Math.log(10))))+2
: 1;
if (precisionXmin > MAX_PRECISION) {
precisionXmin = 1;
}
String minStringX = Utils.doubleToString(m_minX,
nondecimal+1+precisionXmin,
precisionXmin);
g.drawString(minStringX, d.width-d2.width, PAD+hf+2);
int maxWidth = m_fontMetrics.stringWidth(maxStringX);
g.drawString(maxStringX, d.width-maxWidth, PAD+hf+2);
}
} else {
g.drawLine(d.width-PAD, 0, d.width-PAD, d2.height);
// try and draw some scale values
if (getInstances() != null) {
int precisionYmax = 1;
int precisionYmin = 1;
int whole = (int)Math.abs(m_maxY);
double decimal = Math.abs(m_maxY) - whole;
int nondecimal;
nondecimal = (whole > 0)
? (int)(Math.log(whole) / Math.log(10))
: 1;
precisionYmax = (decimal > 0)
? (int)Math.abs(((Math.log(Math.abs(m_maxY)) /
Math.log(10))))+2
: 1;
if (precisionYmax > MAX_PRECISION) {
precisionYmax = 1;
}
String maxStringY = Utils.doubleToString(m_maxY,
nondecimal+1+precisionYmax
,precisionYmax);
whole = (int)Math.abs(m_minY);
decimal = Math.abs(m_minY) - whole;
nondecimal = (whole > 0)
? (int)(Math.log(whole) / Math.log(10))
: 1;
precisionYmin = (decimal > 0)
? (int)Math.abs(((Math.log(Math.abs(m_minY)) /
Math.log(10))))+2
: 1;
if (precisionYmin > MAX_PRECISION) {
precisionYmin = 1;
}
String minStringY = Utils.doubleToString(m_minY,
nondecimal+1+precisionYmin,
precisionYmin);
int maxWidth = m_fontMetrics.stringWidth(minStringY);
g.drawString(minStringY, d.width-PAD-maxWidth-2, d2.height);
maxWidth = m_fontMetrics.stringWidth(maxStringY);
g.drawString(maxStringY, d.width-PAD-maxWidth-2, hf);
}
}
}
}
// the training instances
private Instances m_trainingInstances;
// the classifier to use
private DistributionClassifier m_classifier;
// plot area dimensions
private int m_plotAreaWidth = 512;
private int m_plotAreaHeight = 384;
// the plotting panel
private BoundaryPanel m_boundaryPanel;
// combo boxes for selecting the class attribute, class values (for
// colouring pixels), and visualization attributes
private JComboBox m_classAttBox = new JComboBox();
private JComboBox m_redClassValueBox = new JComboBox();
private JComboBox m_greenClassValueBox = new JComboBox();
private JComboBox m_blueClassValueBox = new JComboBox();
private JComboBox m_xAttBox = new JComboBox();
private JComboBox m_yAttBox = new JComboBox();
private Dimension COMBO_SIZE =
new Dimension(m_plotAreaWidth / 2,
m_classAttBox.getPreferredSize().height);
private JButton m_startBut = new JButton("Start");
private JPanel m_controlPanel;
// separate panels for rendering axis information
private AxisPanel m_xAxisPanel;
private AxisPanel m_yAxisPanel;
// min and max values for visualization dimensions
private double m_maxX;
private double m_maxY;
private double m_minX;
private double m_minY;
private int m_xIndex;
private int m_yIndex;
// the number of samples to use from each kernel when plotting pixels
private int m_numberOfSamplesFromEachGeneratingModel = 2;
private JTextField m_samplesText =
new JTextField(""+m_numberOfSamplesFromEachGeneratingModel);
/**
* Creates a new <code>BoundaryVisualizer</code> instance.
*/
public BoundaryVisualizer() {
setLayout(new BorderLayout());
m_classAttBox.setMinimumSize(COMBO_SIZE);
m_classAttBox.setPreferredSize(COMBO_SIZE);
m_classAttBox.setMaximumSize(COMBO_SIZE);
m_redClassValueBox.setMinimumSize(COMBO_SIZE);
m_redClassValueBox.setPreferredSize(COMBO_SIZE);
m_redClassValueBox.setMaximumSize(COMBO_SIZE);
m_greenClassValueBox.setMinimumSize(COMBO_SIZE);
m_greenClassValueBox.setPreferredSize(COMBO_SIZE);
m_greenClassValueBox.setMaximumSize(COMBO_SIZE);
m_blueClassValueBox.setMinimumSize(COMBO_SIZE);
m_blueClassValueBox.setPreferredSize(COMBO_SIZE);
m_blueClassValueBox.setMaximumSize(COMBO_SIZE);
m_xAttBox.setMinimumSize(COMBO_SIZE);
m_xAttBox.setPreferredSize(COMBO_SIZE);
m_xAttBox.setMaximumSize(COMBO_SIZE);
m_yAttBox.setMinimumSize(COMBO_SIZE);
m_yAttBox.setPreferredSize(COMBO_SIZE);
m_yAttBox.setMaximumSize(COMBO_SIZE);
m_controlPanel = new JPanel();
m_controlPanel.setLayout(new BorderLayout());
JPanel cHolder = new JPanel();
cHolder.setBorder(BorderFactory.createTitledBorder("Class Attribute"));
cHolder.add(m_classAttBox);
JPanel cValHolder = new JPanel();
cValHolder.setLayout(new GridLayout(3,1));
cValHolder.setBorder(BorderFactory.createTitledBorder("Class Values"));
cValHolder.add(m_redClassValueBox);
cValHolder.add(m_greenClassValueBox);
cValHolder.add(m_blueClassValueBox);
JPanel vAttHolder = new JPanel();
vAttHolder.setLayout(new GridLayout(2,1));
vAttHolder.setBorder(BorderFactory.
createTitledBorder("Visualization Attributes"));
vAttHolder.add(m_xAttBox);
vAttHolder.add(m_yAttBox);
JPanel colOne = new JPanel();
colOne.setLayout(new BorderLayout());
colOne.add(cHolder, BorderLayout.NORTH);
colOne.add(vAttHolder, BorderLayout.CENTER);
// JPanel samplesHolder = new JPanel();
m_samplesText.setBorder(BorderFactory.
createTitledBorder("Num. samples per generator"));
m_samplesText.setBackground(colOne.getBackground());
// samplesHolder.add(m_samplesText);
colOne.add(m_samplesText, BorderLayout.SOUTH);
JPanel colTwo = new JPanel();
colTwo.setLayout(new BorderLayout());
colTwo.add(cValHolder, BorderLayout.NORTH);
JPanel startPanel = new JPanel();
startPanel.setBorder(BorderFactory.
createTitledBorder("Start/Stop"));
startPanel.setLayout(new BorderLayout());
startPanel.add(m_startBut, BorderLayout.CENTER);
colTwo.add(startPanel, BorderLayout.SOUTH);
m_controlPanel.add(colOne, BorderLayout.WEST);
m_controlPanel.add(colTwo, BorderLayout.CENTER);
/* m_controlPanel.add(cHolder);
m_controlPanel.add(cValHolder);
m_controlPanel.add(vAttHolder);
m_controlPanel.add(m_startBut); */
add(m_controlPanel, BorderLayout.NORTH);
m_boundaryPanel = new BoundaryPanel(m_plotAreaWidth, m_plotAreaHeight);
m_boundaryPanel.setDataGenerator(new KDDataGenerator());
add(m_boundaryPanel, BorderLayout.CENTER);
m_xAxisPanel = new AxisPanel(false);
add(m_xAxisPanel, BorderLayout.SOUTH);
m_yAxisPanel = new AxisPanel(true);
add(m_yAxisPanel, BorderLayout.WEST);
m_startBut.setEnabled(false);
m_startBut.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
if (m_startBut.getText().equals("Start")) {
if (m_trainingInstances != null && m_classifier != null) {
try {
int tempSamples = m_numberOfSamplesFromEachGeneratingModel;
try {
tempSamples =
Integer.parseInt(m_samplesText.getText().trim());
} catch (Exception ex) {
m_samplesText.setText(""+tempSamples);
}
m_numberOfSamplesFromEachGeneratingModel = tempSamples;
m_boundaryPanel.
setNumberOfSamplesFromEachGeneratingModel(tempSamples);
m_trainingInstances.
setClassIndex(m_classAttBox.getSelectedIndex());
m_boundaryPanel.setClassifier(m_classifier);
m_boundaryPanel.setTrainingData(m_trainingInstances);
m_boundaryPanel.
setRedClassValue(m_redClassValueBox.getSelectedIndex());
m_boundaryPanel.
setGreenClassValue(m_greenClassValueBox.getSelectedIndex());
m_boundaryPanel.
setBlueClassValue(m_blueClassValueBox.getSelectedIndex());
m_boundaryPanel.setXAttribute(m_xIndex);
m_boundaryPanel.setYAttribute(m_yIndex);
m_boundaryPanel.start();
m_startBut.setText("Stop");
setControlEnabledStatus(false);
} catch (Exception ex) {
ex.printStackTrace();
}
}
} else {
m_boundaryPanel.stopPlotting();
m_startBut.setText("Start");
setControlEnabledStatus(true);
}
}
});
m_boundaryPanel.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
m_startBut.setText("Start");
setControlEnabledStatus(true);
}
});
}
/**
* Set the enabled status of the controls
*
* @param status a <code>boolean</code> value
*/
private void setControlEnabledStatus(boolean status) {
m_classAttBox.setEnabled(status);
m_redClassValueBox.setEnabled(status);
m_greenClassValueBox.setEnabled(status);
m_blueClassValueBox.setEnabled(status);
m_xAttBox.setEnabled(status);
m_yAttBox.setEnabled(status);
m_samplesText.setEnabled(status);
}
/**
* Set a classifier to use
*
* @param newClassifier the classifier to use
* @exception Exception if an error occurs
*/
public void setClassifier(Classifier newClassifier) throws Exception {
if (!(newClassifier instanceof DistributionClassifier)) {
throw new Exception("Classifier must be a distribution classifier!");
}
m_classifier = (DistributionClassifier)newClassifier;
}
private void computeBounds() {
String xName = (String)m_xAttBox.getSelectedItem();
if (xName == null) {
return;
}
xName = Utils.removeSubstring(xName, "X: ");
xName = Utils.removeSubstring(xName, " (Num)");
String yName = (String)m_yAttBox.getSelectedItem();
yName = Utils.removeSubstring(yName, "Y: ");
yName = Utils.removeSubstring(yName, " (Num)");
m_xIndex = -1;
m_yIndex = -1;
for (int i = 0; i < m_trainingInstances.numAttributes(); i++) {
if (m_trainingInstances.attribute(i).name().equals(xName)) {
m_xIndex = i;
}
if (m_trainingInstances.attribute(i).name().equals(yName)) {
m_yIndex = i;
}
}
if (m_xIndex != -1 && m_yIndex != -1) {
// find the min and max values
m_minX = Double.MAX_VALUE;
m_minY = Double.MAX_VALUE;
m_maxX = Double.MIN_VALUE;
m_maxY = Double.MIN_VALUE;
for (int i = 0; i < m_trainingInstances.numInstances(); i++) {
Instance inst = m_trainingInstances.instance(i);
if (!inst.isMissing(m_xIndex)) {
double value = inst.value(m_xIndex);
if (value < m_minX) {
m_minX = value;
}
if (value > m_maxX) {
m_maxX = value;
}
}
if (!inst.isMissing(m_yIndex)) {
double value = inst.value(m_yIndex);
if (value < m_minY) {
m_minY = value;
}
if (value > m_maxY) {
m_maxY = value;
}
}
}
}
}
/**
* Get the training instances
*
* @return the training instances
*/
public Instances getInstances() {
return m_trainingInstances;
}
/**
* Set the training instances
*
* @param inst the instances to use
*/
public void setInstances(Instances inst) throws Exception {
if (inst.numAttributes() < 3) {
throw new Exception("Not enough attributes in the data to visualize!");
}
m_trainingInstances = inst;
// setup combo boxes
String [] classAttNames = new String [m_trainingInstances.numAttributes()];
Vector xAttNames = new Vector();
Vector yAttNames = new Vector();
for (int i = 0; i < m_trainingInstances.numAttributes(); i++) {
classAttNames[i] = m_trainingInstances.attribute(i).name();
if (m_trainingInstances.attribute(i).isNominal()) {
classAttNames[i] += " (Nom)";
} else {
classAttNames[i] += " (Num)";
}
if (m_trainingInstances.attribute(i).isNumeric()) {
xAttNames.addElement("X: "+classAttNames[i]);
yAttNames.addElement("Y: "+classAttNames[i]);
}
}
m_classAttBox.setModel(new DefaultComboBoxModel(classAttNames));
m_xAttBox.setModel(new DefaultComboBoxModel(xAttNames));
m_yAttBox.setModel(new DefaultComboBoxModel(yAttNames));
if (xAttNames.size() > 1) {
m_yAttBox.setSelectedIndex(1);
}
m_classAttBox.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
setUpClassValuesBoxes();
}
});
m_xAttBox.addItemListener(new ItemListener() {
public void itemStateChanged(ItemEvent e) {
if (e.getStateChange() == ItemEvent.SELECTED) {
computeBounds();
repaint();
}
}
});
m_yAttBox.addItemListener(new ItemListener() {
public void itemStateChanged(ItemEvent e) {
if (e.getStateChange() == ItemEvent.SELECTED) {
computeBounds();
repaint();
}
}
});
computeBounds();
revalidate();
repaint();
}
/**
* Set up the class values combo boxes
*/
private void setUpClassValuesBoxes() {
int classIndex = m_classAttBox.getSelectedIndex();
if (classIndex >= 0) {
// see if this is a nominal attribute
if (m_trainingInstances.attribute(classIndex).isNominal()) {
Vector rNames = new Vector();
Vector gNames = new Vector();
Vector bNames = new Vector();
for (int i = 0;
i < m_trainingInstances.attribute(classIndex).numValues(); i++) {
String name = m_trainingInstances.attribute(classIndex).value(i);
rNames.addElement("Red: "+name);
gNames.addElement("Green: "+name);
bNames.addElement("Blue: "+name);
}
m_redClassValueBox.setModel(new DefaultComboBoxModel(rNames));
m_greenClassValueBox.setModel(new DefaultComboBoxModel(gNames));
m_blueClassValueBox.setModel(new DefaultComboBoxModel(bNames));
if (gNames.size() > 1) {
m_greenClassValueBox.setSelectedIndex(1);
m_blueClassValueBox.setSelectedIndex(1);
}
if (bNames.size() > 2) {
m_blueClassValueBox.setSelectedIndex(2);
}
if (m_xAttBox.getSelectedIndex() >= 0 &&
m_yAttBox.getSelectedIndex() >= 0) {
m_startBut.setEnabled(true);
}
} else {
((DefaultComboBoxModel)m_redClassValueBox.getModel())
.removeAllElements();
((DefaultComboBoxModel)m_greenClassValueBox.getModel())
.removeAllElements();
((DefaultComboBoxModel)m_blueClassValueBox.getModel())
.removeAllElements();
m_startBut.setEnabled(false);
}
}
}
/**
* Main method for testing this class
*
* @param args a <code>String[]</code> value
*/
public static void main(String [] args) {
try {
if (args.length < 2) {
System.err.println("Usage : BoundaryPanel <dataset> <classifier "
+"[classifier options]>");
System.exit(1);
}
final javax.swing.JFrame jf =
new javax.swing.JFrame("Weka classification boundary visualizer");
jf.getContentPane().setLayout(new BorderLayout());
BoundaryVisualizer bv = new BoundaryVisualizer();
jf.getContentPane().add(bv, BorderLayout.CENTER);
jf.setSize(bv.getMinimumSize());
jf.addWindowListener(new java.awt.event.WindowAdapter() {
public void windowClosing(java.awt.event.WindowEvent e) {
jf.dispose();
System.exit(0);
}
});
jf.pack();
jf.setVisible(true);
jf.setResizable(false);
Dimension t = jf.getSize();
System.err.println("Loading instances from : "+args[0]);
java.io.Reader r = new java.io.BufferedReader(
new java.io.FileReader(args[0]));
Instances i = new Instances(r);
bv.setInstances(i);
String [] argsR = null;
if (args.length > 2) {
argsR = new String [args.length-2];
for (int j = 2; j < args.length; j++) {
argsR[j-2] = args[j];
}
}
Classifier c = Classifier.forName(args[1], argsR);
bv.setClassifier(c);
} catch (Exception ex) {
ex.printStackTrace();
}
}
}