package com.m4rkl1u.autoencoder; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationLinear; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.engine.network.activation.ActivationTANH; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; import org.encog.neural.networks.training.propagation.Propagation; import org.encog.neural.networks.training.propagation.quick.QuickPropagation; public class AutoEncoder { public class MLParams{ public double[] weights; public ActivationFunction func; public int nodes; public MLParams(double[] weights, ActivationFunction func, int nodes){ this.weights = weights; this.func = func; this.nodes = nodes; } } private List<MLParams> params; private MLDataSet dataset; private ActivationFunction func; private BasicNetwork network; private BasicNetwork hiddenNet; private MLDataSet intermediateDataset; public AutoEncoder(){ params = new ArrayList<MLParams>(); dataset = new BasicMLDataSet(); } public void setData(MLDataSet dataset) { this.dataset = dataset; } public void setData(double[][] p) { for(int i = 0 ; i < p.length; i ++){ double[] input = p[i]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(input)); dataset.add(pair); } } public void addData(double[] p ){ MLDataPair pair = new BasicMLDataPair(new BasicMLData(p), new BasicMLData(p)); dataset.add(pair); } public void setFunc(ActivationFunction func){ this.func = func; } public void addLayer(ActivationFunction func, int nodes){ if(params.size() > 0) { buildNetwork(); transformData(); } else { intermediateDataset = this.dataset; } network = new BasicNetwork(); network.addLayer(new BasicLayer(new ActivationLinear(), true, intermediateDataset.getInputSize())); network.addLayer(new BasicLayer(func, true, nodes)); network.addLayer(new BasicLayer(new ActivationTANH(), false, intermediateDataset.getIdealSize())); network.getStructure().finalizeStructure(); network.reset(); train(nodes); } public void train(int nodes) { Propagation propagation = new QuickPropagation(network, intermediateDataset, 0.01); propagation.setThreadCount(0); //for(int i = 0 ; i < 100; i ++) { propagation.iteration(); System.out.println( "In deep layer: " + params.size() + " Training error " + propagation.getError()); //} int fromNodes = network.getInputCount() + 1; int toNodes = network.getLayerNeuronCount(1); //the next layer int numWeight = fromNodes * toNodes; double[] weights = new double[numWeight]; int k = 0; for(int i = 0 ; i < fromNodes; i ++ ){ for(int j = 0 ; j < toNodes; j ++) { //FIXME, bug weights[k++] = network.getWeight(0, i, j); } } ActivationFunction func = network.getActivation(1); MLParams param = new MLParams(weights, func, nodes); params.add(param); System.out.println("Add weight: " + weights.length + "\n and the activation function: " + func.toString()); } private void transformData() { intermediateDataset = new BasicMLDataSet(); for(int i = 0 ; i < this.dataset.getRecordCount(); i ++) { MLData input = hiddenNet.compute(dataset.get(i).getInput()); intermediateDataset.add(input, input); } } private void buildNetwork() { hiddenNet = new BasicNetwork(); hiddenNet.addLayer(new BasicLayer(new ActivationLinear(), true, dataset.getInputSize())); for(int i = 0 ; i < params.size() - 1; i ++ ){ hiddenNet.addLayer(new BasicLayer(params.get(i).func, true, params.get(i).nodes)); } hiddenNet.addLayer(new BasicLayer(params.get(params.size() - 1).func, false, params.get(params.size() - 1).nodes)); hiddenNet.getStructure().finalizeStructure(); for(int i = 0 ; i < params.size(); i ++) { double[] layer_weights = params.get(i).weights; int j = 0; int fromCount = network.getLayerTotalNeuronCount(i); int toCount = network.getLayerNeuronCount(i + 1); for(int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++){ for(int toNeuron = 0; toNeuron < toCount; toNeuron++){ network.setWeight(i, fromNeuron, toNeuron, layer_weights[j++]); } } } } public double[] represent(int layer){ assert(layer <= params.size()); hiddenNet = new BasicNetwork(); hiddenNet.addLayer(new BasicLayer(new ActivationLinear(), true, dataset.getInputSize())); for(int i = 0 ; i < params.size() - 1 && i < layer - 1; i ++ ){ hiddenNet.addLayer(new BasicLayer(params.get(i).func, true, params.get(i).nodes)); } hiddenNet.addLayer(new BasicLayer(params.get(layer - 1).func, false, params.get(layer - 1).nodes)); hiddenNet.getStructure().finalizeStructure(); for(int i = 0 ; i < params.size() && i < layer - 1; i ++) { double[] layer_weights = params.get(i).weights; int j = 0; int fromCount = network.getLayerTotalNeuronCount(i); int toCount = network.getLayerNeuronCount(i + 1); for(int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++){ for(int toNeuron = 0; toNeuron < toCount; toNeuron++){ network.setWeight(i, fromNeuron, toNeuron, layer_weights[j++]); } } } return hiddenNet.compute(this.dataset.get(0).getInput()).getData(); } }