package com.yc.nlp.util; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import com.yc.nlp.pojo.Pre; import com.yc.nlp.pojo.Result; import com.yc.nlp.pojo.StageValue; import com.yc.nlp.pojo.Tag; import com.yc.nlp.pojo.WordTag; import com.yc.nlp.prob.AddOneProb; import com.yc.nlp.prob.BaseProb; import com.yc.nlp.prob.NormalProb; public class TnT { private Integer num; private Double l1; private Double l2; private Double l3; private Set<String> status; private BaseProb wd, eos, eosd, uni, bi, tri; private Map<String, Set<String>> word; private Map<Object, Double> trans; public TnT() { this(1000); } public TnT(Integer num) { this.num = num; this.l1 = 0.0; this.l2 = 0.0; this.l3 = 0.0; this.status = new HashSet<String>(); this.wd = new AddOneProb(); this.eos = new AddOneProb(); this.eosd = new AddOneProb(); this.uni = new NormalProb(); this.bi = new NormalProb(); this.tri = new NormalProb(); this.word = new HashMap<String, Set<String>>(); this.trans = new HashMap<Object, Double>(); } public Integer getNum() { return num; } /** * 将内存中的数据写到文件中 * * @param fname */ public void save(String fname) { MemFile.loadFromMem(fname, this); } /** * 将文件内容导入到内存 * * @param fname */ public void load(String fname) { try { byte[] result = MemFile.loadFromFile(fname, this); if (result != null) { MemFile.loadToMem(result, this); return; } throw new Exception("TnT读取" + fname + "文件出错!"); } catch (Exception e) { e.printStackTrace(); } } public double tntDiv(double v1, double v2) { if (v2 == 0) { return v2; } return v1 / v2; } public double getEos(String tag) { if (!eosd.exist(tag)) { return Math.log(1.0 / this.status.size()); } return Math.log(this.eos.get(tag + "-EOS")) - Math.log(this.eosd.get(tag)); } /** * 训练样本,将每个字的词性进行存储 * * @param data */ public void train(List<List<WordTag>> data) { Tuple<String> now = new Tuple<String>(); now.addAll(Arrays.asList("BOS", "BOS")); for (List<WordTag> wtList : data) { this.bi.add("BOS-BOS", 1); this.uni.add("BOS", 2); for (WordTag wt : wtList) { now.add(wt.getTag()); String tupleStr = now.subList(1, now.size()).toString(); this.status.add(wt.getTag()); this.wd.add(wt.toString(), 1); this.eos.add(tupleStr, 1); this.eosd.add(wt.getTag(), 1); this.uni.add(wt.getTag(), 1); this.bi.add(tupleStr, 1); this.tri.add(now.toString(), 1); if (!this.word.containsKey(wt.getWord())) { Set<String> tags = new HashSet<String>(); this.word.put(wt.getWord(), tags); } this.word.get(wt.getWord()).add(wt.getTag()); now.remove(0); } this.eos.add(now.get(now.size() - 1) + "-EOS", 1); } double tl1 = 0.0, tl2 = 0.0, tl3 = 0.0; for (String key : this.tri.samples()) { now = Tuple.fromStr(key); double c3 = this.tntDiv(this.tri.get(now.toString()) - 1, this.bi.get(now.subList(0, 2).toString()) - 1); double c2 = this.tntDiv(this.bi.get(now.subList(1, now.size()).toString()) - 1, this.uni.get(now.get(1)) - 1); double c1 = this.tntDiv(this.uni.get(now.get(2)) - 1, this.uni.getSum() - 1); double result=this.tri.get(now.toString()); if (c3 >= c1 && c3 >= c2) { tl3 += result; } else if (c2 >= c1 && c2 >= c3) { tl2 += result; } else if (c1 >= c2 && c1 >= c3) { tl1 += result; } } this.l1 = tl1 / (tl1 + tl2 + tl3); this.l2 = tl2 / (tl1 + tl2 + tl3); this.l3 = tl3 / (tl1 + tl2 + tl3); Set<String> newStatus = new HashSet<String>(); newStatus.addAll(status); newStatus.add("BOS"); for (String s1 : newStatus) { for (String s2 : newStatus) { for (String s3 : status) { if (s1.equals("BOS") && s2.equals("BOS") && s3.equals("s")) { System.out.println("aa"); } double uni = this.l1 * this.uni.frequency(s3); double bi = this.tntDiv(this.l2 * this.bi.get(s2 + "-" + s3), this.uni.get(s2)); double tri = this.tntDiv(this.l3 * this.tri.get(s1 + "-" + s2 + "-" + s3), this.bi.get(s1 + "-" + s2)); this.trans.put(s1 + "-" + s2 + "-" + s3, Math.log(uni + bi + tri)); } } } } /** * 获取每个字最有可能的一种词性 * * @param data * @return * @throws Exception */ public List<Result> tag(List<String> data) throws Exception { List<Tag> tags = new ArrayList<Tag>(getNum()); tags.add(new Tag(new Pre("BOS", "BOS"), 0.0, "")); Map<Pre, StageValue> stage = new HashMap<Pre, StageValue>(); for (String ch : data) { stage = new HashMap<Pre, StageValue>(); Set<String> samples = status; if (this.word.containsKey(ch)) { samples = this.word.get(ch); } for (String s : samples) { double wd = Math.log(this.wd.get(s + "-" + ch)) - Math.log(this.uni.get(s)); for (Tag tag : tags) { double p = tag.getScore() + wd + this.trans.get(tag.getPrefix().toString() + "-" + s); Pre pre = new Pre(tag.getPrefix().getTwo(), s); if (!stage.containsKey(pre) || p > stage.get(pre).getScore()) { stage.put(pre, new StageValue(p, tag.getSuffix().equals("") ? s : (tag.getSuffix() + "-" + s))); } } } tags.clear(); for (Map.Entry<Pre, StageValue> entry : stage.entrySet()) { tags.add(new Tag(entry.getKey(), entry.getValue().getScore(), entry.getValue().getValue())); } Collections.sort(tags, new Comparator<Tag>() { public int compare(Tag o1, Tag o2) { if (o2.getScore() == o1.getScore()) return 0; return (o2.getScore() - o1.getScore() > 0 ? 1 : -1); } }); while (tags.size() > getNum()) { tags = tags.subList(0, getNum()); } } tags.clear(); for (Map.Entry<Pre, StageValue> entry : stage.entrySet()) { double score = entry.getValue().getScore() + getEos(entry.getKey().getTwo()); tags.add(new Tag(entry.getKey(), score, entry.getValue().getValue())); } Collections.sort(tags, new Comparator<Tag>() { public int compare(Tag o1, Tag o2) { if (o2.getScore() == o1.getScore()) return 0; return (o2.getScore() - o1.getScore() > 0 ? 1 : -1); } }); List<Result> results = new ArrayList<Result>(); String[] tagArr = tags.get(0).getSuffix().split("-"); if (tagArr.length != data.size()) { throw new Exception("出错了!"); } for (int i = 0; i < data.size(); i++) { results.add(new Result(data.get(i), tagArr[i])); } return results; } public static void main(String[] args) { TnT tnt = new TnT(1000); tnt.save("seg.marshal"); tnt.load("seg.marshal"); } }