/******************************************************************************* * Copyright 2014 Felipe Takiyama * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package br.usp.poli.takiyama.sandbox; import static org.junit.Assert.assertEquals; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import org.junit.Test; import br.usp.poli.takiyama.acfove.AggParfactor.AggParfactorBuilder; import br.usp.poli.takiyama.cfove.StdParfactor.StdParfactorBuilder; import br.usp.poli.takiyama.common.Factor; import br.usp.poli.takiyama.common.Parfactor; import br.usp.poli.takiyama.common.StdFactor; import br.usp.poli.takiyama.prv.Constant; import br.usp.poli.takiyama.prv.LogicalVariable; import br.usp.poli.takiyama.prv.Or; import br.usp.poli.takiyama.prv.Prv; import br.usp.poli.takiyama.prv.StdLogicalVariable; import br.usp.poli.takiyama.prv.StdPrv; import br.usp.poli.takiyama.utils.Lists; import br.usp.poli.takiyama.utils.TestUtils; public class Temp5 { /** * Checking the factor created just after performing sum-out on * aggregation parfactor. */ @Test public void testBigJackpotInference() { int populationSize = 2; LogicalVariable person = StdLogicalVariable.getInstance("Person", "x", populationSize); Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); Prv played = StdPrv.getBooleanInstance("played", person); Prv matched_6 = StdPrv.getBooleanInstance("matched_6", person); Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); List<BigDecimal> fBigJackpot = TestUtils.toBigDecimalList(0.8, 0.2); List<BigDecimal> fPlayed = TestUtils.toBigDecimalList(0.95, 0.05, 0.85, 0.15); List<BigDecimal> fMatched6 = TestUtils.toBigDecimalList(1.0, 0.0, 0.99999993, 0.00000007); Parfactor g1 = new StdParfactorBuilder().variables(big_jackpot).values(fBigJackpot).build(); Parfactor g2 = new StdParfactorBuilder().variables(big_jackpot, played).values(fPlayed).build(); Parfactor g3 = new StdParfactorBuilder().variables(played, matched_6).values(fMatched6).build(); Parfactor g4 = new AggParfactorBuilder(matched_6, jackpot_won, Or.OR).context(big_jackpot).build(); Parfactor g2xg3 = g2.multiply(g3); Parfactor afterEliminatingPlayed = g2xg3.sumOut(played); Parfactor g4xg5 = g4.multiply(afterEliminatingPlayed); Parfactor afterEliminatingMatched6 = g4xg5.sumOut(matched_6); Factor r = getCorrectResultOfBigJackpotInference(populationSize); Parfactor expected = new StdParfactorBuilder().factor(r).build(); assertEquals(expected, afterEliminatingMatched6); } // propositionalizes the model above and calculates the factor created // after eliminating the aggregation parfactor private Factor getCorrectResultOfBigJackpotInference(int n) { List<Prv> matched_6 = new ArrayList<Prv>(n); for (int i = 0; i < n; i++) { Constant p = Constant.getInstance("x" + i); matched_6.add(StdPrv.getBooleanInstance("matched_6", p)); } Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); // Creates factor on matched_6(x0) ... matched_6(xn) jackpot_won() List<Prv> vars = Lists.listOf(matched_6); vars.add(jackpot_won); List<BigDecimal> fagg = new ArrayList<BigDecimal>(); for (int i = 0; i < (int) Math.pow(2, n); i++) { fagg.add(BigDecimal.ZERO); fagg.add(BigDecimal.ONE); } fagg.set(0, BigDecimal.ONE); fagg.set(1, BigDecimal.ZERO); Factor jw = StdFactor.getInstance("", vars, fagg); // Creates factor on big_jackpot() matched_6(X) Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); List<BigDecimal> f5vals = TestUtils.toBigDecimalList(0.9999999965, 0.0000000035, 0.9999999895, 0.0000000105); List<Factor> f5 = new ArrayList<Factor>(n); for (Prv m6 : matched_6) { List<Prv> rvs = new ArrayList<Prv>(2); rvs.add(big_jackpot); rvs.add(m6); f5.add(StdFactor.getInstance("", rvs, f5vals)); } // Creates factor big_jackpot() List<BigDecimal> fbj = TestUtils.toBigDecimalList(0.8, 0.2); Factor bj = StdFactor.getInstance("", big_jackpot, fbj); // multiplies all factors Factor product = jw; for (Factor f : f5) { product = product.multiply(f); } // eliminates matched_6(x) Factor result = product; for (Prv m6 : matched_6) { result = result.sumOut(m6); } return result; } }