package org.deeplearning4j.examples.recurrent.seqclassification;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.berkeley.Pair;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
/**
* Sequence Classification Example Using a LSTM Recurrent Neural Network
*
* This example learns how to classify univariate time series as belonging to one of six categories.
* Categories are: Normal, Cyclic, Increasing trend, Decreasing trend, Upward shift, Downward shift
*
* Data is the UCI Synthetic Control Chart Time Series Data Set
* Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series
* Data: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data
* Image: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg
*
* This example proceeds as follows:
* 1. Download and prepare the data (in downloadUCIData() method)
* (a) Split the 600 sequences into train set of size 450, and test set of size 150
* (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification
* This format: one time series per file, and a separate file for the labels.
* For example, train/features/0.csv is the features using with the labels file train/labels/0.csv
* Because the data is a univariate time series, we only have one column in the CSV files. Normally, each column
* would contain multiple values - one time step per row.
* Furthermore, because we have only one label for each time series, the labels CSV files contain only a single value
*
* 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator
* (to convert it to DataSet objects, ready to train)
* For more details on this step, see: http://deeplearning4j.org/usingrnns#data
*
* 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized.
* Normalization is conducted using NormalizerStandardize, based on statistics (mean, st.dev) collected on the training
* data only. Note that both the training data and test data are normalized in the same way.
*
* 4. Configure the network
* The data set here is very small, so we can't afford to use a large network with many parameters.
* We are using one small LSTM layer and one RNN output layer
*
* 5. Train the network for 40 epochs
* At each epoch, evaluate and print the accuracy and f1 on the test set
*
* @author Alex Black
*/
public class UCISequenceClassificationExample {
private static final Logger log = LoggerFactory.getLogger(UCISequenceClassificationExample.class);
//'baseDir': Base directory for the data. Change this if you want to save the data somewhere else
private static File baseDir = new File("src/main/resources/uci/");
private static File baseTrainDir = new File(baseDir, "train");
private static File featuresDirTrain = new File(baseTrainDir, "features");
private static File labelsDirTrain = new File(baseTrainDir, "labels");
private static File baseTestDir = new File(baseDir, "test");
private static File featuresDirTest = new File(baseTestDir, "features");
private static File labelsDirTest = new File(baseTestDir, "labels");
public static void main(String[] args) throws Exception {
downloadUCIData();
// ----- Load the training data -----
//Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
int miniBatchSize = 10;
int numLabelClasses = 6;
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Normalize the training data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainData); //Collect training data statistics
trainData.reset();
//Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
trainData.setPreProcessor(normalizer);
// ----- Load the test data -----
//Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data
// ----- Configure the network -----
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123) //Random number generator seed for improved repeatability. Optional.
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(0.9)
.learningRate(0.005)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
.gradientNormalizationThreshold(0.5)
.list()
.layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build())
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations
// ----- Train the network, evaluating the test set performance at each epoch -----
int nEpochs = 40;
String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
for (int i = 0; i < nEpochs; i++) {
net.fit(trainData);
//Evaluate on the test set:
Evaluation evaluation = net.evaluate(testData);
log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));
testData.reset();
trainData.reset();
}
log.info("----- Example Complete -----");
}
//This method downloads the data, and converts the "one time series per line" format into a suitable
//CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read.
private static void downloadUCIData() throws Exception {
if (baseDir.exists()) return; //Data already exists, don't download it again
String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data";
String data = IOUtils.toString(new URL(url));
String[] lines = data.split("\n");
//Create directories
baseDir.mkdir();
baseTrainDir.mkdir();
featuresDirTrain.mkdir();
labelsDirTrain.mkdir();
baseTestDir.mkdir();
featuresDirTest.mkdir();
labelsDirTest.mkdir();
int lineCount = 0;
List<Pair<String, Integer>> contentAndLabels = new ArrayList<>();
for (String line : lines) {
String transposed = line.replaceAll(" +", "\n");
//Labels: first 100 examples (lines) are label 0, second 100 examples are label 1, and so on
contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100));
}
//Randomize and do a train/test split:
Collections.shuffle(contentAndLabels, new Random(12345));
int nTrain = 450; //75% train, 25% test
int trainCount = 0;
int testCount = 0;
for (Pair<String, Integer> p : contentAndLabels) {
//Write output in a format we can read, in the appropriate locations
File outPathFeatures;
File outPathLabels;
if (trainCount < nTrain) {
outPathFeatures = new File(featuresDirTrain, trainCount + ".csv");
outPathLabels = new File(labelsDirTrain, trainCount + ".csv");
trainCount++;
} else {
outPathFeatures = new File(featuresDirTest, testCount + ".csv");
outPathLabels = new File(labelsDirTest, testCount + ".csv");
testCount++;
}
FileUtils.writeStringToFile(outPathFeatures, p.getFirst());
FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString());
}
}
}