package org.deeplearning4j.examples.feedforward.regression; import org.deeplearning4j.examples.feedforward.regression.function.MathFunction; import org.deeplearning4j.examples.feedforward.regression.function.SinXDivXMathFunction; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.plot.PlotOrientation; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import javax.swing.*; import java.util.Collections; import java.util.List; import java.util.Random; /**Example: Train a network to reproduce certain mathematical functions, and plot the results. * Plotting of the network output occurs every 'plotFrequency' epochs. Thus, the plot shows the accuracy of the network * predictions as training progresses. * A number of mathematical functions are implemented here. * Note the use of the identity function on the network output layer, for regression * * @author Alex Black */ public class RegressionMathFunctions { //Random number generator seed, for reproducability public static final int seed = 12345; //Number of iterations per minibatch public static final int iterations = 1; //Number of epochs (full passes of the data) public static final int nEpochs = 2000; //How frequently should we plot the network output? public static final int plotFrequency = 500; //Number of data points public static final int nSamples = 1000; //Batch size: i.e., each epoch has nSamples/batchSize parameter updates public static final int batchSize = 100; //Network learning rate public static final double learningRate = 0.01; public static final Random rng = new Random(seed); public static final int numInputs = 1; public static final int numOutputs = 1; public static void main(final String[] args){ //Switch these two options to do different functions with different networks final MathFunction fn = new SinXDivXMathFunction(); final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration(); //Generate the training data final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1); final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng); //Create the network final MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(1)); //Train the network on the full data set, and evaluate in periodically final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency]; for( int i=0; i<nEpochs; i++ ){ iterator.reset(); net.fit(iterator); if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false); } //Plot the target data and the network predictions plot(fn,x,fn.getFunctionValues(x),networkPredictions); } /** Returns the network configuration, 2 hidden DenseLayers of size 50. */ private static MultiLayerConfiguration getDeepDenseLayerNetworkConfiguration() { final int numHiddenNodes = 50; return new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .weightInit(WeightInit.XAVIER) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .activation(Activation.TANH).build()) .layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes) .activation(Activation.TANH).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY) .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); } /** Create a DataSetIterator for training * @param x X values * @param function Function to evaluate * @param batchSize Batch size (number of examples for every call of DataSetIterator.next()) * @param rng Random number generator (for repeatability) */ private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) { final INDArray y = function.getFunctionValues(x); final DataSet allData = new DataSet(x,y); final List<DataSet> list = allData.asList(); Collections.shuffle(list,rng); return new ListDataSetIterator(list,batchSize); } //Plot the data private static void plot(final MathFunction function, final INDArray x, final INDArray y, final INDArray... predicted) { final XYSeriesCollection dataSet = new XYSeriesCollection(); addSeries(dataSet,x,y,"True Function (Labels)"); for( int i=0; i<predicted.length; i++ ){ addSeries(dataSet,x,predicted[i],String.valueOf(i)); } final JFreeChart chart = ChartFactory.createXYLineChart( "Regression Example - " + function.getName(), // chart title "X", // x axis label function.getName() + "(X)", // y axis label dataSet, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); final ChartPanel panel = new ChartPanel(chart); final JFrame f = new JFrame(); f.add(panel); f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); f.pack(); f.setVisible(true); } private static void addSeries(final XYSeriesCollection dataSet, final INDArray x, final INDArray y, final String label){ final double[] xd = x.data().asDouble(); final double[] yd = y.data().asDouble(); final XYSeries s = new XYSeries(label); for( int j=0; j<xd.length; j++ ) s.add(xd[j],yd[j]); dataSet.addSeries(s); } }