package org.deeplearning4j.examples.unsupervised.variational.plot;
import org.deeplearning4j.berkeley.Pair;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.swing.*;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;
/**
* Plotting methods for the VariationalAutoEncoder example
* @author Alex Black
*/
public class PlotUtil {
//Scatterplot util used for CenterLossMnistExample
public static void scatterPlot(List<Pair<INDArray,INDArray>> data, List<Integer> epochCounts, String title ){
double xMin = Double.MAX_VALUE;
double xMax = -Double.MAX_VALUE;
double yMin = Double.MAX_VALUE;
double yMax = -Double.MAX_VALUE;
for(Pair<INDArray,INDArray> p : data){
INDArray maxes = p.getFirst().max(0);
INDArray mins = p.getFirst().min(0);
xMin = Math.min(xMin, mins.getDouble(0));
xMax = Math.max(xMax, maxes.getDouble(0));
yMin = Math.min(yMin, mins.getDouble(1));
yMax = Math.max(yMax, maxes.getDouble(1));
}
double plotMin = Math.min(xMin, yMin);
double plotMax = Math.max(xMax, yMax);
JPanel panel = new ChartPanel(createChart(data.get(0).getFirst(), data.get(0).getSecond(), plotMin, plotMax, title + " (epoch " + epochCounts.get(0) + ")"));
JSlider slider = new JSlider(0,epochCounts.size()-1,0);
slider.setSnapToTicks(true);
final JFrame f = new JFrame();
slider.addChangeListener(new ChangeListener() {
private JPanel lastPanel = panel;
@Override
public void stateChanged(ChangeEvent e) {
JSlider slider = (JSlider)e.getSource();
int value = slider.getValue();
JPanel panel = new ChartPanel(createChart(data.get(value).getFirst(), data.get(value).getSecond(), plotMin, plotMax, title + " (epoch " + epochCounts.get(value) + ")"));
if(lastPanel != null){
f.remove(lastPanel);
}
lastPanel = panel;
f.add(panel, BorderLayout.CENTER);
f.setTitle(title);
f.revalidate();
}
});
f.setLayout(new BorderLayout());
f.add(slider, BorderLayout.NORTH);
f.add(panel, BorderLayout.CENTER);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setTitle(title);
f.setVisible(true);
}
public static void plotData(List<INDArray> xyVsIter, INDArray labels, double axisMin, double axisMax, int plotFrequency){
JPanel panel = new ChartPanel(createChart(xyVsIter.get(0), labels, axisMin, axisMax));
JSlider slider = new JSlider(0,xyVsIter.size()-1,0);
slider.setSnapToTicks(true);
final JFrame f = new JFrame();
slider.addChangeListener(new ChangeListener() {
private JPanel lastPanel = panel;
@Override
public void stateChanged(ChangeEvent e) {
JSlider slider = (JSlider)e.getSource();
int value = slider.getValue();
JPanel panel = new ChartPanel(createChart(xyVsIter.get(value), labels, axisMin, axisMax));
if(lastPanel != null){
f.remove(lastPanel);
}
lastPanel = panel;
f.add(panel, BorderLayout.CENTER);
f.setTitle(getTitle(value, plotFrequency));
f.revalidate();
}
});
f.setLayout(new BorderLayout());
f.add(slider, BorderLayout.NORTH);
f.add(panel, BorderLayout.CENTER);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setTitle(getTitle(0, plotFrequency));
f.setVisible(true);
}
private static String getTitle(int recordNumber, int plotFrequency){
return "MNIST Test Set - Latent Space Encoding at Training Iteration " + recordNumber * plotFrequency;
}
//Test data
private static XYDataset createDataSet(INDArray features, INDArray labelsOneHot){
int nRows = features.rows();
int nClasses = labelsOneHot.columns();
XYSeries[] series = new XYSeries[nClasses];
for( int i=0; i<nClasses; i++){
series[i] = new XYSeries(String.valueOf(i));
}
INDArray classIdx = Nd4j.argMax(labelsOneHot, 1);
for( int i=0; i<nRows; i++ ){
int idx = classIdx.getInt(i);
series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
}
XYSeriesCollection c = new XYSeriesCollection();
for( XYSeries s : series) c.addSeries(s);
return c;
}
private static JFreeChart createChart(INDArray features, INDArray labels, double axisMin, double axisMax) {
return createChart(features, labels, axisMin, axisMax, "Variational Autoencoder Latent Space - MNIST Test Set");
}
private static JFreeChart createChart(INDArray features, INDArray labels, double axisMin, double axisMax, String title ) {
XYDataset dataset = createDataSet(features, labels);
JFreeChart chart = ChartFactory.createScatterPlot(title,
"X", "Y", dataset, PlotOrientation.VERTICAL, true, true, false);
XYPlot plot = (XYPlot) chart.getPlot();
plot.getRenderer().setBaseOutlineStroke(new BasicStroke(0));
plot.setNoDataMessage("NO DATA");
plot.setDomainPannable(false);
plot.setRangePannable(false);
plot.setDomainZeroBaselineVisible(true);
plot.setRangeZeroBaselineVisible(true);
plot.setDomainGridlineStroke(new BasicStroke(0.0f));
plot.setDomainMinorGridlineStroke(new BasicStroke(0.0f));
plot.setDomainGridlinePaint(Color.blue);
plot.setRangeGridlineStroke(new BasicStroke(0.0f));
plot.setRangeMinorGridlineStroke(new BasicStroke(0.0f));
plot.setRangeGridlinePaint(Color.blue);
plot.setDomainMinorGridlinesVisible(true);
plot.setRangeMinorGridlinesVisible(true);
XYLineAndShapeRenderer renderer
= (XYLineAndShapeRenderer) plot.getRenderer();
renderer.setSeriesOutlinePaint(0, Color.black);
renderer.setUseOutlinePaint(true);
NumberAxis domainAxis = (NumberAxis) plot.getDomainAxis();
domainAxis.setAutoRangeIncludesZero(false);
domainAxis.setRange(axisMin, axisMax);
domainAxis.setTickMarkInsideLength(2.0f);
domainAxis.setTickMarkOutsideLength(2.0f);
domainAxis.setMinorTickCount(2);
domainAxis.setMinorTickMarksVisible(true);
NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis();
rangeAxis.setTickMarkInsideLength(2.0f);
rangeAxis.setTickMarkOutsideLength(2.0f);
rangeAxis.setMinorTickCount(2);
rangeAxis.setMinorTickMarksVisible(true);
rangeAxis.setRange(axisMin, axisMax);
return chart;
}
public static class MNISTLatentSpaceVisualizer {
private double imageScale;
private List<INDArray> digits; //Digits (as row vectors), one per INDArray
private int plotFrequency;
private int gridWidth;
public MNISTLatentSpaceVisualizer(double imageScale, List<INDArray> digits, int plotFrequency) {
this.imageScale = imageScale;
this.digits = digits;
this.plotFrequency = plotFrequency;
this.gridWidth = (int)Math.sqrt(digits.get(0).size(0)); //Assume square, nxn rows
}
private String getTitle(int recordNumber){
return "Reconstructions Over Latent Space at Training Iteration " + recordNumber * plotFrequency;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(getTitle(0));
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setLayout(new BorderLayout());
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
JSlider slider = new JSlider(0,digits.size()-1, 0);
slider.addChangeListener(new ChangeListener() {
@Override
public void stateChanged(ChangeEvent e) {
JSlider slider = (JSlider)e.getSource();
int value = slider.getValue();
panel.removeAll();
List<JLabel> list = getComponents(value);
for(JLabel image : list){
panel.add(image);
}
frame.setTitle(getTitle(value));
frame.revalidate();
}
});
frame.add(slider, BorderLayout.NORTH);
List<JLabel> list = getComponents(0);
for(JLabel image : list){
panel.add(image);
}
frame.add(panel, BorderLayout.CENTER);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(int idx){
List<JLabel> images = new ArrayList<>();
List<INDArray> temp = new ArrayList<>();
for( int i=0; i<digits.get(idx).size(0); i++ ){
temp.add(digits.get(idx).getRow(i));
}
for( INDArray arr : temp ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
}