package org.deeplearning4j.examples.feedforward.regression;
import org.deeplearning4j.examples.feedforward.regression.function.MathFunction;
import org.deeplearning4j.examples.feedforward.regression.function.SinXDivXMathFunction;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.util.Collections;
import java.util.List;
import java.util.Random;
/**Example: Train a network to reproduce certain mathematical functions, and plot the results.
* Plotting of the network output occurs every 'plotFrequency' epochs. Thus, the plot shows the accuracy of the network
* predictions as training progresses.
* A number of mathematical functions are implemented here.
* Note the use of the identity function on the network output layer, for regression
*
* @author Alex Black
*/
public class RegressionMathFunctions {
//Random number generator seed, for reproducability
public static final int seed = 12345;
//Number of iterations per minibatch
public static final int iterations = 1;
//Number of epochs (full passes of the data)
public static final int nEpochs = 2000;
//How frequently should we plot the network output?
public static final int plotFrequency = 500;
//Number of data points
public static final int nSamples = 1000;
//Batch size: i.e., each epoch has nSamples/batchSize parameter updates
public static final int batchSize = 100;
//Network learning rate
public static final double learningRate = 0.01;
public static final Random rng = new Random(seed);
public static final int numInputs = 1;
public static final int numOutputs = 1;
public static void main(final String[] args){
//Switch these two options to do different functions with different networks
final MathFunction fn = new SinXDivXMathFunction();
final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration();
//Generate the training data
final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng);
//Create the network
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
//Train the network on the full data set, and evaluate in periodically
final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
for( int i=0; i<nEpochs; i++ ){
iterator.reset();
net.fit(iterator);
if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false);
}
//Plot the target data and the network predictions
plot(fn,x,fn.getFunctionValues(x),networkPredictions);
}
/** Returns the network configuration, 2 hidden DenseLayers of size 50.
*/
private static MultiLayerConfiguration getDeepDenseLayerNetworkConfiguration() {
final int numHiddenNodes = 50;
return new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
}
/** Create a DataSetIterator for training
* @param x X values
* @param function Function to evaluate
* @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
* @param rng Random number generator (for repeatability)
*/
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
final INDArray y = function.getFunctionValues(x);
final DataSet allData = new DataSet(x,y);
final List<DataSet> list = allData.asList();
Collections.shuffle(list,rng);
return new ListDataSetIterator(list,batchSize);
}
//Plot the data
private static void plot(final MathFunction function, final INDArray x, final INDArray y, final INDArray... predicted) {
final XYSeriesCollection dataSet = new XYSeriesCollection();
addSeries(dataSet,x,y,"True Function (Labels)");
for( int i=0; i<predicted.length; i++ ){
addSeries(dataSet,x,predicted[i],String.valueOf(i));
}
final JFreeChart chart = ChartFactory.createXYLineChart(
"Regression Example - " + function.getName(), // chart title
"X", // x axis label
function.getName() + "(X)", // y axis label
dataSet, // data
PlotOrientation.VERTICAL,
true, // include legend
true, // tooltips
false // urls
);
final ChartPanel panel = new ChartPanel(chart);
final JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setVisible(true);
}
private static void addSeries(final XYSeriesCollection dataSet, final INDArray x, final INDArray y, final String label){
final double[] xd = x.data().asDouble();
final double[] yd = y.data().asDouble();
final XYSeries s = new XYSeries(label);
for( int j=0; j<xd.length; j++ ) s.add(xd[j],yd[j]);
dataSet.addSeries(s);
}
}