import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class nb {
private static final boolean DEBUG = false;
public static void main(String[] args){
if(args.length != 2)
throw new IllegalArgumentException("Invalid arguments");
//Build the training data.
File trainFolder = new File(args[0]);
File[] labelFolders = trainFolder.listFiles();
Map<String, List<TrainingDocument>> trainData = new HashMap<String, List<TrainingDocument>>();
for(File d : labelFolders){
if(d.isDirectory() && d.getName() != null){
String label = d.getName();
trainData.put(label, new ArrayList<TrainingDocument>());
for(File f : d.listFiles()){
trainData.get(label).add(new TrainingDocument(f, label));
}
}
}
//Train the classifier.
NaiveBayes b = new NaiveBayes(trainData);
File testFolder = new File(args[0]);
int totalCorrect = 0;
int totalOverall = 0;
for(File d : testFolder.listFiles()){
int correct = 0;
int total = 0;
Map<String, Integer> mcc = new HashMap<String, Integer>();
if(d.isDirectory() && d.getName() != null){
String label = d.getName();
for(File f : d.listFiles()){
total++;
String guess = b.classify(new Document(f));
if(guess.equals(label))
correct++;
else{
if(!mcc.containsKey(guess))
mcc.put(guess, 0);
mcc.put(guess, mcc.get(guess) + 1);
}
}
double acc = correct * 100.0 / (double)total;
if(DEBUG){
System.out.println("Class: " + label + " Accuracy: " + acc);
printErrors(mcc, total);
}
totalCorrect += correct;
totalOverall += total;
}
}
double acc = totalCorrect * 100.0 / (double)totalOverall;
System.out.println("Training Accuracy: " + acc);
//Run the algoritm on the test data.
testFolder = new File(args[1]);
totalCorrect = 0;
totalOverall = 0;
for(File d : testFolder.listFiles()){
int correct = 0;
int total = 0;
Map<String, Integer> mcc = new HashMap<String, Integer>();
if(d.isDirectory() && d.getName() != null){
String label = d.getName();
for(File f : d.listFiles()){
total++;
String guess = b.classify(new Document(f));
if(guess.equals(label))
correct++;
else{
if(!mcc.containsKey(guess))
mcc.put(guess, 0);
mcc.put(guess, mcc.get(guess) + 1);
}
}
acc = correct * 100.0 / (double)total;
if(DEBUG){
System.out.println("Class: " + label + " Accuracy: " + acc);
printErrors(mcc, total);
}
totalCorrect += correct;
totalOverall += total;
}
}
acc = totalCorrect * 100.0 / (double)totalOverall;
System.out.println("Test Accuracy: " + acc);
}
/**
* Prints out the occurrence rate of error guesses.
* @param m map of incorrectly guessed classes as keys and their counts as the values
* @param total total number of guesses made
*/
private static void printErrors(Map<String, Integer> m, int total){
PriorityQueue<Pair> q = new PriorityQueue<Pair>();
for(String s : m.keySet())
q.add(new Pair(s, m.get(s)));
while(!q.isEmpty()){
Pair p = q.poll();
double occurrenceRate = p.i / (double)total;
if(occurrenceRate > 0.1)
System.out.println("\tIncorrectly guessed class: " + p.s + " " + occurrenceRate + " percent of the time.");
}
}
}