package hex; import hex.Layer.VecsInput; import water.fvec.Vec; import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.image.BufferedImage; import java.awt.image.WritableRaster; import java.util.Random; public class MnistCanvas extends Canvas { static final int PIXELS = 784, EDGE = 28; static Random _rand = new Random(); static int _level = 1; Trainer _trainer; public MnistCanvas(Trainer trainer) { _trainer = trainer; } public JPanel init() { JToolBar bar = new JToolBar(); bar.add(new JButton("refresh") { @Override protected void fireActionPerformed(ActionEvent event) { MnistCanvas.this.repaint(); } }); bar.add(new JButton("++") { @Override protected void fireActionPerformed(ActionEvent event) { _level++; } }); bar.add(new JButton("--") { @Override protected void fireActionPerformed(ActionEvent event) { _level--; } }); bar.add(new JButton("histo") { @Override protected void fireActionPerformed(ActionEvent event) { Histograms.initFromSwingThread(); Histograms.build(_trainer.layers()); } }); JPanel pane = new JPanel(); BorderLayout bord = new BorderLayout(); pane.setLayout(bord); pane.add("North", bar); setSize(1024, 1024); pane.add(this); return pane; } @Override public void paint(Graphics g) { Layer[] ls = _trainer.layers(); Vec[] vecs = ((VecsInput) ls[0]).vecs; // Vec resp = ((VecSoftmax) ls[ls.length - 1]).vec; int edge = 56, pad = 10; int rand = _rand.nextInt((int) vecs[0].length()); // Side { BufferedImage in = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB); WritableRaster r = in.getRaster(); // Input int[] pix = new int[PIXELS]; for( int i = 0; i < pix.length; i++ ) pix[i] = (int) (vecs[i].at8(rand)); r.setDataElements(0, 0, EDGE, EDGE, pix); g.drawImage(in, pad, pad, null); // Labels // g.drawString("" + resp.at8(rand), 10, 50); // g.drawString("RBM " + _level, 10, 70); } // Outputs int offset = pad; // float[] visible = new float[MnistNeuralNetTest.PIXELS]; // System.arraycopy(_images, rand * MnistNeuralNetTest.PIXELS, visible, 0, MnistNeuralNetTest.PIXELS); // for( int i = 0; i <= _level; i++ ) { // for( int pass = 0; pass < 10; pass++ ) { // if( i == _level ) { // int[] output = new int[visible.length]; // for( int v = 0; v < visible.length; v++ ) // output[v] = (int) Math.min(visible[v] * 255, 255); // BufferedImage out = new BufferedImage(MnistNeuralNetTest.EDGE, MnistNeuralNetTest.EDGE, // BufferedImage.TYPE_INT_RGB); // WritableRaster r = out.getRaster(); // r.setDataElements(0, 0, MnistNeuralNetTest.EDGE, MnistNeuralNetTest.EDGE, output); // BufferedImage image = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB); // Graphics2D ig = image.createGraphics(); // ig.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); // ig.clearRect(0, 0, edge, edge); // ig.drawImage(out, 0, 0, edge, edge, null); // ig.dispose(); // g.drawImage(image, pad * 2 + MnistNeuralNetTest.EDGE, offset, null); // offset += pad + edge; // } // if( _ls[i]._v != null ) { // float[] hidden = new float[_ls[i]._b.length]; // _ls[i].forward(visible, hidden); // visible = _ls[i].generate(hidden); // } // } // float[] t = new float[_ls[i]._b.length]; // _ls[i].forward(visible, t); // visible = t; // } // Weights int buf = EDGE + pad + pad; Layer layer = ls[_level]; double mean = 0; int n = layer._w.length; for( int i = 0; i < n; i++ ) mean += layer._w[i]; mean /= layer._w.length; double sigma = 0; for( int i = 0; i < layer._w.length; i++ ) { double d = layer._w[i] - mean; sigma += d * d; } sigma = Math.sqrt(sigma / (layer._w.length - 1)); for( int o = 0; o < layer._b.length; o++ ) { if( o % 10 == 0 ) { offset = pad; buf += pad + edge; } int[] start = new int[layer._previous._a.length]; for( int i = 0; i < layer._previous._a.length; i++ ) { double w = layer._w[o * layer._previous._a.length + i]; w = ((w - mean) / sigma) * 200; if( w >= 0 ) start[i] = ((int) Math.min(+w, 255)) << 8; //GREEN else start[i] = ((int) Math.min(-w, 255)) << 16; //RED } BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB); WritableRaster r = out.getRaster(); r.setDataElements(0, 0, EDGE, EDGE, start); BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB); Graphics2D g2 = resized.createGraphics(); try { g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); g2.clearRect(0, 0, edge, edge); g2.drawImage(out, 0, 0, edge, edge, null); } finally { g2.dispose(); } g.drawImage(resized, buf, offset, null); offset += pad + edge; } } }