package neuralnetworks;
import javax.swing.*;
import java.awt.*;
import java.awt.event.*;
public class GUITest_1H extends JFrame {
static float[] in1 = {0.1f, 0.1f, 0.9f};
static float[] in2 = {0.1f, 0.9f, 0.1f};
static float[] in3 = {0.9f, 0.1f, 0.1f};
static float[] out1 = {0.9f, 0.1f, 0.1f};
static float[] out2 = {0.1f, 0.1f, 0.9f};
static float[] out3 = {0.1f, 0.9f, 0.1f};
static float[] test1 = {0.1f, 0.1f, 0.9f};
static float[] test2 = {0.1f, 0.9f, 0.1f};
static float[] test3 = {0.9f, 0.1f, 0.1f};
Neural_1H nn = new Neural_1H(3, 3, 3);
Plot1DPanel inputPanel = new Plot1DPanel(3, 0f, 1.0f, nn.inputs);
Plot1DPanel hiddenPanel = new Plot1DPanel(3, 0f, 1.0f, nn.hidden);
Plot1DPanel outputPanel = new Plot1DPanel(3, 0f, 1.0f, nn.outputs);
Plot2DPanel w1Panel = new Plot2DPanel(3, 3, -1.0f, 1.0f, nn.W1);
Plot2DPanel w2Panel = new Plot2DPanel(3, 3, -4.0f, 4.0f, nn.W2);
JButton jButton1 = new JButton();
JLabel jLabel1 = new JLabel();
JLabel jLabel2 = new JLabel();
JLabel jLabel3 = new JLabel();
JLabel jLabel4 = new JLabel();
JLabel jLabel5 = new JLabel();
public GUITest_1H() {
try {
nn.addTrainingExample(in1, out1);
nn.addTrainingExample(in2, out2);
nn.addTrainingExample(in3, out3);
jbInit();
this.setSize(450, 350);
this.setVisible(true);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
GUITest_1H GUITest_1H1 = new GUITest_1H();
}
private void jbInit() throws Exception {
this.getContentPane().setLayout(null);
inputPanel.setBounds(new Rectangle(5, 30, 400, 20));
hiddenPanel.setBounds(new Rectangle(5, 138, 400, 20));
outputPanel.setBounds(new Rectangle(5, 240, 400, 20));
w1Panel.setBounds(new Rectangle(150, 50, 61, 61));
w2Panel.setBounds(new Rectangle(150, 158, 61, 61));
jButton1.setText("Reset and Run");
jButton1.setBounds(new Rectangle(246, 290, 148, 28));
jButton1.addMouseListener(new java.awt.event.MouseAdapter() {
public void mousePressed(MouseEvent e) {
try {
do_run_button(e);
} catch (InterruptedException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
});
this.setDefaultCloseOperation(3);
jLabel1.setText("Input neurons:");
jLabel1.setBounds(new Rectangle(4, 10, 144, 19));
jLabel2.setText("Hidden neurons:");
jLabel2.setBounds(new Rectangle(4, 115, 144, 19));
jLabel3.setText("Output neurons:");
jLabel3.setBounds(new Rectangle(4, 220, 240, 19));
jLabel4.setText("input to hidden weights");
jLabel4.setBounds(new Rectangle(220, 70, 150, 19));
jLabel5.setText("hidden to output weights");
jLabel5.setBounds(new Rectangle(220, 180, 190, 19));
this.getContentPane().add(inputPanel, null);
this.getContentPane().add(hiddenPanel, null);
this.getContentPane().add(outputPanel, null);
this.getContentPane().add(w1Panel, null);
this.getContentPane().add(w2Panel, null);
this.getContentPane().add(jButton1, null);
this.getContentPane().add(jLabel1, null);
this.getContentPane().add(jLabel2, null);
this.getContentPane().add(jLabel3, null);
this.getContentPane().add(jLabel4, null);
this.getContentPane().add(jLabel5, null);
this.getContentPane().setBackground(Color.white);
}
void do_run_button(MouseEvent e) throws InterruptedException {
Graphics g1 = inputPanel.getGraphics();
Graphics g2 = hiddenPanel.getGraphics();
Graphics g3 = outputPanel.getGraphics();
Graphics g4 = w1Panel.getGraphics();
Graphics g5 = w2Panel.getGraphics();
training_loop:
for (int i = 0; i < 5000; i++) {
if (i == 1000 || i == 3000 || i == 4000 | i ==4500) nn.learningRate *= 0.75f;
float error = nn.train();
if (i > 0 && i % 500 == 0) {
//
// If the error is too large, slightly randomize weights:
if (error > 0.75) {
nn.randomizeWeights();
nn.learningRate = 0.75f;
} else if (error > 0.3) {
nn.slightlyRandomizeWeights();
}
System.out.println("cycle " + i + " error is " + error);
if (error < 0.1) break training_loop;
}
inputPanel.paint(g1);
hiddenPanel.paint(g2);
outputPanel.paint(g3);
w1Panel.paint(g4);
w2Panel.paint(g5);
}
float [] answers;
Thread.sleep(2000);
answers = nn.recall(in1);
for (int i=0; i<3; i++) nn.outputs[i] = answers[i];
inputPanel.paint(g1);
hiddenPanel.paint(g2);
outputPanel.paint(g3);
Thread.sleep(2000);
answers = nn.recall(in2);
for (int i=0; i<3; i++) nn.outputs[i] = answers[i];
inputPanel.paint(g1);
hiddenPanel.paint(g2);
outputPanel.paint(g3);
Thread.sleep(2000);
answers = nn.recall(in3);
for (int i=0; i<3; i++) nn.outputs[i] = answers[i];
inputPanel.paint(g1);
hiddenPanel.paint(g2);
outputPanel.paint(g3);
jButton1.setEnabled(true);
}
}