/*
* Copyright (c) 2011-2016, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.gui.learning;
import boofcv.gui.image.ShowImages;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.RandomMatrices;
import javax.swing.*;
import java.awt.*;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Visualizes a confusion matrix. Each element is assumed to have a value from 0 to 1.0
*
* @author Peter Abeles
*/
public class ConfusionMatrixPanel extends JPanel {
DenseMatrix64F temp = new DenseMatrix64F(1,1);
DenseMatrix64F confusion = new DenseMatrix64F(1,1);
boolean dirty = false;
boolean gray = false;
boolean showNumbers = true;
boolean showLabels = true;
boolean showZeros = true;
// fraction of the width that labels occupy
double labelViewFraction = 0.30;
List<String> labels;
// if set to a valid category then that category will be highlighted
int highlightCategory = -1;
// internal variables used for rendering
int viewHeight, viewWidth;
int gridHeight, gridWidth;
boolean renderLabels;
/**
* Constructor that specifies the confusion matrix and width/height
* @param labels Optional labels for the confusion matrix.
* @param widthPixels preferred width and height of the panel in pixels
* @param gray Render gray scale or color image
*/
public ConfusionMatrixPanel( DenseMatrix64F M , List<String> labels, int widthPixels , boolean gray ) {
this(widthPixels,labels!=null);
setLabels(labels);
setMatrix(M);
this.gray = gray;
}
/**
* Constructor in which the prefered width and height is specified in pixels
* @param widthPixels preferred width and height
*/
public ConfusionMatrixPanel(int widthPixels, boolean hasLabels ) {
int heightPixels = widthPixels;
if( hasLabels ) {
heightPixels *= 1.0-labelViewFraction;
}
setPreferredSize(new Dimension(widthPixels,heightPixels));
}
public void setMatrix( DenseMatrix64F A ) {
synchronized ( this ) {
temp.set(A);
dirty = true;
}
repaint();
}
public boolean isGray() {
return gray;
}
public void setGray(boolean gray) {
this.gray = gray;
}
public boolean isShowNumbers() {
return showNumbers;
}
public void setShowNumbers(boolean showNumbers) {
this.showNumbers = showNumbers;
}
public boolean isShowZeros() {
return showZeros;
}
public void setShowZeros(boolean showZeros) {
this.showZeros = showZeros;
}
public boolean isShowLabels() {
return showLabels;
}
public void setShowLabels(boolean showLabels) {
this.showLabels = showLabels;
}
public void setLabels(List<String> labels) {
this.labels = new ArrayList<>(labels);
}
public int getHighlightCategory() {
return highlightCategory;
}
public void setHighlightCategory(int highlightCategory) {
this.highlightCategory = highlightCategory;
}
@Override
public synchronized void paint( Graphics g ) {
synchronized ( this ) {
if (dirty) {
confusion.set(temp);
dirty = false;
}
}
Graphics2D g2 = (Graphics2D)g;
int numCategories = confusion.getNumRows();
synchronized ( this ) {
viewHeight = getHeight();
viewWidth = getWidth();
gridHeight = viewHeight;
gridWidth = viewWidth;
renderLabels = this.showLabels && labels != null;
if (renderLabels) {
// gridHeight *= 1.0-labelViewFraction;
gridWidth *= 1.0 - labelViewFraction;
}
}
double fontSize = Math.min(gridWidth/numCategories,gridHeight/numCategories);
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
if(renderLabels) {
renderLabels(g2, fontSize);
}
renderMatrix(g2, fontSize);
if( highlightCategory >= 0 && highlightCategory < numCategories ) {
g2.setColor(new Color(255,255,0,100));
int ry = (int)(0.1*gridHeight / numCategories);
int rx = (int)(0.1*gridWidth / numCategories);
int y0 = highlightCategory * gridHeight / numCategories;
int y1 = (highlightCategory + 1) * gridHeight / numCategories;
int x0 = highlightCategory * gridWidth / numCategories;
int x1 = (highlightCategory + 1) * gridWidth / numCategories;
g2.fillRect(x0+rx,0,x1-x0-2*rx,gridHeight);
g2.fillRect(0,y0+ry,viewWidth,y1-y0-2*ry);
}
}
/**
* Renders the names on each category to the side of the confusion matrix
*/
private void renderLabels(Graphics2D g2, double fontSize) {
int numCategories = confusion.getNumRows();
int longestLabel = 0;
if(renderLabels) {
for (int i = 0; i < numCategories; i++) {
longestLabel = Math.max(longestLabel,labels.get(i).length());
}
}
Font fontLabel = new Font("monospaced", Font.BOLD, (int)(0.055*longestLabel*fontSize + 0.5));
g2.setFont(fontLabel);
FontMetrics metrics = g2.getFontMetrics(fontLabel);
// clear the background
g2.setColor(Color.WHITE);
g2.fillRect(gridWidth,0,viewWidth-gridWidth,viewHeight);
// draw the text
g2.setColor(Color.BLACK);
for (int i = 0; i < numCategories; i++) {
String label = labels.get(i);
int y0 = i * gridHeight / numCategories;
int y1 = (i + 1) * gridHeight / numCategories;
Rectangle2D r = metrics.getStringBounds(label,null);
float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;
float x = ((viewWidth+gridWidth)/2f-adjX);
float y = ((y1+y0)/2f-adjY);
g2.drawString(label, x, y);
}
}
/**
* Renders the confusion matrix and visualizes the value in each cell with a color and optionally a color.
*/
private void renderMatrix(Graphics2D g2, double fontSize) {
int numCategories = confusion.getNumRows();
Font fontNumber = new Font("Serif", Font.BOLD, (int)(0.6*fontSize + 0.5));
g2.setFont(fontNumber);
FontMetrics metrics = g2.getFontMetrics(fontNumber);
for (int i = 0; i < numCategories; i++) {
int y0 = i*gridHeight/numCategories;
int y1 = (i+1)*gridHeight/numCategories;
for (int j = 0; j < numCategories; j++) {
int x0 = j*gridWidth/numCategories;
int x1 = (j+1)*gridWidth/numCategories;
double value = confusion.unsafe_get(i,j);
int red,green,blue;
if( gray ) {
red = green = blue = (int)(255*(1.0-value));
} else {
green = 0;
red = (int)(255*value);
blue = (int)(255*(1.0-value));
}
g2.setColor(new Color(red, green, blue));
g2.fillRect(x0,y0,x1-x0,y1-y0);
// Render numbers inside the squares. Pick a color so that the number is visible no matter what
// the color of the square is
if( showNumbers && (showZeros || value != 0 )) {
int a = (red+green+blue)/3;
String text = ""+(int)(value*100.0+0.5);
Rectangle2D r = metrics.getStringBounds(text,null);
float adjX = (float)(r.getX()*2 + r.getWidth())/2.0f;
float adjY = (float)(r.getY()*2 + r.getHeight())/2.0f;
float x = ((x1+x0)/2f-adjX);
float y = ((y1+y0)/2f-adjY);
int gray = a > 127 ? 0 : 255;
g2.setColor(new Color(gray,gray,gray));
g2.drawString(text,x,y);
}
}
}
}
/**
* Use to sample the panel to see what is being displayed at the location clicked. All coordinates
* are in panel coordinates.
*
* @param pixelX x-axis in panel coordinates
* @param pixelY y-axis in panel coordinates
* @param output (Optional) storage for output.
* @return Information on what is at the specified location
*/
public LocationInfo whatIsAtPoint( int pixelX , int pixelY , LocationInfo output ) {
if( output == null )
output = new LocationInfo();
int numCategories = confusion.getNumRows();
synchronized ( this ) {
if( pixelX >= gridWidth ) {
output.insideMatrix = false;
output.col = output.row = pixelY*numCategories/gridHeight;
} else {
output.insideMatrix = true;
output.row = pixelY*numCategories/gridHeight;
output.col = pixelX*numCategories/gridWidth;
}
}
return output;
}
/**
* Contains information on what was at the point
*/
public static class LocationInfo {
public boolean insideMatrix;
public int row,col;
}
public static void main(String[] args) {
DenseMatrix64F m = RandomMatrices.createRandom(5,5,0,1,new Random(234));
List<String> labels = new ArrayList<>();
for (int i = 0; i < m.numRows; i++) {
labels.add("Label "+i);
}
ConfusionMatrixPanel confusion = new ConfusionMatrixPanel(m,labels,300,false);
confusion.setHighlightCategory(2);
ShowImages.showWindow(confusion,"Window",true);
}
}