/**- * This is a test program that uses word vector and trained network generated by PrepareWordVector.java and TrainNews.java * - Type or copy/paste news headline from news (indian news channel is preferred) and click on Check button * and see the predicted category right to the Check button * <p> * <b></b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b> */ package org.deeplearning4j.examples.recurrent.processnews; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; import java.util.logging.Level; public class TestNews extends javax.swing.JFrame { private static String WORD_VECTORS_PATH = ""; private static WordVectors wordVectors; private static TokenizerFactory tokenizerFactory; private static int maxLength = 8; private static String userDirectory = ""; // Variables declaration - do not modify private javax.swing.JButton jButton1; private javax.swing.JLabel jLabel1; private javax.swing.JLabel jLabel2; private javax.swing.JLabel jLabel3; private javax.swing.JScrollPane jScrollPane1; private javax.swing.JTextArea jTextArea1; private static MultiLayerNetwork net; public TestNews() { initComponents(); } /** * This method is called from within the constructor to initialize the form. * WARNING: Do NOT modify this code. The content of this method is always * regenerated by the Form Editor. */ @SuppressWarnings("unchecked") // <editor-fold defaultstate="collapsed" desc="Generated Code"> private void initComponents() { this.setTitle("Predict News Category - KITS"); jLabel1 = new javax.swing.JLabel(); jScrollPane1 = new javax.swing.JScrollPane(); jTextArea1 = new javax.swing.JTextArea(); jButton1 = new javax.swing.JButton(); jLabel2 = new javax.swing.JLabel(); jLabel3 = new javax.swing.JLabel(); setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE); jLabel1.setText("Type News Here"); jTextArea1.setColumns(20); jTextArea1.setRows(5); jScrollPane1.setViewportView(jTextArea1); jButton1.setText("Check"); jButton1.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { jButton1ActionPerformed(evt); } }); jLabel2.setText("Category"); javax.swing.GroupLayout layout = new javax.swing.GroupLayout(getContentPane()); getContentPane().setLayout(layout); layout.setHorizontalGroup( layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(layout.createSequentialGroup() .addContainerGap() .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(jScrollPane1, javax.swing.GroupLayout.DEFAULT_SIZE, 380, Short.MAX_VALUE) .addGroup(layout.createSequentialGroup() .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(jLabel1) .addComponent(jButton1)) .addGap(0, 0, Short.MAX_VALUE)) .addGroup(layout.createSequentialGroup() .addComponent(jLabel2) .addGap(18, 18, 18) .addComponent(jLabel3, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))) .addContainerGap()) ); layout.setVerticalGroup( layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(layout.createSequentialGroup() .addContainerGap() .addComponent(jLabel1) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addComponent(jScrollPane1, javax.swing.GroupLayout.PREFERRED_SIZE, 134, javax.swing.GroupLayout.PREFERRED_SIZE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addComponent(jButton1) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(jLabel3) .addComponent(jLabel2)) .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)) ); pack(); }// </editor-fold> private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) { DataSet testNews = prepareTestData(jTextArea1.getText()); INDArray fet = testNews.getFeatureMatrix(); INDArray predicted = net.output(fet, false); int arrsiz[] = predicted.shape(); double crimeTotal = 0; double politicsTotal = 0; double bollywoodTotal = 0; double developmentTotal = 0; String DATA_PATH = userDirectory + "LabelledNews"; File categories = new File(DATA_PATH + File.separator + "categories.txt"); double max = 0; int pos = 0; for (int i = 0; i < arrsiz[1]; i++) { if (max < (double) predicted.getColumn(i).sumNumber()) { max = (double) predicted.getColumn(i).sumNumber(); pos = i; } } try (BufferedReader brCategories = new BufferedReader(new FileReader(categories))) { String temp = ""; List<String> labels = new ArrayList<>(); while ((temp = brCategories.readLine()) != null) { labels.add(temp); } brCategories.close(); jLabel3.setText(labels.get(pos).split(",")[1]); } catch (Exception e) { System.out.println("File Exception : " + e.getMessage()); } } public static void main(String args[]) { try { for (javax.swing.UIManager.LookAndFeelInfo info : javax.swing.UIManager.getInstalledLookAndFeels()) { if ("Nimbus".equals(info.getName())) { javax.swing.UIManager.setLookAndFeel(info.getClassName()); break; } } } catch (ClassNotFoundException ex) { Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex); } catch (InstantiationException ex) { Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex); } catch (IllegalAccessException ex) { Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex); } catch (javax.swing.UnsupportedLookAndFeelException ex) { Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex); } TestNews test = new TestNews(); test.setVisible(true); try { userDirectory = new ClassPathResource("NewsData").getFile().getAbsolutePath() + File.separator; WORD_VECTORS_PATH = userDirectory + "NewsWordVector.txt"; tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); net = ModelSerializer.restoreMultiLayerNetwork(userDirectory + "NewsModel.net"); wordVectors = WordVectorSerializer.loadTxtVectors(new File(WORD_VECTORS_PATH)); } catch (Exception e) { } } private static DataSet prepareTestData(String i_news) { List<String> news = new ArrayList<>(1); int[] category = new int[1]; int currCategory = 0; news.add(i_news); List<List<String>> allTokens = new ArrayList<>(news.size()); int maxLength = 0; for (String s : news) { List<String> tokens = tokenizerFactory.create(s).getTokens(); List<String> tokensFiltered = new ArrayList<>(); for (String t : tokens) { if (wordVectors.hasWord(t)) tokensFiltered.add(t); } allTokens.add(tokensFiltered); maxLength = Math.max(maxLength, tokensFiltered.size()); } INDArray features = Nd4j.create(news.size(), wordVectors.lookupTable().layerSize(), maxLength); INDArray labels = Nd4j.create(news.size(), 4, maxLength); //labels: Crime, Politics, Bollywood, Business&Development INDArray featuresMask = Nd4j.zeros(news.size(), maxLength); INDArray labelsMask = Nd4j.zeros(news.size(), maxLength); int[] temp = new int[2]; for (int i = 0; i < news.size(); i++) { List<String> tokens = allTokens.get(i); temp[0] = i; for (int j = 0; j < tokens.size() && j < maxLength; j++) { String token = tokens.get(j); INDArray vector = wordVectors.getWordVectorMatrix(token); features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector); temp[1] = j; featuresMask.putScalar(temp, 1.0); } int idx = category[i]; int lastIdx = Math.min(tokens.size(), maxLength); labels.putScalar(new int[]{i, idx, lastIdx - 1}, 1.0); labelsMask.putScalar(new int[]{i, lastIdx - 1}, 1.0); } DataSet ds = new DataSet(features, labels, featuresMask, labelsMask); return ds; } }