/**- * This program trains a RNN to predict category of a news headlines. It uses word vector generated from PrepareWordVector.java. * - Labeled News are stored in \dl4j-examples\src\main\resources\NewsData\LabelledNews folder in train and test folders. * - categories.txt file in \dl4j-examples\src\main\resources\NewsData\LabelledNews folder contains category code and description. * - This categories are used along with actual news for training. * - news word vector is contained in \dl4j-examples\src\main\resources\NewsData\NewsWordVector.txt file. * - Trained model is stored in \dl4j-examples\src\main\resources\NewsData\NewsModel.net file * - News Data contains only 3 categories currently. * - Data set structure is as given below * - categories.txt - this file contains various categories in category id,category description format. Sample categories are as below * 0,crime * 1,politics * 2,bollywood * 3,Business&Development * - For each category id above, there is a file containig actual news headlines, e.g. * 0.txt - contains news for crime headlines * 1.txt - contains news for politics headlines * 2.txt - contains news for bollywood * 3.txt - contains news for Business&Development * - You can add any new category by adding one line in categories.txt and respective news file in folder mentioned above. * - Below are training results with the news data given with this example. * ==========================Scores======================================== * Accuracy: 0.9343 * Precision: 0.9249 * Recall: 0.9327 * F1 Score: 0.9288 * ======================================================================== * <p> * Note : * - This code is a modification of original example named Word2VecSentimentRNN.java * - Results may vary with the data you use to train this network * <p> * <b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b> */ package org.deeplearning4j.examples.recurrent.processnews; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; public class TrainNews { public static String userDirectory = ""; public static String DATA_PATH = ""; public static String WORD_VECTORS_PATH = ""; public static WordVectors wordVectors; private static TokenizerFactory tokenizerFactory; public static void main(String[] args) throws Exception { userDirectory = new ClassPathResource("NewsData").getFile().getAbsolutePath() + File.separator; DATA_PATH = userDirectory + "LabelledNews"; WORD_VECTORS_PATH = userDirectory + "NewsWordVector.txt"; int batchSize = 50; //Number of examples in each minibatch int nEpochs = 1000; //Number of epochs (full passes of training data) to train on int truncateReviewsToLength = 300; //Truncate reviews with length (# words) greater than this //DataSetIterators for training and testing respectively //Using AsyncDataSetIterator to do data loading in a separate thread; this may improve performance vs. waiting for data to load wordVectors = WordVectorSerializer.loadTxtVectors(new File(WORD_VECTORS_PATH)); TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); NewsIterator iTrain = new NewsIterator.Builder() .dataDirectory(DATA_PATH) .wordVectors(wordVectors) .batchSize(batchSize) .truncateLength(truncateReviewsToLength) .tokenizerFactory(tokenizerFactory) .train(true) .build(); NewsIterator iTest = new NewsIterator.Builder() .dataDirectory(DATA_PATH) .wordVectors(wordVectors) .batchSize(batchSize) .tokenizerFactory(tokenizerFactory) .truncateLength(truncateReviewsToLength) .train(false) .build(); //DataSetIterator train = new AsyncDataSetIterator(iTrain,1); //DataSetIterator test = new AsyncDataSetIterator(iTest,1); int inputNeurons = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; // 100 in our case int outputs = iTrain.getLabels().size(); tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); //Set up network configuration MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.RMSPROP) .regularization(true).l2(1e-5) .weightInit(WeightInit.XAVIER) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0) .learningRate(0.0018) .list() .layer(0, new GravesLSTM.Builder().nIn(inputNeurons).nOut(200) .activation("softsign").build()) .layer(1, new RnnOutputLayer.Builder().activation("softmax") .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(200).nOut(outputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(1)); System.out.println("Starting training"); for (int i = 0; i < nEpochs; i++) { net.fit(iTrain); iTrain.reset(); System.out.println("Epoch " + i + " complete. Starting evaluation:"); //Run evaluation. This is on 25k reviews, so can take some time Evaluation evaluation = new Evaluation(); while (iTest.hasNext()) { DataSet t = iTest.next(); INDArray features = t.getFeatureMatrix(); INDArray lables = t.getLabels(); //System.out.println("labels : " + lables); INDArray inMask = t.getFeaturesMaskArray(); INDArray outMask = t.getLabelsMaskArray(); INDArray predicted = net.output(features, false); //System.out.println("predicted : " + predicted); evaluation.evalTimeSeries(lables, predicted, outMask); } iTest.reset(); System.out.println(evaluation.stats()); } ModelSerializer.writeModel(net, userDirectory + "NewsModel.net", true); System.out.println("----- Example complete -----"); } }