package edu.fudan.nlp.parser.dep; import java.io.FileInputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.util.Arrays; import java.util.Set; import java.util.zip.GZIPInputStream; import edu.fudan.ml.classifier.Predict; import edu.fudan.ml.classifier.linear.Linear; import edu.fudan.ml.classifier.linear.inf.Inferencer; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.alphabet.AlphabetFactory; import edu.fudan.ml.types.alphabet.IFeatureAlphabet; import edu.fudan.ml.types.alphabet.LabelAlphabet; import edu.fudan.ml.types.sv.HashSparseVector; import edu.fudan.nlp.parser.Sentence; import edu.fudan.util.exception.UnsupportedDataTypeException; /** * 依赖句法分析器类 * * 输入单个分完词的句子(包含词性),使用Yamada分析算法完成依存结构分析。 * * @author cshen * @version Feb 16, 2009 */ public class YamadaParser extends Inferencer { private static final long serialVersionUID = 7114734594734593632L; // 对于左焦点词的每个词性,保存一张特征名到特征ID的对应表 LabelAlphabet postagAlphabet; // 对于左焦点词的每个词性,有一个分类模型 public Linear[] models; public AlphabetFactory factory; /** * 缺省词性, * 如果词性在训练语料中没有使用过,用缺省词性代替,比如”名词“ * 默认为null,不进行替换。遇到新词性时抛出异常 */ protected String defaultPOS = "名词"; /** * 设置缺省词性 * @param pos * @throws UnsupportedDataTypeException */ public void setDefaultPOS(String pos) throws UnsupportedDataTypeException{ int lpos = postagAlphabet.lookupIndex(pos); if(lpos==-1){ throw new UnsupportedDataTypeException("不支持词性:"+pos); } defaultPOS = pos; } /** * 构造函数 * * @param modelfile * 模型目录 * @throws ClassNotFoundException * @throws IOException */ public YamadaParser(String modelfile) throws IOException, ClassNotFoundException { loadModel(modelfile); factory.setStopIncrement(true); postagAlphabet = factory.buildLabelAlphabet("postag"); } private Predict<DependencyTree> _getBestParse(Sentence sent){ float score = 0; // 分析中的状态 ParsingState state = new ParsingState(sent,factory); postagAlphabet = factory.buildLabelAlphabet("postag"); while (!state.isFinalState()) { float[][] estimates; try { estimates = estimateActions(state); } catch (UnsupportedDataTypeException e) { return null; } if ((int) estimates[0][0] == 1) state.next(ParsingState.Action.LEFT); else if ((int) estimates[0][0] == 2) state.next(ParsingState.Action.RIGHT); else if ((int) estimates[0][1] == 1) state.next(ParsingState.Action.LEFT, estimates[1][1]); else state.next(ParsingState.Action.RIGHT, estimates[1][1]); if (estimates[0][0] != 0) score += Math.log10(estimates[1][0]); else score += Math.log10(estimates[1][1]); } score = (float) Math.exp(score); Predict<DependencyTree> res = new Predict<DependencyTree>(); res.add(state.trees.get(0),score); return res; } /** * 动作预测 * * 根据当前状态得到的特征,和训练好的模型,预测当前状态应采取的策略,用在测试中 * * @param featureAlphabet * 特征名到特征ID的对应表,特征抽取时使用特征名,模型中使用特征ID, * @param model * 分类模型 * @param features * 当前状态的特征 * @return 动作及其概率 [[动作1,概率1],[动作2,概率2],[动作3,概率3]] 动作: 1->LEFT; 2->RIGHT; * 0->SHIFT * @throws UnsupportedDataTypeException */ private float[][] estimateActions(ParsingState state) throws UnsupportedDataTypeException { // 当前状态的特征 HashSparseVector features = state.getFeatures(); Instance inst = new Instance(features.indices()); String pos = state.getLeftPos(); int lpos = postagAlphabet.lookupIndex(pos); if(lpos==-1) throw new UnsupportedDataTypeException("不支持词性:"+pos); LabelAlphabet actionList = factory.buildLabelAlphabet(pos); Predict<Integer> ret = models[lpos].classify(inst, actionList.size()); Object[] guess = ret.labels; float[][] result = new float[2][actionList.size()]; float total = 0; for (int i = 0; i < guess.length; i++) { if(guess[i]==null) //bug:可能为空,待修改。 xpqiu break; String action = actionList.lookupString((Integer)guess[i]); result[0][i] = 0; if (action.matches("L")) result[0][i] = 1; else if (action.matches("R")) result[0][i] = 2; result[1][i] = (float) Math.exp(ret.getScore(i)); total += result[1][i]; } for (int i = 0; i < guess.length; i++) { result[1][i] = result[1][i] / total; } return result; } /** * 分析单个句子 * * @param carrier 句子实例 * @param n * @return 整个句子的得分 */ public Predict<int[]> getBest(Instance carrier, int n) { throw new UnsupportedOperationException("Cannot find k-best trees in " + this.getClass().getName()); } /** * 加载模型 * * 以序列化方式加载模型 * * @param modelfile * 模型路径 * @throws IOException * @throws ClassNotFoundException */ public void loadModel(String modelfile) throws IOException, ClassNotFoundException { ObjectInputStream instream = new ObjectInputStream(new GZIPInputStream( new FileInputStream(modelfile))); factory = (AlphabetFactory) instream.readObject(); models = (Linear[]) instream.readObject(); instream.close(); IFeatureAlphabet features = factory.DefaultFeatureAlphabet(); features.setStopIncrement(true); } public Predict<int[]> getBest(Instance inst) { Sentence sent = (Sentence) inst; Predict<DependencyTree> res = _getBestParse(sent); float score = res.getScore(0); DependencyTree dt = res.getLabel(0); Predict<int[]> ret = new Predict<int[]>(); int[] preds = new int[sent.length()]; Arrays.fill(preds, -1); DependencyTree.toArrays(dt, preds); ret.add(preds,score); return ret; } public int[] parse(Instance inst) { return (int[]) getBest(inst).getLabel(0); } // public int[] parse(String[][] strings) { // return parse(new Sentence(strings)); // } public int[] parse(String[] words, String[] pos) { return parse(new Sentence(words, pos)); } public DependencyTree getBestParse(Sentence sent) { return _getBestParse(sent).getLabel(0); } public DependencyTree getBestParse(String[] words) { return getBestParse(words, null); } public DependencyTree getBestParse(String[] words, String[] tags) { return getBestParse(new Sentence(words, tags)); } public static void main(String args[]){ try { YamadaParser yp = new YamadaParser("./tmp/modelConll.mz"); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /** * 得到支持的词性标签集合 * @return 词性标签集合 */ public Set<String> getSupportedTags(){ Set<String> tagset = postagAlphabet.toSet(); return tagset; } }