import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; public class NaiveBayes { public static final int MOST_COMMON_COUNT = 0; public static final int LEAST_COMMON_COUNT = 0; public static final boolean T42 = false; public static final boolean T43 = false; private Map<String, Integer> docCounts; /** * label -> number of words(including dups) in given class (n) */ private Map<String, Integer> denoms; /** * word -> label -> P(word | label) */ private Map<String, Map<String, Double>> prob; /** * label -> P(label) */ private Map<String, Double> prior; private int numTrainingDocs; /** * @param data label -> List of all Documents with that label */ public NaiveBayes(Map<String, List<TrainingDocument>> data){ this.denoms = new HashMap<String, Integer>(); this.docCounts = new HashMap<String, Integer>(); Map<String, Map<String, Integer>> vocab = buildVocab(data); this.prior = new HashMap<String, Double>(); this.numTrainingDocs = 0; for(String label : data.keySet()){ this.numTrainingDocs += data.get(label).size(); this.prior.put(label, Math.log10((double)data.get(label).size() / (double)data.size())); } this.prob = new HashMap<String, Map<String, Double>>(); this.train(vocab); this.denoms = null; } /** * @param d Document object of file to classify * @return most likely label that the given document is classified as */ public String classify(Document d){ String maxLabel = null; Double maxProb = 0.0; for(String label : this.prior.keySet()){ Double curProb = this.prior.get(label); for(String term : d){ if(this.prob.containsKey(term)){ //if term is in the vocabulary //multiplication accounts for duplicates curProb += d.getCount(term) * prob.get(term).get(label); } } if(curProb >= maxProb || maxLabel == null){ maxProb = curProb; maxLabel = label; } } return maxLabel; } /** * @param vocab word -> label -> number of times given word appears in Text of the given label */ private void train(Map<String, Map<String, Integer>> vocab){ for(String term : vocab.keySet()){ if(!this.prob.containsKey(term)) this.prob.put(term, new HashMap<String, Double>()); for(String label : this.prior.keySet()){ this.prob.get(term).put(label, getProb(vocab, term, label)); } } } /** * @param vocab word -> label -> number of times given word appears in Text of the given label * @param term word to look up * @param label label to look up * @return P(word | label) */ private Double getProb(Map<String, Map<String, Integer>> vocab, String term, String label){ Map<String, Integer> cnts = vocab.get(term); double nk; if(cnts.containsKey(label)) nk = cnts.get(label); else nk = 0.0; if(T42) nk = nk * Math.log10((double) this.numTrainingDocs / (double)this.docCounts.get(term)); else if(T43) nk = Math.log10(1.0 + nk); Integer n = this.denoms.get(label); return Math.log10(((double)nk + 1.0) / (double)(n + vocab.size())); } /** * @param docs label -> List of all Documents with that label * @return hash of vocabulary: word -> label -> number of times given word appears in Text of the given label */ @SuppressWarnings("unused") private Map<String, Map<String, Integer>> buildVocab(Map<String, List<TrainingDocument>> docs){ Map<String, Map<String, Integer>> rtn = new HashMap<String, Map<String, Integer>>(); for(String label : docs.keySet()){ for(TrainingDocument d : docs.get(label)){ int cnt = 0; for(String t : d){ if(!this.docCounts.containsKey(t)) this.docCounts.put(t, 0); this.docCounts.put(t, this.docCounts.get(t) + 1); if(!rtn.containsKey(t)) rtn.put(t, new HashMap<String, Integer>()); Map<String, Integer> cnts = rtn.get(t); if(!cnts.containsKey(d.getLabel())) cnts.put(d.getLabel(), 0); cnts.put(d.getLabel(), cnts.get(d.getLabel()) + 1); cnt += d.getCount(t); } if(!this.denoms.containsKey(d.getLabel())) this.denoms.put(d.getLabel(), 0); this.denoms.put(d.getLabel(), this.denoms.get(d.getLabel()) + cnt); } } if(MOST_COMMON_COUNT > 0 || LEAST_COMMON_COUNT > 0){ PriorityQueue<Pair> mostCommon = new PriorityQueue<Pair>(); PriorityQueue<Pair> leastCommon = new PriorityQueue<Pair>(); for(String word : rtn.keySet()){ int wc = reduce(rtn.get(word).values()); mostCommon.add(new Pair(word, wc)); leastCommon.add(new Pair(word, -wc)); } for(int i = 0; i < MOST_COMMON_COUNT; i++){ rtn.remove(mostCommon.poll().s); } for(int i = 0; i < LEAST_COMMON_COUNT; i++){ rtn.remove(leastCommon.poll().s); } } return rtn; } private static int reduce(Collection<Integer> c){ int rtn = 0; for(Integer i : c) rtn += i; return rtn; } }