/**-
* This is a test program that uses word vector and trained network generated by PrepareWordVector.java and TrainNews.java
* - Type or copy/paste news headline from news (indian news channel is preferred) and click on Check button
* and see the predicted category right to the Check button
* <p>
* <b></b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
*/
package org.deeplearning4j.examples.recurrent.processnews;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import java.util.logging.Level;
public class TestNews extends javax.swing.JFrame {
private static String WORD_VECTORS_PATH = "";
private static WordVectors wordVectors;
private static TokenizerFactory tokenizerFactory;
private static int maxLength = 8;
private static String userDirectory = "";
// Variables declaration - do not modify
private javax.swing.JButton jButton1;
private javax.swing.JLabel jLabel1;
private javax.swing.JLabel jLabel2;
private javax.swing.JLabel jLabel3;
private javax.swing.JScrollPane jScrollPane1;
private javax.swing.JTextArea jTextArea1;
private static MultiLayerNetwork net;
public TestNews() {
initComponents();
}
/**
* This method is called from within the constructor to initialize the form.
* WARNING: Do NOT modify this code. The content of this method is always
* regenerated by the Form Editor.
*/
@SuppressWarnings("unchecked")
// <editor-fold defaultstate="collapsed" desc="Generated Code">
private void initComponents() {
this.setTitle("Predict News Category - KITS");
jLabel1 = new javax.swing.JLabel();
jScrollPane1 = new javax.swing.JScrollPane();
jTextArea1 = new javax.swing.JTextArea();
jButton1 = new javax.swing.JButton();
jLabel2 = new javax.swing.JLabel();
jLabel3 = new javax.swing.JLabel();
setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE);
jLabel1.setText("Type News Here");
jTextArea1.setColumns(20);
jTextArea1.setRows(5);
jScrollPane1.setViewportView(jTextArea1);
jButton1.setText("Check");
jButton1.addActionListener(new java.awt.event.ActionListener() {
public void actionPerformed(java.awt.event.ActionEvent evt) {
jButton1ActionPerformed(evt);
}
});
jLabel2.setText("Category");
javax.swing.GroupLayout layout = new javax.swing.GroupLayout(getContentPane());
getContentPane().setLayout(layout);
layout.setHorizontalGroup(
layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
.addGroup(layout.createSequentialGroup()
.addContainerGap()
.addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
.addComponent(jScrollPane1, javax.swing.GroupLayout.DEFAULT_SIZE, 380, Short.MAX_VALUE)
.addGroup(layout.createSequentialGroup()
.addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
.addComponent(jLabel1)
.addComponent(jButton1))
.addGap(0, 0, Short.MAX_VALUE))
.addGroup(layout.createSequentialGroup()
.addComponent(jLabel2)
.addGap(18, 18, 18)
.addComponent(jLabel3, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)))
.addContainerGap())
);
layout.setVerticalGroup(
layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
.addGroup(layout.createSequentialGroup()
.addContainerGap()
.addComponent(jLabel1)
.addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
.addComponent(jScrollPane1, javax.swing.GroupLayout.PREFERRED_SIZE, 134, javax.swing.GroupLayout.PREFERRED_SIZE)
.addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
.addComponent(jButton1)
.addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
.addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
.addComponent(jLabel3)
.addComponent(jLabel2))
.addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))
);
pack();
}// </editor-fold>
private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) {
DataSet testNews = prepareTestData(jTextArea1.getText());
INDArray fet = testNews.getFeatureMatrix();
INDArray predicted = net.output(fet, false);
int arrsiz[] = predicted.shape();
double crimeTotal = 0;
double politicsTotal = 0;
double bollywoodTotal = 0;
double developmentTotal = 0;
String DATA_PATH = userDirectory + "LabelledNews";
File categories = new File(DATA_PATH + File.separator + "categories.txt");
double max = 0;
int pos = 0;
for (int i = 0; i < arrsiz[1]; i++) {
if (max < (double) predicted.getColumn(i).sumNumber()) {
max = (double) predicted.getColumn(i).sumNumber();
pos = i;
}
}
try (BufferedReader brCategories = new BufferedReader(new FileReader(categories))) {
String temp = "";
List<String> labels = new ArrayList<>();
while ((temp = brCategories.readLine()) != null) {
labels.add(temp);
}
brCategories.close();
jLabel3.setText(labels.get(pos).split(",")[1]);
} catch (Exception e) {
System.out.println("File Exception : " + e.getMessage());
}
}
public static void main(String args[]) {
try {
for (javax.swing.UIManager.LookAndFeelInfo info : javax.swing.UIManager.getInstalledLookAndFeels()) {
if ("Nimbus".equals(info.getName())) {
javax.swing.UIManager.setLookAndFeel(info.getClassName());
break;
}
}
} catch (ClassNotFoundException ex) {
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
} catch (InstantiationException ex) {
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
} catch (IllegalAccessException ex) {
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
} catch (javax.swing.UnsupportedLookAndFeelException ex) {
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
}
TestNews test = new TestNews();
test.setVisible(true);
try {
userDirectory = new ClassPathResource("NewsData").getFile().getAbsolutePath() + File.separator;
WORD_VECTORS_PATH = userDirectory + "NewsWordVector.txt";
tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
net = ModelSerializer.restoreMultiLayerNetwork(userDirectory + "NewsModel.net");
wordVectors = WordVectorSerializer.loadTxtVectors(new File(WORD_VECTORS_PATH));
} catch (Exception e) {
}
}
private static DataSet prepareTestData(String i_news) {
List<String> news = new ArrayList<>(1);
int[] category = new int[1];
int currCategory = 0;
news.add(i_news);
List<List<String>> allTokens = new ArrayList<>(news.size());
int maxLength = 0;
for (String s : news) {
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for (String t : tokens) {
if (wordVectors.hasWord(t)) tokensFiltered.add(t);
}
allTokens.add(tokensFiltered);
maxLength = Math.max(maxLength, tokensFiltered.size());
}
INDArray features = Nd4j.create(news.size(), wordVectors.lookupTable().layerSize(), maxLength);
INDArray labels = Nd4j.create(news.size(), 4, maxLength); //labels: Crime, Politics, Bollywood, Business&Development
INDArray featuresMask = Nd4j.zeros(news.size(), maxLength);
INDArray labelsMask = Nd4j.zeros(news.size(), maxLength);
int[] temp = new int[2];
for (int i = 0; i < news.size(); i++) {
List<String> tokens = allTokens.get(i);
temp[0] = i;
for (int j = 0; j < tokens.size() && j < maxLength; j++) {
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(i),
NDArrayIndex.all(),
NDArrayIndex.point(j)},
vector);
temp[1] = j;
featuresMask.putScalar(temp, 1.0);
}
int idx = category[i];
int lastIdx = Math.min(tokens.size(), maxLength);
labels.putScalar(new int[]{i, idx, lastIdx - 1}, 1.0);
labelsMask.putScalar(new int[]{i, lastIdx - 1}, 1.0);
}
DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
return ds;
}
}