package ruc.irm.classification; import java.io.DataInputStream; import java.io.DataOutput; import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.Map; public class NaiveBayesClassifier { /** * 记录每个类别下出现的文档数量, 用于计算P(C)使用 */ Variable VARIABLE = new Variable(); /** * 词语在所有类别中的总数量 */ Map<String, Integer> TERM_TOTAL_COUNT = new HashMap<String, Integer>(); /** * 训练一篇文档 * @param doc */ public void training(Instance doc) { VARIABLE.addInstance(doc); } /** * 保存训练结果 * @throws IOException */ void save(File file) throws IOException{ DataOutput out = new DataOutputStream(new FileOutputStream(file)); VARIABLE.write(out); } public void load(File file) throws IOException{ DataInputStream in = new DataInputStream(new FileInputStream(file)); VARIABLE = Variable.read(in); } /** * 计算P(C) * @param category * @return */ public double getCategoryProbability(String category){ return Math.log(VARIABLE.getDocCount(category)*1.0f/VARIABLE.getDocCount()); } /** * 计算P(feature|cateogry),返回的是取对数后的数值 * @param feature * @param category * @return */ public double getFeatureProbability(String feature, String category){ int m = VARIABLE.getFeatureCount(); return Math.log((VARIABLE.getDocCount(feature, category)+1.0)/(VARIABLE.getDocCount(category)+m)); } /** * 计算给定实例文档属于指定类别的概率,返回的是取对数后的数值 * @param category * @param doc * @return */ public double getProbability(String category, Instance doc) { double result = getCategoryProbability(category); for(String feature:doc.getWords()){ if(VARIABLE.containFeature(feature)){ result += getFeatureProbability(feature, category); } } return result; } public String getCategory(Instance doc){ Collection<String> categories = VARIABLE.getCategories(); double best = Double.NEGATIVE_INFINITY; String bestName = null; for(String c:categories){ double current = getProbability(c, doc); // System.out.println(c + ":" + current); if(best<current){ best = current; bestName = c; } } return bestName; } public static void main(String[] args) throws IOException { NaiveBayesClassifier classifier = new NaiveBayesClassifier(); // File samplePath = new File("./corpus/Sample"); // for(File categoryPath:samplePath.listFiles()){ // String category = categoryPath.getName(); // for(File f:categoryPath.listFiles()){ // classifier.training(new Instance(category, f, "GBK")); // } // } // classifier.save(new File("result.dat")); // System.out.println("Finished!"); classifier.load(new File("result.dat")); Instance doc = new Instance(null, new File("/tmp/10.txt"), "GBK"); System.out.println(classifier.getCategory(doc)); } }