package hip.mahout;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableInt;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class HomegrownNBClassifier {
public static void main(String... args) {
Classifier c = new Classifier();
c.train("ham windyhill roofing estimate");
c.train("ham quick hadoop meetup");
c.train("spam cheap quick xanax");
c.train("spam quick easy money");
System.out.println(c);
c.classify("quick money");
}
public static class Classifier {
Map<String, Category> categories = new HashMap<String, Category>();
int numDocuments;
public void train(String document) {
numDocuments++;
String[] parts = StringUtils.split(document);
String category = parts[0];
List<String> words = Arrays.asList(
Arrays.copyOfRange(parts, 1, parts.length));
Category cat = categories.get(category);
if (cat == null) {
cat = new Category(category);
categories.put(category, cat);
}
cat.train(words);
for (Category c : categories.values()) {
c.updateProbability(numDocuments);
}
}
public void classify(String words) {
String[] parts = StringUtils.split(words);
for (Category c : categories.values()) {
double p = 1.0;
for (String word : parts) {
p *= c.weightedProbability(word);
}
System.out.println("Probability of document '" + words +
"' for category '" + c.label +
"' is " + (p * c.categoryProbability));
}
}
public String toString() {
StringBuilder sb = new StringBuilder();
for (Category cat : categories.values()) {
sb.append(cat).append("\n");
}
return sb.toString();
}
}
public static class Category {
String label;
int numDocuments;
double categoryProbability;
Map<String, MutableInt> features =
new HashMap<String, MutableInt>();
public Category(String label) {
this.label = label;
}
void train(List<String> words) {
numDocuments++;
for (String word : words) {
MutableInt i = features.get(word);
if (i == null) {
i = new MutableInt(0);
features.put(word, i);
}
i.increment();
}
}
void updateProbability(int totalDocuments) {
categoryProbability = (double) numDocuments /
(double) totalDocuments;
}
double weightedProbability(String word) {
MutableInt i = features.get(word);
return (i == null ? 0.1 : (i.doubleValue() /
(double) numDocuments));
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Category = ").append(label)
.append(", numDocs = ").append(numDocuments)
.append(", categoryProbability = ")
.append(categoryProbability);
for (Map.Entry<String, MutableInt> entry : features.entrySet()) {
sb.append("\n ").append(entry.getKey()).append(" ")
.append(entry.getValue());
}
return sb.toString();
}
}
}