/******************************************************************************* * Copyright (c) 2012 György Orosz, Attila Novák. * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Lesser Public License v3 * which accompanies this distribution, and is available at * http://www.gnu.org/licenses/ * * This file is part of PurePos. * * PurePos is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * PurePos is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser Public License for more details. * * Contributors: * György Orosz - initial API and implementation ******************************************************************************/ package hu.ppke.itk.nlpg.purepos.model.internal; import java.util.ArrayList; import java.util.Arrays; import junit.framework.Assert; import org.junit.Test; public class ProbModelTest { @Test public void creation() { ArrayList<Double> lambdas = new ArrayList<Double>(Arrays.asList(0.0, 1.0, 2.0, 4.0)); // IntTrieNode<Integer> node = new IntTrieNode<Integer>(0); NGramModel<Integer> imodel = new NGramModel<Integer>(3); imodel.addWord(Arrays.asList(1, 2), 3); imodel.addWord(Arrays.asList(2, 3), 4); imodel.addWord(Arrays.asList(22, 3), 6); imodel.addWord(Arrays.asList(1, 2), 5); ProbModel<Integer> model = new ProbModel<Integer>(imodel.root, lambdas); Double val; // unigrams val = model.getProb(new ArrayList<Integer>(), 3); Assert.assertEquals(0.25, val); val = model.getProb(new ArrayList<Integer>(), 4); Assert.assertEquals(0.25, val); val = model.getProb(new ArrayList<Integer>(), 5); Assert.assertEquals(0.25, val); val = model.getProb(new ArrayList<Integer>(), -1); Assert.assertEquals(0.0, val); // bigrams val = model.getProb(Arrays.asList(2), 3); Assert.assertEquals(1.25, val); val = model.getProb(Arrays.asList(2), 5); Assert.assertEquals(1.25, val); val = model.getProb(Arrays.asList(3), 4); Assert.assertEquals(1.25, val); val = model.getProb(Arrays.asList(3), 6); Assert.assertEquals(1.25, val); val = model.getProb(Arrays.asList(3), -1); Assert.assertEquals(0.0, val); // it is going to be an unigram val = model.getProb(Arrays.asList(-1), 3); Assert.assertEquals(0.25, val); // trigrams val = model.getProb(Arrays.asList(1, 2), 3); Assert.assertEquals(3.25, val); val = model.getProb(Arrays.asList(1, 2), 5); Assert.assertEquals(3.25, val); val = model.getProb(Arrays.asList(22, 3), 6); Assert.assertEquals(5.25, val); val = model.getProb(Arrays.asList(2, 3), 4); Assert.assertEquals(5.25, val); val = model.getProb(Arrays.asList(2, 3), -1); Assert.assertEquals(0.0, val); } @Test public void getProbTest() { DoubleTrieNode<Integer> root = new DoubleTrieNode<Integer>(0); ProbModel<Integer> model = new ProbModel<Integer>(root); Double val = model.getProb(new ArrayList<Integer>(), 1); Assert.assertEquals(val, 0.0); Double val2 = model.getWordProbs(new ArrayList<Integer>()).get(1); Assert.assertEquals(val2, null); root = new DoubleTrieNode<Integer>(0); root.addWord(1, 0.1); root.addWord(2, 0.2); model = new ProbModel<Integer>(root); val = model.getProb(new ArrayList<Integer>(), 1); val2 = model.getWordProbs(new ArrayList<Integer>()).get(1); Assert.assertEquals(val, 0.1); Assert.assertEquals(val, val2); val = model.getProb(new ArrayList<Integer>(), 2); val2 = model.getWordProbs(new ArrayList<Integer>()).get(2); Assert.assertEquals(val, 0.2); Assert.assertEquals(val, val2); root = new DoubleTrieNode<Integer>(0); root.addWord(1, 0.1); root.addWord(2, 0.2); DoubleTrieNode<Integer> c1 = new DoubleTrieNode<Integer>(1); c1.addWord(3, 0.3); c1.addWord(4, 0.4); DoubleTrieNode<Integer> c2 = new DoubleTrieNode<Integer>(2); c2.addWord(5, 0.5); c2.addWord(4, 0.4); DoubleTrieNode<Integer> c11 = new DoubleTrieNode<Integer>(11); c11.addWord(3, 0.33); DoubleTrieNode<Integer> c12 = new DoubleTrieNode<Integer>(12); c12.addWord(4, 0.44); c1.addChild(c12); c1.addChild(c11); root.addChild(c1); root.addChild(c2); model = new ProbModel<Integer>(root); val = model.getProb(Arrays.asList(11, 1), 3); Assert.assertEquals(0.33, val); val2 = model.getWordProbs(Arrays.asList(11, 1)).get(3); Assert.assertEquals(val, val2); val = model.getProb(Arrays.asList(12, 1), 4); Assert.assertEquals(0.44, val); val2 = model.getWordProbs(Arrays.asList(12, 1)).get(4); Assert.assertEquals(val, val2); val = model.getProb(Arrays.asList(2), 4); Assert.assertEquals(0.4, val); val2 = model.getWordProbs(Arrays.asList(2)).get(4); Assert.assertEquals(val, val2); val = model.getProb(Arrays.asList(2), 5); Assert.assertEquals(0.5, val); val2 = model.getWordProbs(Arrays.asList(2)).get(5); Assert.assertEquals(val, val2); // too big context val = model.getProb(Arrays.asList(1, 2), 4); Assert.assertEquals(0.4, val); val2 = model.getWordProbs(Arrays.asList(1, 2)).get(4); Assert.assertEquals(val, val2); val = model.getProb(Arrays.asList(1, 2, 3, 4, 12, 1), 4); Assert.assertEquals(0.44, val); val2 = model.getWordProbs(Arrays.asList(1, 2, 3, 4, 12, 1)).get(4); Assert.assertEquals(val, val2); } }