package edu.usc.cssl.tacit.classify.svm.services; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileOutputStream; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.ObjectOutputStream; import java.text.DateFormat; import java.text.NumberFormat; import java.text.SimpleDateFormat; import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.TreeMap; import org.apache.commons.math3.stat.inference.AlternativeHypothesis; import org.apache.commons.math3.stat.inference.BinomialTest; import edu.usc.cssl.tacit.common.ui.views.ConsoleView; public class SVMClassify { private String intermediatePath; private File modelFile; private boolean doTfidf; private int featureMapIndex; private HashMap<String, Integer> featureMap = new HashMap<String, Integer>(); private HashMap<String, Integer> dfMap = new HashMap<String, Integer>(); private String delimiters = " .,;'\"!-()[]{}:?"; private int noOfDocuments = 0; public SVMClassify(String class1Name, String class2Name, String outputFolder) { this.intermediatePath = outputFolder + System.getProperty("file.separator") + "SVM-Classification"; } public void buildDfMap(File inputFile) throws IOException { BufferedReader br = new BufferedReader(new FileReader(inputFile)); // ConsoleView.writeInConsole("Building map for: "+inputFile.getAbsolutePath()); String currentLine; StringBuilder fullFile = new StringBuilder(); while ((currentLine = br.readLine()) != null) { fullFile.append(currentLine + ' '); } String input = fullFile.toString(); for (char c : delimiters.toCharArray()) input = input.replace(c, ' '); HashSet<String> wordSet = new HashSet<String>(); for (String word : input.split("\\s+")) { wordSet.add(word); } for (String word : wordSet) { if (!(dfMap.containsKey(word))) { dfMap.put(word, 1); } else { dfMap.put(word, dfMap.get(word) + 1); } } br.close(); } public HashMap<String, Double> fileToBow(File inputFile) throws IOException { HashMap<String, Double> hashMap = new HashMap<String, Double>(); BufferedReader br = new BufferedReader(new FileReader(inputFile)); String currentLine; // Converting the file to one string for faster processing StringBuilder fullFile = new StringBuilder(); while ((currentLine = br.readLine()) != null) { fullFile.append(currentLine + ' '); } // ConsoleView.writeInConsole(fullFile); String input = fullFile.toString(); for (char c : delimiters.toCharArray()) input = input.replace(c, ' '); // ConsoleView.writeInConsole(input); for (String word : input.split("\\s+")) { if (!hashMap.containsKey(word)) hashMap.put(word, (double) 1); else { hashMap.put(word, hashMap.get(word) + 1); } } // If TF.IDF method, multiply each hashMap value with IDF. IDF = log10( // noOfDocuments / no of documents containing the current word) if (doTfidf) { double tfidf = 0; for (String word : hashMap.keySet()) { Integer docsContaining; if ((docsContaining = dfMap.get(word)) != null) { tfidf = hashMap.get(word) * (Math.log10(noOfDocuments / (double) docsContaining)); // ConsoleView.writeInConsole(word+" - "+noOfDocuments+"/"+(double)docsContaining); } else { continue; // If new word, none of the training documents // will contain it. So, skip. } hashMap.put(word, tfidf); } } // ConsoleView.writeInConsole(hashMap); br.close(); return hashMap; } public String BowToString(HashMap<String, Double> bow) { TreeMap<Integer, Double> integerMap = new TreeMap<Integer, Double>(); for (String word : bow.keySet()) { if (featureMap.containsKey(word)) { integerMap.put(featureMap.get(word), bow.get(word)); } else { featureMapIndex = featureMapIndex + 1; featureMap.put(word, featureMapIndex); integerMap.put(featureMapIndex, bow.get(word)); } } // ConsoleView.writeInConsole(integerMap.toString()); // ConsoleView.writeInConsole(bow.toString()); StringBuilder sb = new StringBuilder(); for (int i : integerMap.keySet()) { sb.append(i + ":" + integerMap.get(i) + " "); } // ConsoleView.writeInConsole(sb.toString().trim()); return sb.toString().trim(); } public String BowToTestString(HashMap<String, Double> bow) { TreeMap<Integer, Double> integerMap = new TreeMap<Integer, Double>(); for (String word : bow.keySet()) { if (featureMap.containsKey(word)) { integerMap.put(featureMap.get(word), bow.get(word)); } else { // Ignore new words } } StringBuilder sb = new StringBuilder(); for (int i : integerMap.keySet()) { sb.append(i + ":" + integerMap.get(i) + " "); } return sb.toString().trim(); } public HashMap<Integer, Double> computePredictiveWeights(File modelFile) throws IOException { BufferedReader br = new BufferedReader(new FileReader(modelFile)); HashMap<Integer, Double> weights = new HashMap<Integer, Double>(); String currentLine; while ((currentLine = br.readLine()) != null) { if (currentLine.equals("SV")) { break; } } while ((currentLine = br.readLine()) != null) { String[] items = currentLine.split("\\s+"); double alpha = Double.parseDouble(items[0]); // ConsoleView.writeInConsole(alpha); for (int i = 1; i < items.length; i++) { String[] pair = items[i].split(":"); int featureID = Integer.parseInt(pair[0]); double weight = Double.parseDouble(pair[1]); // ConsoleView.writeInConsole(pair[0]+" "+pair[1]); if (weights.containsKey(featureID)) { weights.put(featureID, weights.get(featureID) + (alpha * weight)); } else { weights.put(featureID, alpha * weight); } } } // ConsoleView.writeInConsole(weights); br.close(); return weights; } public int cross_train(String kVal, String label1, File[] trainFiles1, String label2, File[] trainFiles2, boolean doPredictiveWeights, Date dateObj) throws IOException { int ret = 0; modelFile = new File(intermediatePath + "_" + kVal + ".model"); File trainFile = new File(intermediatePath + "_" + kVal + ".train"); this.doTfidf = true; featureMapIndex = 0; featureMap.clear(); dfMap.clear(); noOfDocuments = 0; if (doTfidf) { for (File file : trainFiles1) { if (file.getAbsolutePath().contains("DS_Store")) continue; noOfDocuments = noOfDocuments + 1; // Count the total no of // documents buildDfMap(file); } for (File file : trainFiles2) { if (file.getAbsolutePath().contains("DS_Store")) continue; noOfDocuments = noOfDocuments + 1; // Count the total no of // documents buildDfMap(file); } // ConsoleView.writeInConsole("dfmap -"+dfMap); ConsoleView .printlInConsoleln("Finished building document frequency map."); } BufferedWriter bw = new BufferedWriter(new FileWriter(trainFile)); for (File file : trainFiles1) { if (file.getAbsolutePath().contains("DS_Store")) continue; // ConsoleView.writeInConsole("Reading File "+file.toString()); bw.write("+1 " + BowToString(fileToBow(file))); bw.newLine(); } for (File file : trainFiles2) { if (file.getAbsolutePath().contains("DS_Store")) continue; // ConsoleView.writeInConsole("Reading File "+file.toString()); bw.write("-1 " + BowToString(fileToBow(file))); bw.newLine(); } ConsoleView.printlInConsoleln("Total number of documents - " + noOfDocuments + ". Total unique features - " + featureMapIndex); // ConsoleView.printlInConsoleln("Finished building SVM-format training file - "+trainFile.getAbsolutePath()); bw.close(); String[] train_arguments; ConsoleView.printlInConsoleln("Linear Kernel selected"); train_arguments = new String[4]; train_arguments[0] = "-t"; train_arguments[1] = "0"; train_arguments[2] = trainFile.getAbsolutePath(); train_arguments[3] = modelFile.getAbsolutePath(); DateFormat df = new SimpleDateFormat("MM-dd-yy-HH-mm-ss"); ConsoleView.printlInConsoleln("Training the classifier..."); double[] result = SVMTrain.main(train_arguments); double crossValResult = result[0]; double pvalue = result[1]; // ConsoleView.printlInConsoleln("Model file created - "+modelFile.getAbsolutePath()); // Saving the feature map File hashmap = new File(intermediatePath + "_" + kVal + ".hashmap"); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream( hashmap)); oos.writeObject(featureMap); oos.flush(); oos.close(); // ConsoleView.printlInConsoleln("Feature Map saved - "+hashmap.getAbsolutePath()); HashMap<Integer, String> reverseMap = new HashMap<Integer, String>(); for (String k : featureMap.keySet()) { reverseMap.put(featureMap.get(k), k); } if (doPredictiveWeights) { // PredictiveWeights pw = new PredictiveWeights(); File weightsFile = new File(intermediatePath + "-weights" + "-" + kVal +"-"+df.format(dateObj)+".csv"); BufferedWriter weightsWriter = new BufferedWriter(new FileWriter( weightsFile)); // HashMap<Integer,Double> weightsMap = // pw.computePredictiveWeights(modelFile); HashMap<Integer, Double> weightsMap = computePredictiveWeights(modelFile); weightsWriter.write("Word,ID,Weight\n"); for (Integer i : weightsMap.keySet()) { // System.out.print(i+" "); weightsWriter.write(reverseMap.get(i) + "," + i + "," + weightsMap.get(i) + "\n"); } ConsoleView.printlInConsoleln("Created Predictive Weights file - " + weightsFile.getAbsolutePath()); weightsWriter.close(); } return ret; } public double cross_predict(String kVal, String label1, File[] testFiles1, String label2, File[] testFiles2) throws IOException { // if TFIDF method, clear and rebuild df map dfMap.clear(); noOfDocuments = 0; if (doTfidf) { for (File file : testFiles1) { if (file.getAbsolutePath().contains("DS_Store")) continue; noOfDocuments = noOfDocuments + 1; buildDfMap(file); } for (File file : testFiles2) { if (file.getAbsolutePath().contains("DS_Store")) continue; noOfDocuments = noOfDocuments + 1; buildDfMap(file); } // ConsoleView.writeInConsole("dfmap -"+dfMap); ConsoleView .printlInConsoleln("Finished building document frequency map."); } // Create a test file just like the training file was created. // Use the existing featureMap, ignore new words. File testFile = new File(intermediatePath + "_" + kVal + ".test"); BufferedWriter bw = new BufferedWriter(new FileWriter(testFile)); for (File file : testFiles1) { if (file.getAbsolutePath().contains("DS_Store")) continue; // ConsoleView.writeInConsole("Reading File "+file.toString()); bw.write("+1 " + BowToTestString(fileToBow(file))); bw.newLine(); } for (File file : testFiles2) { if (file.getAbsolutePath().contains("DS_Store")) continue; // ConsoleView.writeInConsole("Reading File "+file.toString()); bw.write("-1 " + BowToTestString(fileToBow(file))); bw.newLine(); } // ConsoleView.printlInConsoleln("Finished building SVM-format test file - "+testFile.getAbsolutePath()); bw.close(); // ConsoleView.printlInConsoleln("Model file loaded - "+modelFile.getAbsolutePath()); String[] predict_arguments = new String[3]; predict_arguments[0] = testFile.getAbsolutePath(); predict_arguments[1] = modelFile.getAbsolutePath(); predict_arguments[2] = intermediatePath + "_" + kVal + ".out"; double[] result = SVMPredict.main(predict_arguments); int correct = (int) result[0], total = (int) result[1]; // double pvalue = result[2]; BinomialTest btest = new BinomialTest(); double p = 0.5; double pvalue = btest.binomialTest(total, correct, p, AlternativeHypothesis.TWO_SIDED); // ConsoleView.printlInConsoleln("Created SVM output file - "+intermediatePath+"_"+kVal+".out"); ConsoleView.printlInConsoleln("Accuracy = " + (double) correct / total * 100 + "% (" + correct + "/" + total + ") (classification)\n"); ConsoleView.printlInConsoleln("Binomial Test P value = " + pvalue); NumberFormat nf = NumberFormat.getInstance(); nf.setMaximumFractionDigits(Integer.MAX_VALUE); // ConsoleView.writeInConsole(nf.format(pvalue)); if (pvalue != 0) { if (pvalue > 0.5) pvalue = Math.abs(pvalue - 1); } return (double) correct / total * 100; } }