package samples.expert;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningTask;
import hex.deeplearning.Neurons;
import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
public class DeepLearningVisualization extends Canvas {
static int _level = 1;
Neurons[] _neurons;
public DeepLearningVisualization(Neurons[] neurons) {
_neurons = neurons;
}
public JPanel init() {
JToolBar bar = new JToolBar();
bar.add(new JButton("refresh") {
@Override protected void fireActionPerformed(ActionEvent event) {
DeepLearningVisualization.this.repaint();
}
});
bar.add(new JButton("++") {
@Override protected void fireActionPerformed(ActionEvent event) {
if (_level < _neurons.length-2) _level++;
}
});
bar.add(new JButton("--") {
@Override protected void fireActionPerformed(ActionEvent event) {
if (_level > 1) _level--;
}
});
JPanel pane = new JPanel();
BorderLayout bord = new BorderLayout();
pane.setLayout(bord);
pane.add("North", bar);
setSize(1024, 1024);
pane.add(this);
return pane;
}
@Override public void paint(Graphics g) {
Neurons layer = _neurons[_level];
int edge = 56, pad = 10;
final int EDGE = (int) Math.ceil(Math.sqrt(layer._previous._a.size()));
assert (layer._previous._a.size() <= EDGE * EDGE);
int offset = pad;
int buf = EDGE + pad + pad;
double mean = 0;
long n = layer._w.size();
for (int i = 0; i < n; i++)
mean += layer._w.raw()[i];
mean /= layer._w.size();
double sigma = 0;
for (int i = 0; i < layer._w.size(); i++) {
double d = layer._w.raw()[i] - mean;
sigma += d * d;
}
sigma = Math.sqrt(sigma / (layer._w.size() - 1));
for (int o = 0; o < layer._a.size(); o++) {
if (o % 10 == 0) {
offset = pad;
buf += pad + edge;
}
int[] pic = new int[EDGE * EDGE];
for (int i = 0; i < layer._previous._a.size(); i++) {
double w = layer._w.get(o, i);
w = ((w - mean) / sigma) * 200;
if (w >= 0)
pic[i] = ((int) Math.min(+w, 255)) << 8; //GREEN
else
pic[i] = ((int) Math.min(-w, 255)) << 16; //RED
}
BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB);
WritableRaster r = out.getRaster();
r.setDataElements(0, 0, EDGE, EDGE, pic);
BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB);
Graphics2D g2 = resized.createGraphics();
try {
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC);
g2.clearRect(0, 0, edge, edge);
g2.drawImage(out, 0, 0, edge, edge, null);
} finally {
g2.dispose();
}
g.drawImage(resized, buf, offset, null);
offset += pad + edge;
}
}
static JFrame frame = new JFrame("H2O Deep Learning");
static public void visualize(final DeepLearningModel dlm) {
Neurons[] neurons = DeepLearningTask.makeNeuronsForTesting(dlm.model_info());
frame.dispose();
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
DeepLearningVisualization canvas = new DeepLearningVisualization(neurons);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
}
}