package neuralnetworks; import javax.swing.*; import java.awt.*; import java.awt.event.*; public class GUITest_2H extends JFrame { static float[] in1 = {0.1f, 0.1f, 0.9f}; static float[] in2 = {0.1f, 0.9f, 0.1f}; static float[] in3 = {0.9f, 0.1f, 0.1f}; static float[] out1 = {0.9f, 0.1f, 0.1f}; static float[] out2 = {0.1f, 0.1f, 0.9f}; static float[] out3 = {0.1f, 0.9f, 0.1f}; static float[] test1 = {0.1f, 0.1f, 0.9f}; static float[] test2 = {0.1f, 0.9f, 0.1f}; static float[] test3 = {0.9f, 0.1f, 0.1f}; Neural_2H nn = new Neural_2H(3, 3, 3, 3); Plot1DPanel inputPanel = new Plot1DPanel(3, 0.0f, 1.0f, nn.inputs); Plot1DPanel hidden1Panel = new Plot1DPanel(3, 0.0f, 1.0f, nn.hidden1); Plot1DPanel hidden2Panel = new Plot1DPanel(3, 0.0f, 1.0f, nn.hidden2); Plot1DPanel outputPanel = new Plot1DPanel(3, 0.0f, 1.0f, nn.outputs); Plot2DPanel w1Panel = new Plot2DPanel(3, 3, nn.clampWeight(-100), nn.clampWeight(100), nn.W1); Plot2DPanel w2Panel = new Plot2DPanel(3, 3, nn.clampWeight(-100), nn.clampWeight(100), nn.W2); Plot2DPanel w3Panel = new Plot2DPanel(3, 3, nn.clampWeight(-100), nn.clampWeight(100), nn.W3); JButton jButton1 = new JButton(); JLabel jLabel1 = new JLabel(); JLabel jLabel2 = new JLabel(); JLabel jLabel2b = new JLabel(); JLabel jLabel3 = new JLabel(); JLabel jLabel4 = new JLabel(); JLabel jLabel4b = new JLabel(); JLabel jLabel5 = new JLabel(); public GUITest_2H() { try { nn.addTrainingExample(in1, out1); nn.addTrainingExample(in2, out2); nn.addTrainingExample(in3, out3); jbInit(); this.setSize(440, 450); this.setVisible(true); } catch (Exception e) { e.printStackTrace(); } } public static void main(String[] args) { new GUITest_2H(); } private void jbInit() throws Exception { this.getContentPane().setLayout(null); inputPanel.setBounds(new Rectangle(5, 30, 400, 20)); hidden1Panel.setBounds(new Rectangle(5, 138, 400, 20)); hidden2Panel.setBounds(new Rectangle(5, 238, 400, 20)); outputPanel.setBounds(new Rectangle(5, 340, 400, 20)); w1Panel.setBounds(new Rectangle(160, 50, 61, 61)); w2Panel.setBounds(new Rectangle(160, 158, 61, 61)); w3Panel.setBounds(new Rectangle(160, 258, 61, 61)); jButton1.setText("Reset and Run"); jButton1.setBounds(new Rectangle(246, 380, 148, 28)); jButton1.addMouseListener(new java.awt.event.MouseAdapter() { public void mousePressed(MouseEvent e) { try { do_run_button(e); } catch (InterruptedException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } } }); this.setDefaultCloseOperation(3); jLabel1.setText("Input neurons:"); jLabel1.setBounds(new Rectangle(4, 10, 144, 19)); jLabel2.setText("Hidden 1 neurons:"); jLabel2.setBounds(new Rectangle(4, 111, 144, 19)); jLabel2b.setText("Hidden 2 neurons:"); jLabel2b.setBounds(new Rectangle(4, 211, 144, 19)); jLabel3.setText("Output neurons:"); jLabel3.setBounds(new Rectangle(4, 317, 240, 19)); jLabel4.setText("Input to H1 weights"); jLabel4.setBounds(new Rectangle(230, 80, 170, 19)); jLabel4b.setText("H1 to H2 weights"); jLabel4b.setBounds(new Rectangle(230, 180, 170, 19)); jLabel5.setText("H2 to output weights"); jLabel5.setBounds(new Rectangle(230, 280, 170, 19)); this.getContentPane().add(inputPanel, null); this.getContentPane().add(hidden1Panel, null); this.getContentPane().add(hidden2Panel, null); this.getContentPane().add(outputPanel, null); this.getContentPane().add(w1Panel, null); this.getContentPane().add(w2Panel, null); this.getContentPane().add(w3Panel, null); this.getContentPane().add(jButton1, null); this.getContentPane().add(jLabel1, null); this.getContentPane().add(jLabel2, null); this.getContentPane().add(jLabel2b, null); this.getContentPane().add(jLabel3, null); this.getContentPane().add(jLabel4, null); this.getContentPane().add(jLabel4b, null); this.getContentPane().add(jLabel5, null); this.getContentPane().setBackground(Color.white); } void do_run_button(MouseEvent e) throws InterruptedException { jButton1.setEnabled(false); Graphics g1 = inputPanel.getGraphics(); Graphics g2 = hidden1Panel.getGraphics(); Graphics g3 = hidden2Panel.getGraphics(); Graphics g4 = outputPanel.getGraphics(); Graphics g5 = w1Panel.getGraphics(); Graphics g6 = w2Panel.getGraphics(); Graphics g7 = w3Panel.getGraphics(); training_loop: for (int i = 0; i < 25000; i++) { if (i == 5000 || i == 8000 || i == 10000 | i == 12000) nn.TRAINING_RATE *= 0.75f; float error = nn.train(); if (i > 0 && i % 500 == 0) { // // If the error is too large, slightly randomize weights: if (error > 0.75) { nn.randomizeWeights(); nn.TRAINING_RATE = 0.75f; } else if (error > 0.3) { nn.slightlyRandomizeWeights(); } System.out.println("cycle " + i + " error is " + error); if (error < 0.1) break training_loop; } inputPanel.paint(g1); hidden1Panel.paint(g2); hidden2Panel.paint(g3); outputPanel.paint(g4); w1Panel.paint(g5); w2Panel.paint(g6); w3Panel.paint(g7); } float [] answers; Thread.sleep(2000); answers = nn.recall(in1); for (int i=0; i<3; i++) nn.outputs[i] = answers[i]; inputPanel.paint(g1); hidden1Panel.paint(g2); hidden2Panel.paint(g3); outputPanel.paint(g4); Thread.sleep(2000); answers = nn.recall(in2); for (int i=0; i<3; i++) nn.outputs[i] = answers[i]; inputPanel.paint(g1); hidden1Panel.paint(g2); hidden2Panel.paint(g3); outputPanel.paint(g4); Thread.sleep(2000); answers = nn.recall(in3); for (int i=0; i<3; i++) nn.outputs[i] = answers[i]; inputPanel.paint(g1); hidden1Panel.paint(g2); hidden2Panel.paint(g3); outputPanel.paint(g4); jButton1.setEnabled(true); //nn.save("/tmp/neural_2h_save"); } }