package com.yc.nlp.classification; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.yc.nlp.pojo.ClassifyResult; import com.yc.nlp.prob.AddOneProb; import com.yc.nlp.util.MemFile; /** * 贝叶斯分类器使用 * * @author uohzoaix * */ public class Bayes { private static Logger logger = LoggerFactory.getLogger(Bayes.class); private Map<String, AddOneProb> d; private Double total; public Bayes() { logger.debug("initialize bayes begin..."); d = new HashMap<String, AddOneProb>(); total = 0.0; logger.debug("initialize bayes end..."); } public void save(String fname) { Map<String, Object> data = new HashMap<String, Object>(); data.put("total", this.total); Map<String, AddOneProb> probdata = Collections.synchronizedMap(new HashMap<String, AddOneProb>()); for (Map.Entry<String, AddOneProb> entry : this.d.entrySet()) { probdata.put(entry.getKey(), entry.getValue()); } data.put("d", probdata); MemFile.saveToFile(data, fname); } public void load(String fname) { try { byte[] result = MemFile.loadFromFile(fname, this); if (result != null) { MemFile.bayesLoadToMem(result, this); return; } throw new Exception("Bayes读取" + fname + "文件出错!"); } catch (Exception e) { e.printStackTrace(); } } @SuppressWarnings("unchecked") public void train(List<Object[]> data) { for (Object[] d : data) { String c = d[1].toString(); if (!this.d.containsKey(c)) { this.d.put(c, new AddOneProb()); } for (String word : (List<String>) d[0]) { this.d.get(c).add(word, 1); } } for (String key : this.d.keySet()) { this.total += this.d.get(key).getSum(); } } public ClassifyResult classify(List<String> x) { Map<String, Double> tmp = new HashMap<String, Double>(); for (String key : this.d.keySet()) { tmp.put(key, 0.0); for (String word : x) { tmp.put(key, tmp.get(key) + Math.log(this.d.get(key).getSum()) - Math.log(this.total) + Math.log(this.d.get(key).frequency(word))); } } String ret = ""; double prob = 0; for (String key : this.d.keySet()) { double now = 0; for (String otherKey : this.d.keySet()) { now += Math.exp(tmp.get(otherKey) - tmp.get(key)); } now = 1 / now; if (now > prob) { ret = key; prob = now; } } return new ClassifyResult(ret, prob); } }