package org.deeplearning4j.examples.dataexamples;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.meta.Prediction;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.List;
/**
* This example is a version of the basic CSV example, but adds the following:
* (a) Meta data tracking - i.e., where data for each example comes from
* (b) Additional evaluation information - getting metadata for prediction errors
*
* @author Alex Black
*/
public class CSVExampleEvaluationMetaData {
public static void main(String[] args) throws Exception {
//First: get the dataset using the record reader. This is as per CSV example - see that example for details
RecordReader recordReader = new CSVRecordReader(0, ",");
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
int labelIndex = 4;
int numClasses = 3;
int batchSize = 150;
RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
iterator.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects
DataSet allData = iterator.next();
allData.shuffle(123);
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//Let's view the example metadata in the training and test sets:
List<RecordMetaData> trainMetaData = trainingData.getExampleMetaData(RecordMetaData.class);
List<RecordMetaData> testMetaData = testData.getExampleMetaData(RecordMetaData.class);
//Let's show specifically which examples are in the training and test sets, using the collected metadata
System.out.println(" +++++ Training Set Examples MetaData +++++");
String format = "%-20s\t%s";
for(RecordMetaData recordMetaData : trainMetaData){
System.out.println(String.format(format, recordMetaData.getLocation(), recordMetaData.getURI()));
//Also available: recordMetaData.getReaderClass()
}
System.out.println("\n\n +++++ Test Set Examples MetaData +++++");
for(RecordMetaData recordMetaData : testMetaData){
System.out.println(recordMetaData.getLocation());
}
//Normalize data as per basic CSV example
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
//Configure a simple model. We're not using an optimal configuration here, in order to show evaluation/errors, later
final int numInputs = 4;
int outputNum = 3;
int iterations = 50;
long seed = 6;
System.out.println("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(3).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
//Fit the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(trainingData);
//Evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output, testMetaData); //Note we are passing in the test set metadata here
System.out.println(eval.stats());
//Get a list of prediction errors, from the Evaluation object
//Prediction errors like this are only available after calling iterator.setCollectMetaData(true)
List<Prediction> predictionErrors = eval.getPredictionErrors();
System.out.println("\n\n+++++ Prediction Errors +++++");
for(Prediction p : predictionErrors){
System.out.println("Predicted class: " + p.getPredictedClass() + ", Actual class: " + p.getActualClass()
+ "\t" + p.getRecordMetaData(RecordMetaData.class).getLocation());
}
//We can also load a subset of the data, to a DataSet object:
List<RecordMetaData> predictionErrorMetaData = new ArrayList<>();
for( Prediction p : predictionErrors ) predictionErrorMetaData.add(p.getRecordMetaData(RecordMetaData.class));
DataSet predictionErrorExamples = iterator.loadFromMetaData(predictionErrorMetaData);
normalizer.transform(predictionErrorExamples); //Apply normalization to this subset
//We can also load the raw data:
List<Record> predictionErrorRawData = recordReader.loadFromMetaData(predictionErrorMetaData);
//Print out the prediction errors, along with the raw data, normalized data, labels and network predictions:
for(int i=0; i<predictionErrors.size(); i++ ){
Prediction p = predictionErrors.get(i);
RecordMetaData meta = p.getRecordMetaData(RecordMetaData.class);
INDArray features = predictionErrorExamples.getFeatures().getRow(i);
INDArray labels = predictionErrorExamples.getLabels().getRow(i);
List<Writable> rawData = predictionErrorRawData.get(i).getRecord();
INDArray networkPrediction = model.output(features);
System.out.println(meta.getLocation() + ": "
+ "\tRaw Data: " + rawData
+ "\tNormalized: " + features
+ "\tLabels: " + labels
+ "\tPredictions: " + networkPrediction);
}
//Some other useful evaluation methods:
List<Prediction> list1 = eval.getPredictions(1,2); //Predictions: actual class 1, predicted class 2
List<Prediction> list2 = eval.getPredictionByPredictedClass(2); //All predictions for predicted class 2
List<Prediction> list3 = eval.getPredictionsByActualClass(2); //All predictions for actual class 2
}
}