package hu.ppke.itk.nlpg.purepos.model.internal;
import hu.ppke.itk.nlpg.docmodel.IDocument;
import hu.ppke.itk.nlpg.docmodel.ISentence;
import hu.ppke.itk.nlpg.docmodel.IToken;
import hu.ppke.itk.nlpg.purepos.common.Util;
import hu.ppke.itk.nlpg.purepos.common.lemma.ILemmaTransformation;
import hu.ppke.itk.nlpg.purepos.common.lemma.LemmaUtil;
import hu.ppke.itk.nlpg.purepos.model.ISuffixGuesser;
import hu.ppke.itk.nlpg.purepos.model.ModelData;
import hu.ppke.itk.nlpg.purepos.model.SuffixTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
public class LogLinearTriCombiner extends LogLinearCombiner {
/**
*
*/
private static final long serialVersionUID = 4264362007575382294L;
@Override
public void calculateParameters(IDocument doc, RawModelData rawModeldata,
ModelData<String, Integer> data) {
Map<Integer, Double> aprioriProbs = rawModeldata.tagNGramModel
.getWordAprioriProbs();
Double theta = SuffixTree.calculateTheta(aprioriProbs);
ISuffixGuesser<String, ILemmaTransformation<String, Integer>> lemmaSuffixGuesser = rawModeldata.lemmaSuffixTree
.createGuesser(theta);
ISuffixGuesser<String, String> lemmaProb = rawModeldata.lemmaFreqTree
.createGuesser(theta);
LemmaUnigramModel<String> lemmaUnigramModel = rawModeldata.lemmaUnigramModel;
Double lambdaS = 1.0, lambdaU = 1.0, lambdaL = 1.0;
if (lambdas != null && lambdas.size() > 1) {
lambdaS = lambdas.get(0);
lambdaU = lambdas.get(1);
lambdaL = lambdas.get(2);
}
lambdas = new ArrayList<Double>(3);
for (ISentence sentence : doc.getSentences()) {
for (IToken tok : sentence) {
Map<IToken, Pair<ILemmaTransformation<String, Integer>, Double>> suffixProbs = LemmaUtil
.batchConvert(lemmaSuffixGuesser
.getTagLogProbabilities(tok.getToken()), tok
.getToken(), data.tagVocabulary);
Map<IToken, Double> uniProbs = new HashMap<IToken, Double>();
for (IToken t : suffixProbs.keySet()) {
Double uniscore = lemmaUnigramModel.getLogProb(t.getStem());
uniProbs.put(t, uniscore);
}
Map<IToken, Double> lemmaProbs = new HashMap<IToken, Double>();
for (IToken t : suffixProbs.keySet()) {
// if (t.getTag().equals(tok.getTag())) {
Double lemmaScore = lemmaProb.getTagLogProbability(
t.getStem(), LemmaUtil.mainPosTag(t.getTag()));
lemmaProbs.put(t, lemmaScore);
// }
}
Map.Entry<IToken, Double> uniMax = Util.findMax(uniProbs);
Pair<IToken, Double> suffixMax = Util.findMax2(suffixProbs);
Map.Entry<IToken, Double> lemmaMax = Util.findMax(lemmaProbs);
Double actUniProb = lemmaUnigramModel.getLogProb(tok.getStem());
Double actLemmaProb = lemmaProb.getTagLogProbability(
tok.getStem(), LemmaUtil.mainPosTag(tok.getTag()));
// Pair<String, Integer> lemmaCode = SuffixCoder.decode(tok,
// data.tagVocabulary);
Double actSuffProb;
if (suffixProbs.containsKey(tok)) {
actSuffProb = suffixProbs.get(tok).getValue();
} else {
actSuffProb = Util.UNKOWN_VALUE;
}
Double uniProp = actUniProb - uniMax.getValue(), suffProp = actSuffProb
- suffixMax.getValue(), lemmaProp = actLemmaProb
- lemmaMax.getValue();
if (uniProp > suffProp && uniProp > lemmaProp) {
lambdaU += uniProp;
} else if (suffProp > uniProp && suffProp > lemmaProp) {
lambdaS += suffProp;
} else if (lemmaProp > uniProp && lemmaProp > suffProp) {
lambdaL += lemmaProp;// - uniProp;
}
}
}
double sum = lambdaU + lambdaS + lambdaL;
lambdaU = lambdaU / sum;
lambdaS = lambdaS / sum;
lambdaL = lambdaL / sum;
lambdas.add(lambdaU);
lambdas.add(lambdaS);
lambdas.add(lambdaL);
// return lambdas;
}
@Override
public Double combine(IToken tok, ILemmaTransformation<String, Integer> t,
CompiledModelData<String, Integer> compiledModelData,
ModelData<String, Integer> modelData) {
LemmaUnigramModel<String> unigramLemmaModel = compiledModelData.unigramLemmaModel;
Double uniScore = unigramLemmaModel.getLogProb(tok.getStem());
Double suffixScore = smooth(compiledModelData.lemmaGuesser
.getTagLogProbability(tok.getToken(), t));
Double lemmaProb = compiledModelData.suffixLemmaModel
.getTagLogProbability(tok.getStem(),
LemmaUtil.mainPosTag(tok.getTag()));
return uniScore * lambdas.get(0) + suffixScore * lambdas.get(1)
+ lemmaProb * lambdas.get(2);
}
}