package net.demilich.metastone.game.behaviour.neutralnetwork;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Random;
public class NeuralNetwork implements Serializable {
// serialver for backwards compatibility
private static final long serialVersionUID = 1165374168397424904L;
// the random number generator
public static final Random random = new Random();
/**
* Method which reads and returns a network from the given file
*
* @param filename
* The file to read from
*/
public static NeuralNetwork readFrom(String filename) throws IOException, ClassNotFoundException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filename));
NeuralNetwork net = (NeuralNetwork) ois.readObject();
ois.close();
return net;
}
// the layers of the network
public InputUnit[] input;
public HiddenUnit[][] hidden;
/**
* Builds a neural network with the given number of input units, hidden
* units, and output units. Thus, calling
*
* new NeuralNetwork(10, new int[] {20, 5});
*
* creates a neural network with 10 input units, a layer of 20 hidden units,
* and then 5 output units.
*
* @param input
* The number of input units
* @param hidden
* The number of hidden units, as well as the number of layers
*/
public NeuralNetwork(int input, int[] hidden) {
this.input = new InputUnit[input];
this.hidden = new HiddenUnit[hidden.length][];
for (int i = 0; i < hidden.length; i++)
this.hidden[i] = new HiddenUnit[hidden[i]];
for (int i = 0; i < input; i++)
this.input[i] = new InputUnit();
for (int i = 0; i < hidden.length; i++)
for (int j = 0; j < hidden[i]; j++)
if (i == 0)
this.hidden[i][j] = new HiddenUnit(this.input, random);
else
this.hidden[i][j] = new HiddenUnit(this.hidden[i - 1], random);
// for (int i = 0; i < hidden.length; i++) {
// for (int j = 0; j < hidden[i]; j++) {
// this.hidden[i][j].randomizeWeights(random);
// }
// }
}
/**
* Builds a neural network based on the provided network and copies the
* weights of the provided network into the new one.
*
* @param net
* The network to base it off of
*/
public NeuralNetwork(NeuralNetwork net) {
this.input = new InputUnit[net.input.length];
this.hidden = new HiddenUnit[net.hidden.length][];
for (int i = 0; i < input.length; i++)
this.input[i] = new InputUnit();
for (int i = 0; i < net.hidden.length; i++) {
this.hidden[i] = new HiddenUnit[net.hidden[i].length];
for (int j = 0; j < net.hidden[i].length; j++)
if (i == 0)
this.hidden[i][j] = new HiddenUnit(this.input, net.hidden[i][j].weights);
else
this.hidden[i][j] = new HiddenUnit(this.hidden[i - 1], net.hidden[i][j].weights);
}
}
/**
* Calculates the network value given the provided input
*
* @param input
* The input to check
* @return The network value from this input
*/
public double[] getValue(double[] input) {
double[] result = new double[hidden[hidden.length - 1].length];
for (int i = 0; i < input.length; i++)
this.input[i].setValue(input[i]);
for (int i = 0; i < hidden.length; i++)
for (int j = 0; j < hidden[i].length; j++)
this.hidden[i][j].recompute();
for (int j = 0; j < hidden[hidden.length - 1].length; j++)
result[j] = this.hidden[hidden.length - 1][j].getValue();
return result;
}
/**
* Method which writes this network to the given file
*
* @param file
* The file to write to
*/
public void writeTo(String filename) throws IOException {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filename));
oos.writeObject(this);
oos.flush();
oos.close();
}
}