package edu.fudan.nlp.parser.dep; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import edu.fudan.ml.classifier.Predict; import edu.fudan.ml.classifier.linear.Linear; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.alphabet.AlphabetFactory; import edu.fudan.ml.types.alphabet.IFeatureAlphabet; import edu.fudan.ml.types.alphabet.LabelAlphabet; import edu.fudan.nlp.parser.Sentence; import edu.fudan.nlp.parser.Target; import edu.fudan.util.exception.LoadModelException; import edu.fudan.util.exception.UnsupportedDataTypeException; import gnu.trove.list.array.TIntArrayList; /** * 依赖句法分析器类,同时标注依赖关系类型 * * 输入单个分完词的句子(包含词性),使用Yamada分析算法完成依存结构分析。 * * @author */ public class JointParser implements Serializable{ private static final long serialVersionUID = 7114734594734593632L; private int ysize; private AlphabetFactory factory; private Linear models; private IFeatureAlphabet fa; private LabelAlphabet la; /** * 构造函数 * * @param modelfile * 模型目录 * @throws LoadModelException */ public JointParser(String modelfile) throws LoadModelException { models = Linear.loadFrom(modelfile); factory = models.getAlphabetFactory(); fa = factory.DefaultFeatureAlphabet(); la = factory.DefaultLabelAlphabet(); ysize = la.size(); factory.setStopIncrement(true); } public static int[] addFeature(IFeatureAlphabet fa, ArrayList<String> str, int ysize) { TIntArrayList indices = new TIntArrayList(); for(String s: str){ int i = fa.lookupIndex(s,ysize); if(i!=-1) indices.add(i); } return indices.toArray(); } private void doNext(String action,JointParsingState state){ char act = action.charAt(0); String relation = action.substring(1); switch(act){ case 'L': state.next(JointParsingState.Action.LEFT, relation);break; case 'R': state.next(JointParsingState.Action.RIGHT, relation);break; default: System.out.println("状态动作错误"); } } private void doNext(String action,float est1,JointParsingState state){ char act = action.charAt(0); String relation = action.substring(1); switch(act){ case 'L': state.next(JointParsingState.Action.LEFT, est1,relation);break; case 'R': state.next(JointParsingState.Action.RIGHT, est1,relation);break; default: System.out.println("状态动作错误"); } } private Predict<DependencyTree> _getBestParse(Sentence sent){ float score = 0.0f; // 分析中的状态 JointParsingState state = new JointParsingState(sent); while (!state.isFinalState()) { Predict<String> estimates; estimates = estimateActions(state); String action = estimates.getLabel(0); if (!action.equals("S")){ doNext(action ,state); score +=estimates.getScore(0); } else{ action = estimates.getLabel(1); float s = estimates.getScore(1); doNext(action,s,state); score +=estimates.getScore(1); } } Predict<DependencyTree> res = new Predict<DependencyTree>(); res.add(state.trees.get(0),score); return res; } /** * 动作预测 * * 根据当前状态得到的特征,和训练好的模型,预测当前状态应采取的策略,用在测试中 * * @param featureAlphabet * 特征名到特征ID的对应表,特征抽取时使用特征名,模型中使用特征ID, * @param model * 分类模型 * @param features * 当前状态的特征 * @return 动作及其概率 [[动作1,概率1],[动作2,概率2],[动作3,概率3]] 动作: 1->LEFT; 2->RIGHT; * 0->SHIFT */ private Predict<String> estimateActions(JointParsingState state) { // 当前状态的特征 ArrayList<String> features = state.getFeatures(); Instance inst = new Instance(addFeature(fa, features, ysize)); Predict<Integer> ret = models.classify(inst,ysize); ret.normalize(); Predict<String> result =new Predict<String>(2); float total = 0; for (int i = 0; i < 2; i++) { Integer guess = ret.getLabel(i); if(guess==null) //bug:可能为空,待修改。 xpqiu break; String action = la.lookupString(guess); result.add(action,ret.getScore(i)); } return result; } public Target jointParse(Instance inst) { Sentence sent = (Sentence) inst; Predict<DependencyTree> res = _getBestParse(sent); DependencyTree dt = res.getLabel(0); Target target = new Target(sent.length()); target=target.ValueOf(dt); return target; } public int[] parse(Instance inst) { return (int[]) getBest(inst).getLabel(0); } // public int[] parse(String[][] strings) { // return parse(new Sentence(strings)); // } public int[] parse(String[] words, String[] pos) { return parse(new Sentence(words, pos)); } public DependencyTree parse2T(Sentence sent) { Predict<DependencyTree> res = _getBestParse(sent); DependencyTree dt = res.getLabel(0); return dt; } /** * 得到依存句法树 * @param words 词数组 * @param pos 词性数组 * @return * @throws UnsupportedDataTypeException */ public DependencyTree parse2T(String[] words, String[] pos){ return parse2T(new Sentence(words, pos)); } public Target parse2R(Instance inst) { return jointParse(inst); } public Target parse2R(String[] words, String[] pos) { return parse2R(new Sentence(words, pos)); } public String parse2String(String[] words, String[] pos,boolean b) { Target target = parse2R(words,pos); int[] heads = target.getHeads(); String[] rel = target.getRelations(); StringBuffer sb = new StringBuffer(); if(b){ for(int j = 0; j < words.length; j++){ sb.append(words[j]); if(j<words.length-1) sb.append(" "); } sb.append("\n"); for(int j = 0; j < pos.length; j++){ sb.append(pos[j]); if(j<pos.length-1) sb.append(" "); } sb.append("\n"); } for(int j = 0; j < heads.length; j++){ sb.append(heads[j]); if(j<heads.length-1) sb.append(" "); } sb.append("\n"); for(int j = 0; j < rel.length; j++){ if(rel[j]==null) sb.append("核心词"); else sb.append(rel[j]); if(j<heads.length-1) sb.append(" "); } return sb.toString(); } public Predict<int[]> getBest(Instance inst) { Sentence sent = (Sentence) inst; Predict<DependencyTree> res = _getBestParse(sent); float score = res.getScore(0); DependencyTree dt = res.getLabel(0); Predict<int[]> ret = new Predict<int[]>(); int[] preds = new int[sent.length()]; Arrays.fill(preds, -1); DependencyTree.toArrays(dt, preds); ret.add(preds,score); return ret; } /** * 得到支持的依存关系类型集合 * @return 词性标签集合 */ public Set<String> getSupportedTypes(){ Set<String> typeset = new HashSet<String>(); Set<String> set = factory.DefaultLabelAlphabet().toSet(); Iterator<String> itt = set.iterator(); while(itt.hasNext()){ String type = itt.next(); if(type.length() ==1 ) continue; typeset.add(type.substring(1)); } return typeset; } }