package org.deeplearning4j.examples.recurrent.regression; import org.apache.commons.io.FileUtils; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.eval.RegressionEvaluation; import org.deeplearning4j.examples.userInterface.util.GradientsAndParamsListener; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.axis.NumberAxis; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import org.jfree.ui.RefineryUtilities; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.swing.*; import java.io.File; import java.io.IOException; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.List; /** * This example was inspired by Jason Brownlee's regression examples for Keras, found here: * http://machinelearningmastery.com/time-series-prediction-lstm-recurrent-neural-networks-python-keras/ * <p> * It demonstrates multi time step regression using LSTM */ public class MultiTimestepRegressionExample { private static final Logger LOGGER = LoggerFactory.getLogger(MultiTimestepRegressionExample.class); private static File initBaseFile(String fileName) { try { return new ClassPathResource(fileName).getFile(); } catch (IOException e) { throw new Error(e); } } private static File baseDir = initBaseFile("/rnnRegression"); private static File baseTrainDir = new File(baseDir, "multiTimestepTrain"); private static File featuresDirTrain = new File(baseTrainDir, "features"); private static File labelsDirTrain = new File(baseTrainDir, "labels"); private static File baseTestDir = new File(baseDir, "multiTimestepTest"); private static File featuresDirTest = new File(baseTestDir, "features"); private static File labelsDirTest = new File(baseTestDir, "labels"); private static int numOfVariables = 0; // in csv. public static void main(String[] args) throws Exception { //Set number of examples for training, testing, and time steps int trainSize = 100; int testSize = 20; int numberOfTimesteps = 20; //Prepare multi time step data, see method comments for more info List<String> rawStrings = prepareTrainAndTest(trainSize, testSize, numberOfTimesteps); //Make sure miniBatchSize is divisable by trainSize and testSize, //as rnnTimeStep will not accept different sized examples int miniBatchSize = 10; // ----- Load the training data ----- SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/train_%d.csv", 0, trainSize - 1)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/train_%d.csv", 0, trainSize - 1)); DataSetIterator trainDataIter = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); //Normalize the training data NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); normalizer.fitLabel(true); normalizer.fit(trainDataIter); //Collect training data statistics trainDataIter.reset(); // ----- Load the test data ----- //Same process as for the training data. SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/test_%d.csv", trainSize, trainSize + testSize - 1)); SequenceRecordReader testLabels = new CSVSequenceRecordReader(); testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/test_%d.csv", trainSize, trainSize + testSize - 1)); DataSetIterator testDataIter = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); trainDataIter.setPreProcessor(normalizer); testDataIter.setPreProcessor(normalizer); // ----- Configure the network ----- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(140) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .weightInit(WeightInit.XAVIER) .updater(Updater.NESTEROVS).momentum(0.9) .learningRate(0.15) .list() .layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(numOfVariables).nOut(10) .build()) .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY).nIn(10).nOut(numOfVariables).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new GradientsAndParamsListener(net,100),new ScoreIterationListener(20)); // ----- Train the network, evaluating the test set performance at each epoch ----- int nEpochs = 50; for (int i = 0; i < nEpochs; i++) { net.fit(trainDataIter); trainDataIter.reset(); LOGGER.info("Epoch " + i + " complete. Time series evaluation:"); RegressionEvaluation evaluation = new RegressionEvaluation(2); //Run evaluation. This is on 25k reviews, so can take some time while (testDataIter.hasNext()) { DataSet t = testDataIter.next(); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = net.output(features, true); evaluation.evalTimeSeries(labels, predicted); } System.out.println(evaluation.stats()); testDataIter.reset(); } /* * All code below this point is only necessary for plotting */ //Init rnnTimeStep with train data and predict test data while (trainDataIter.hasNext()) { DataSet t = trainDataIter.next(); net.rnnTimeStep(t.getFeatureMatrix()); } trainDataIter.reset(); DataSet t = testDataIter.next(); INDArray predicted = net.rnnTimeStep(t.getFeatureMatrix()); normalizer.revertLabels(predicted); //Convert raw string data to IndArrays for plotting INDArray trainArray = createIndArrayFromStringList(rawStrings, 0, trainSize); INDArray testArray = createIndArrayFromStringList(rawStrings, trainSize, testSize); //Create plot with out data XYSeriesCollection c = new XYSeriesCollection(); createSeries(c, trainArray, 0, "Train data"); createSeries(c, testArray, trainSize - 1, "Actual test data"); createSeries(c, predicted, trainSize - 1, "Predicted test data"); plotDataset(c); LOGGER.info("----- Example Complete -----"); } /** * Creates an IndArray from a list of strings * Used for plotting purposes */ private static INDArray createIndArrayFromStringList(List<String> rawStrings, int startIndex, int length) { List<String> stringList = rawStrings.subList(startIndex, startIndex + length); double[][] primitives = new double[numOfVariables][stringList.size()]; for (int i = 0; i < stringList.size(); i++) { String[] vals = stringList.get(i).split(","); for (int j = 0; j < vals.length; j++) { primitives[j][i] = Double.valueOf(vals[j]); } } return Nd4j.create(new int[]{1, length}, primitives); } /** * Used to create the different time series for plotting purposes */ private static void createSeries(XYSeriesCollection seriesCollection, INDArray data, int offset, String name) { int nRows = data.shape()[2]; boolean predicted = name.startsWith("Predicted"); XYSeries series = new XYSeries(name); for (int i = 0; i < nRows; i++) { if (predicted) series.add(i + offset, data.slice(0).slice(0).getDouble(i)); else series.add(i + offset, data.slice(0).getDouble(i)); } seriesCollection.addSeries(series); } /** * Generate an xy plot of the datasets provided. */ private static void plotDataset(XYSeriesCollection c) { String title = "Regression example"; String xAxisLabel = "Timestep"; String yAxisLabel = "Number of passengers"; PlotOrientation orientation = PlotOrientation.VERTICAL; boolean legend = true; boolean tooltips = false; boolean urls = false; JFreeChart chart = ChartFactory.createXYLineChart(title, xAxisLabel, yAxisLabel, c, orientation, legend, tooltips, urls); // get a reference to the plot for further customisation... final XYPlot plot = chart.getXYPlot(); // Auto zoom to fit time series in initial window final NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis(); rangeAxis.setAutoRange(true); JPanel panel = new ChartPanel(chart); JFrame f = new JFrame(); f.add(panel); f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); f.pack(); f.setTitle("Training Data"); RefineryUtilities.centerFrameOnScreen(f); f.setVisible(true); } /** * This method shows how you based on a CSV file can preprocess your data the structure expected for a * multi time step problem. This examples uses a single column CSV as input, but the example should be easy to modify * for use with a multi column input as well. * * @return * @throws IOException */ private static List<String> prepareTrainAndTest(int trainSize, int testSize, int numberOfTimesteps) throws IOException { Path rawPath = Paths.get(baseDir.getAbsolutePath() + "/passengers_raw.csv"); List<String> rawStrings = Files.readAllLines(rawPath, Charset.defaultCharset()); setNumOfVariables(rawStrings); //Remove all files before generating new ones FileUtils.cleanDirectory(featuresDirTrain); FileUtils.cleanDirectory(labelsDirTrain); FileUtils.cleanDirectory(featuresDirTest); FileUtils.cleanDirectory(labelsDirTest); for (int i = 0; i < trainSize; i++) { Path featuresPath = Paths.get(featuresDirTrain.getAbsolutePath() + "/train_" + i + ".csv"); Path labelsPath = Paths.get(labelsDirTrain + "/train_" + i + ".csv"); for (int j = 0; j < numberOfTimesteps; j++) { Files.write(featuresPath, rawStrings.get(i + j).concat(System.lineSeparator()).getBytes(), StandardOpenOption.APPEND, StandardOpenOption.CREATE); } Files.write(labelsPath, rawStrings.get(i + numberOfTimesteps).concat(System.lineSeparator()).getBytes(), StandardOpenOption.APPEND, StandardOpenOption.CREATE); } for (int i = trainSize; i < testSize + trainSize; i++) { Path featuresPath = Paths.get(featuresDirTest + "/test_" + i + ".csv"); Path labelsPath = Paths.get(labelsDirTest + "/test_" + i + ".csv"); for (int j = 0; j < numberOfTimesteps; j++) { Files.write(featuresPath, rawStrings.get(i + j).concat(System.lineSeparator()).getBytes(), StandardOpenOption.APPEND, StandardOpenOption.CREATE); } Files.write(labelsPath, rawStrings.get(i + numberOfTimesteps).concat(System.lineSeparator()).getBytes(), StandardOpenOption.APPEND, StandardOpenOption.CREATE); } return rawStrings; } private static void setNumOfVariables(List<String> rawStrings) { numOfVariables = rawStrings.get(0).split(",").length; } }