/******************************************************************************* * Copyright 2014 Felipe Takiyama * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package br.usp.poli.takiyama.acfove; import static org.junit.Assert.assertEquals; import java.io.IOException; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; import org.junit.Ignore; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import br.usp.poli.takiyama.acfove.AggParfactor.AggParfactorBuilder; import br.usp.poli.takiyama.cfove.StdParfactor.StdParfactorBuilder; import br.usp.poli.takiyama.common.AggregationParfactor; import br.usp.poli.takiyama.common.Constraint; import br.usp.poli.takiyama.common.Distribution; import br.usp.poli.takiyama.common.Factor; import br.usp.poli.takiyama.common.InequalityConstraint; import br.usp.poli.takiyama.common.InputOutput; import br.usp.poli.takiyama.common.Marginal; import br.usp.poli.takiyama.common.Parfactor; import br.usp.poli.takiyama.common.SplitResult; import br.usp.poli.takiyama.common.StdFactor; import br.usp.poli.takiyama.common.StdMarginal.StdMarginalBuilder; import br.usp.poli.takiyama.prv.Constant; import br.usp.poli.takiyama.prv.CountingFormula; import br.usp.poli.takiyama.prv.LogicalVariable; import br.usp.poli.takiyama.prv.Or; import br.usp.poli.takiyama.prv.Prv; import br.usp.poli.takiyama.prv.RandomVariableSet; import br.usp.poli.takiyama.prv.StdLogicalVariable; import br.usp.poli.takiyama.prv.StdPrv; import br.usp.poli.takiyama.samples.WaterSprinklerNetwork; import br.usp.poli.takiyama.utils.Example; import br.usp.poli.takiyama.utils.Lists; import br.usp.poli.takiyama.utils.MathUtils; import br.usp.poli.takiyama.utils.Sets; import br.usp.poli.takiyama.utils.TestUtils; @RunWith(Enclosed.class) public class ACFOVETest { private static class Utils { private static List<BigDecimal> toBigDecimalList(double ... list) { List<BigDecimal> result = new ArrayList<BigDecimal>(list.length); for (int i = 0; i < list.length; i++) { result.add(BigDecimal.valueOf(list[i])); } return result; } } @Ignore("Early debug sessions >> fourth step is not correct") @RunWith(Parameterized.class) public static class StepByStepExampleComputation { private static List<BigDecimal> toBigDecimalList(double ... list) { List<BigDecimal> result = new ArrayList<BigDecimal>(list.length); for (int i = 0; i < list.length; i++) { result.add(BigDecimal.valueOf(list[i])); } return result; } @Parameters public static Collection<Object[]> data() { int populationSize = 10; LogicalVariable lot = StdLogicalVariable.getInstance("Lot", "lot", populationSize); Constant lot1 = Constant.getInstance("lot1"); Constraint lot_lot1 = InequalityConstraint.getInstance(lot, lot1); Prv rain = StdPrv.getBooleanInstance("rain"); Prv sprinkler = StdPrv.getBooleanInstance("sprinkler", lot); Prv sprinkler_lot1 = StdPrv.getBooleanInstance("sprinkler", lot1); Prv wet_grass = StdPrv.getBooleanInstance("wet_grass", lot); Prv wet_grass_lot1 = StdPrv.getBooleanInstance("wet_grass", lot1); Prv formula = CountingFormula.getInstance(lot, wet_grass, lot_lot1); List<BigDecimal> f1 = toBigDecimalList(0.8, 0.2); List<BigDecimal> f2 = toBigDecimalList(0.6, 0.4); List<BigDecimal> f3 = toBigDecimalList(1.0, 0.0, 0.2, 0.8, 0.1, 0.9, 0.01, 0.99); List<BigDecimal> f4 = toBigDecimalList(0.0, 1.0); List<BigDecimal> f2xf3 = new ArrayList<BigDecimal>(8); for (int i = 0; i < 8; i++) { f2xf3.add(f2.get((i / 2) % 2).multiply(f3.get(i), MathUtils.CONTEXT)); } List<BigDecimal> f5 = new ArrayList<BigDecimal>(4); f5.add(f2xf3.get(0).add(f2xf3.get(2), MathUtils.CONTEXT)); f5.add(f2xf3.get(1).add(f2xf3.get(3), MathUtils.CONTEXT)); f5.add(f2xf3.get(4).add(f2xf3.get(6), MathUtils.CONTEXT)); f5.add(f2xf3.get(5).add(f2xf3.get(7), MathUtils.CONTEXT)); List<BigDecimal> f4xf5 = new ArrayList<BigDecimal>(4); for (int i = 0; i < 4; i++) { f4xf5.add(f4.get(i % 2).multiply(f5.get(i), MathUtils.CONTEXT)); } List<BigDecimal> f6 = new ArrayList<BigDecimal>(2); f6.add(f4xf5.get(0).add(f4xf5.get(1), MathUtils.CONTEXT)); f6.add(f4xf5.get(2).add(f4xf5.get(3), MathUtils.CONTEXT)); int n = populationSize; List<BigDecimal> f7 = new ArrayList<BigDecimal>(2 * n); for (int i = 0; i < n; i++) { f7.add(MathUtils.pow(f5.get(0), n - i - 1, 1).multiply(MathUtils.pow(f5.get(1), i, 1), MathUtils.CONTEXT)); } for (int i = 0; i < n; i++) { f7.add(MathUtils.pow(f5.get(2), n - i - 1, 1).multiply(MathUtils.pow(f5.get(3), i, 1), MathUtils.CONTEXT)); } List<BigDecimal> f1xf6xf7 = new ArrayList<BigDecimal>(2 * n); for (int i = 0; i < 2 * n; i++) { f1xf6xf7.add(f1.get(i / n).multiply(f6.get(i / n).multiply(f7.get(i), MathUtils.CONTEXT), MathUtils.CONTEXT)); } List<BigDecimal> f8 = new ArrayList<BigDecimal>(n); for (int i = 0; i < n; i++) { f8.add(f1xf6xf7.get(i).add(f1xf6xf7.get(n + i), MathUtils.CONTEXT)); } /** * Alternative path * * When processing set Phi 2, we can either eliminate * wet_grass(lot1) or sprinkler(lot1). * The order in which it is done does not affect final result, * but affects these individual tests. Kisynski opted to eliminate * sprinkler(lot1) first. * Below I assemble the case when wet_grass(lot1) is eliminated * first. This changes set Phi 3, after that the intermediate * results remain the same. */ List<BigDecimal> f3xf4 = new ArrayList<BigDecimal>(8); for (int i = 0; i < 8; i++) { f3xf4.add(f4.get(i % 2).multiply(f3.get(i), MathUtils.CONTEXT)); } f3xf4.add(f4.get(0).multiply(f3.get(0), MathUtils.CONTEXT)); f3xf4.add(f4.get(0).multiply(f3.get(1), MathUtils.CONTEXT)); f3xf4.add(f4.get(1).multiply(f3.get(2), MathUtils.CONTEXT)); f3xf4.add(f4.get(1).multiply(f3.get(3), MathUtils.CONTEXT)); f3xf4.add(f4.get(0).multiply(f3.get(4), MathUtils.CONTEXT)); f3xf4.add(f4.get(0).multiply(f3.get(5), MathUtils.CONTEXT)); f3xf4.add(f4.get(1).multiply(f3.get(6), MathUtils.CONTEXT)); f3xf4.add(f4.get(1).multiply(f3.get(7), MathUtils.CONTEXT)); List<BigDecimal> f5alt = new ArrayList<BigDecimal>(4); f5alt.add(f3xf4.get(0).add(f3xf4.get(1), MathUtils.CONTEXT)); f5alt.add(f3xf4.get(2).add(f3xf4.get(3), MathUtils.CONTEXT)); f5alt.add(f3xf4.get(4).add(f3xf4.get(5), MathUtils.CONTEXT)); f5alt.add(f3xf4.get(6).add(f3xf4.get(7), MathUtils.CONTEXT)); Parfactor g1 = new StdParfactorBuilder().variables(rain).values(f1).build(); Parfactor g2 = new StdParfactorBuilder().variables(sprinkler).values(f2).build(); Parfactor g3 = new StdParfactorBuilder().variables(rain, sprinkler, wet_grass).values(f3).build(); Parfactor g4 = new StdParfactorBuilder().variables(wet_grass_lot1).values(f4).build(); Parfactor g5 = new StdParfactorBuilder().variables(rain, sprinkler_lot1, wet_grass_lot1).values(f3).build(); Parfactor g6 = new StdParfactorBuilder().variables(rain, sprinkler, wet_grass).values(f3).constraints(lot_lot1).build(); Parfactor g7 = new StdParfactorBuilder().variables(sprinkler_lot1).values(f2).build(); Parfactor g8 = new StdParfactorBuilder().variables(sprinkler).values(f2).constraints(lot_lot1).build(); Parfactor g9 = new StdParfactorBuilder().variables(rain, wet_grass).values(f5).constraints(lot_lot1).build(); Parfactor g10 = new StdParfactorBuilder().variables(rain, wet_grass_lot1).values(f5).build(); Parfactor g10alt = new StdParfactorBuilder().variables(rain, sprinkler_lot1).values(f5alt).build(); Parfactor g11 = new StdParfactorBuilder().variables(rain).values(f6).build(); Parfactor g12 = new StdParfactorBuilder().variables(rain, formula).values(f7).build(); Parfactor g13 = new StdParfactorBuilder().variables(formula).values(f8).build(); Set<Constraint> constraints = Sets.setOf(lot_lot1); RandomVariableSet query = RandomVariableSet.getInstance(wet_grass, constraints); InputOutput<Marginal, Marginal> inOut = InputOutput.getInstance(); Marginal input, output; // first step // input = new StdMarginalBuilder(4).parfactors(g1, g2, g3, g4).preservable(query).build(); // output = new StdMarginalBuilder(6).parfactors(g1, g4, g5, g6, g7, g8).preservable(query).build(); // inOut.add(input, output); // second step - sums out ground(sprinkler(Lot)): {Lot != lot1} from g8 input = new StdMarginalBuilder(6).parfactors(g1, g4, g5, g6, g7, g8).preservable(query).build(); output = new StdMarginalBuilder(5).parfactors(g1, g4, g5, g7, g9).preservable(query).build(); inOut.add(input, output); // third step - sums out ground(sprinkler(lot1) from g5 x g7 input = new StdMarginalBuilder(6).parfactors(g1, g4, g5, g7, g9).preservable(query).build(); // turns out my algorithm chose the alternative path.... output = new StdMarginalBuilder(4).parfactors(g1, g4, g9, g10).preservable(query).build(); // output = new StdMarginalBuilder(4).parfactors(g1, g7, g9, g10alt).preservable(query).build(); inOut.add(input, output); // fourth step input = new StdMarginalBuilder(4).parfactors(g1, g4, g9, g10).preservable(query).build(); output = new StdMarginalBuilder(3).parfactors(g1, g9, g11).preservable(query).build(); inOut.add(input, output); // fifth step input = new StdMarginalBuilder(3).parfactors(g1, g9, g11).preservable(query).build(); output = new StdMarginalBuilder(3).parfactors(g1, g11, g12).preservable(query).build(); inOut.add(input, output); // sixth step input = new StdMarginalBuilder(3).parfactors(g1, g11, g12).preservable(query).build(); output = new StdMarginalBuilder(1).parfactors(g13).preservable(query).build(); inOut.add(input, output); // all steps // input = new StdMarginalBuilder(4).parfactors(g1, g2, g3, g4).preservable(query).build(); // output = new StdMarginalBuilder(1).parfactors(g13).preservable(query).build(); // inOut.add(input, output); return inOut.toCollection(); } private Marginal input; private Marginal expected; public StepByStepExampleComputation(Marginal input, Marginal expected) { this.input = input; this.expected = expected; } @Test public void testStep() { ACFOVE acfove = new ACFOVE(input); Marginal result = acfove.runStep(); assertEquals(expected, result); } } public static class ExampleComputation { private static List<BigDecimal> toBigDecimalList(double ... list) { List<BigDecimal> result = new ArrayList<BigDecimal>(list.length); for (int i = 0; i < list.length; i++) { result.add(BigDecimal.valueOf(list[i])); } return result; } // TODO: remove repeted code @Ignore(" while testing smaller features") @Test public void testExampleComputation() { int populationSize = 10; LogicalVariable lot = StdLogicalVariable.getInstance("Lot", "lot", populationSize); Constant lot1 = Constant.getInstance("lot1"); Constraint lot_lot1 = InequalityConstraint.getInstance(lot, lot1); Prv rain = StdPrv.getBooleanInstance("rain"); Prv sprinkler = StdPrv.getBooleanInstance("sprinkler", lot); Prv wet_grass = StdPrv.getBooleanInstance("wet_grass", lot); Prv wet_grass_lot1 = StdPrv.getBooleanInstance("wet_grass", lot1); Prv formula = CountingFormula.getInstance(lot, wet_grass, lot_lot1); List<BigDecimal> f1 = toBigDecimalList(0.8, 0.2); List<BigDecimal> f2 = toBigDecimalList(0.6, 0.4); List<BigDecimal> f3 = toBigDecimalList(1.0, 0.0, 0.2, 0.8, 0.1, 0.9, 0.01, 0.99); List<BigDecimal> f4 = toBigDecimalList(0.0, 1.0); List<BigDecimal> f2xf3 = new ArrayList<BigDecimal>(8); for (int i = 0; i < 8; i++) { f2xf3.add(f2.get((i / 2) % 2).multiply(f3.get(i), MathUtils.CONTEXT)); } List<BigDecimal> f5 = new ArrayList<BigDecimal>(4); f5.add(f2xf3.get(0).add(f2xf3.get(2), MathUtils.CONTEXT)); f5.add(f2xf3.get(1).add(f2xf3.get(3), MathUtils.CONTEXT)); f5.add(f2xf3.get(4).add(f2xf3.get(6), MathUtils.CONTEXT)); f5.add(f2xf3.get(5).add(f2xf3.get(7), MathUtils.CONTEXT)); List<BigDecimal> f4xf5 = new ArrayList<BigDecimal>(4); for (int i = 0; i < 4; i++) { f4xf5.add(f4.get(i % 2).multiply(f5.get(i), MathUtils.CONTEXT)); } List<BigDecimal> f6 = new ArrayList<BigDecimal>(2); f6.add(f4xf5.get(0).add(f4xf5.get(1), MathUtils.CONTEXT)); f6.add(f4xf5.get(2).add(f4xf5.get(3), MathUtils.CONTEXT)); int n = populationSize; List<BigDecimal> f7 = new ArrayList<BigDecimal>(2 * n); for (int i = 0; i < n; i++) { f7.add(MathUtils.pow(f5.get(0), n - i - 1, 1).multiply(MathUtils.pow(f5.get(1), i, 1), MathUtils.CONTEXT)); } for (int i = 0; i < n; i++) { f7.add(MathUtils.pow(f5.get(2), n - i - 1, 1).multiply(MathUtils.pow(f5.get(3), i, 1), MathUtils.CONTEXT)); } List<BigDecimal> f1xf6xf7 = new ArrayList<BigDecimal>(2 * n); for (int i = 0; i < 2 * n; i++) { f1xf6xf7.add(f1.get(i / n).multiply(f6.get(i / n).multiply(f7.get(i), MathUtils.CONTEXT), MathUtils.CONTEXT)); } List<BigDecimal> f8 = new ArrayList<BigDecimal>(n); for (int i = 0; i < n; i++) { f8.add(f1xf6xf7.get(i).add(f1xf6xf7.get(n + i), MathUtils.CONTEXT)); } Parfactor g1 = new StdParfactorBuilder().variables(rain).values(f1).build(); Parfactor g2 = new StdParfactorBuilder().variables(sprinkler).values(f2).build(); Parfactor g3 = new StdParfactorBuilder().variables(rain, sprinkler, wet_grass).values(f3).build(); Parfactor g4 = new StdParfactorBuilder().variables(wet_grass_lot1).values(f4).build(); Parfactor g13 = new StdParfactorBuilder().variables(formula).values(f8).build(); Set<Constraint> constraints = Sets.setOf(lot_lot1); RandomVariableSet query = RandomVariableSet.getInstance(wet_grass, constraints); Marginal input = new StdMarginalBuilder(4).parfactors(g1, g2, g3, g4).preservable(query).build(); ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = g13; assertEquals(expected, result); } } @RunWith(Parameterized.class) public static class StepByStepExampleComputationWithAggregation { @Parameters public static Collection<Object[]> data() { int populationSize = 5; LogicalVariable person = StdLogicalVariable.getInstance("Person", "x", populationSize); Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); Prv played = StdPrv.getBooleanInstance("played", person); Prv matched_6 = StdPrv.getBooleanInstance("matched_6", person); Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); List<BigDecimal> fBigJackpot = Utils.toBigDecimalList(0.8, 0.2); List<BigDecimal> fPlayed = Utils.toBigDecimalList(0.95, 0.05, 0.85, 0.15); List<BigDecimal> fMatched6 = Utils.toBigDecimalList(1.0, 0.0, 0.99999993, 0.00000007); List<BigDecimal> fJackpotWonPrime = Utils.toBigDecimalList(0.999999975, 0.000000025); List<BigDecimal> fPlayedxfMatched6 = new ArrayList<BigDecimal>(8); for (int i = 0; i < 8; i++) { fPlayedxfMatched6.add(fPlayed.get(i / 2).multiply(fMatched6.get(i % 4), MathUtils.CONTEXT)); } List<BigDecimal> fMatched6Prime = new ArrayList<BigDecimal>(4); fMatched6Prime.add(fPlayedxfMatched6.get(0).add(fPlayedxfMatched6.get(2), MathUtils.CONTEXT)); fMatched6Prime.add(fPlayedxfMatched6.get(1).add(fPlayedxfMatched6.get(3), MathUtils.CONTEXT)); fMatched6Prime.add(fPlayedxfMatched6.get(4).add(fPlayedxfMatched6.get(6), MathUtils.CONTEXT)); fMatched6Prime.add(fPlayedxfMatched6.get(5).add(fPlayedxfMatched6.get(7), MathUtils.CONTEXT)); double [] fJackpotWon = { 0.9999999825, 0.9999999450, 0.0000000175, 0.0000000550 }; Parfactor g1 = new StdParfactorBuilder().variables(big_jackpot).values(fBigJackpot).build(); Parfactor g2 = new StdParfactorBuilder().variables(big_jackpot, played).values(fPlayed).build(); Parfactor g3 = new StdParfactorBuilder().variables(played, matched_6).values(fMatched6).build(); Parfactor g4 = new AggParfactorBuilder(matched_6, jackpot_won, Or.OR).context(big_jackpot).build(); Parfactor g5 = new StdParfactorBuilder().variables(big_jackpot, matched_6).values(fMatched6Prime).build(); Parfactor g7 = new StdParfactorBuilder().variables(big_jackpot, jackpot_won).values(fJackpotWon).build(); Parfactor g8 = new StdParfactorBuilder().variables(jackpot_won).values(fJackpotWonPrime).build(); RandomVariableSet query = RandomVariableSet.getInstance(played, Sets.<Constraint>getInstance(0)); InputOutput<Marginal, Marginal> inOut = InputOutput.getInstance(); Marginal input, output; // first step - multiplies g2 by g3 and sums out played(Person) input = new StdMarginalBuilder(4).parfactors(g1, g2, g3, g4).preservable(query).build(); output = new StdMarginalBuilder(3).parfactors(g1, g4, g5).preservable(query).build(); inOut.add(input, output); // second step - multiplies g4 by g5 and sums out matched_6(Person) input = new StdMarginalBuilder(3).parfactors(g1, g4, g5).preservable(query).build(); output = new StdMarginalBuilder(2).parfactors(g1, g7).preservable(query).build(); inOut.add(input, output); // third step - sums out big_jackpot() input = new StdMarginalBuilder(2).parfactors(g1, g7).preservable(query).build(); output = new StdMarginalBuilder(1).parfactors(g8).preservable(query).build(); inOut.add(input, output); return inOut.toCollection(); } private Marginal input; private Marginal expected; public StepByStepExampleComputationWithAggregation(Marginal input, Marginal expected) { this.input = input; this.expected = expected; } @Test @Ignore("Activate when aggregation parfactors are no longer converted in the beginning of the algorithm") public void testStep() { ACFOVE acfove = new ACFOVE(input); Marginal result = acfove.runStep(); assertEquals(expected, result); } } /** * Tests for AC-FOVE algorithm using example 3.14 of Kisynski (2010). * These tests assume that all aggregation parfactors are converted to * standard parfactors in the beginning of the algorithm. This is a simpler, * less efficient approach. */ @RunWith(Parameterized.class) public static class AggregationExampleWithConversion { @Parameters public static Collection<Object[]> data() { int populationSize = 4; InputOutput<Marginal, Parfactor> inOut = InputOutput.getInstance(); for (int n = 1; n <= populationSize; n++) { LogicalVariable person = StdLogicalVariable.getInstance("Person", "x", n); Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); Prv played = StdPrv.getBooleanInstance("played", person); Prv matched_6 = StdPrv.getBooleanInstance("matched_6", person); Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); List<BigDecimal> fBigJackpot = Utils.toBigDecimalList(0.8, 0.2); List<BigDecimal> fPlayed = Utils.toBigDecimalList( 0.95, 0.05, 0.85, 0.15); List<BigDecimal> fMatched6 = Utils.toBigDecimalList( 1.0, 0.0, 0.99999993, 0.00000007); List<BigDecimal> temp = Utils.toBigDecimalList( 0.9999999965, 0.0000000035, 0.9999999895, 0.0000000105); // 0.8 * 0.9999999965^n + 0.2 * 0.9999999895^n BigDecimal r0 = fBigJackpot.get(0).multiply(MathUtils.pow(temp.get(0), n, 1), MathUtils.CONTEXT).add(fBigJackpot.get(1).multiply(MathUtils.pow(temp.get(2), n, 1), MathUtils.CONTEXT), MathUtils.CONTEXT); // expression is too complicated to put in one line >8O //BigDecimal r1 = fBigJackpot.get(0).multiply(getSum(temp.get(0), temp.get(1), n)).add(fBigJackpot.get(1).multiply(getSum(temp.get(2), temp.get(3), n))); // Hey, this is much easier: BigDecimal r1 = BigDecimal.ONE.subtract(r0, MathUtils.CONTEXT); List<BigDecimal> fResult = Lists.listOf(r0, r1); Parfactor g1 = new StdParfactorBuilder().variables(big_jackpot).values(fBigJackpot).build(); Parfactor g2 = new StdParfactorBuilder().variables(big_jackpot, played).values(fPlayed).build(); Parfactor g3 = new StdParfactorBuilder().variables(played, matched_6).values(fMatched6).build(); Parfactor g4 = new AggParfactorBuilder(matched_6, jackpot_won, Or.OR).context(big_jackpot).build(); RandomVariableSet query = RandomVariableSet.getInstance(jackpot_won, Sets.<Constraint>getInstance(0)); Marginal input = new StdMarginalBuilder(4).parfactors(g1, g2, g3, g4).preservable(query).build(); Parfactor expected = new StdParfactorBuilder().variables(jackpot_won).values(fResult).build(); inOut.add(input, expected); } return inOut.toCollection(); } /** * Returns * ∑<sub>n</sub>a<sub>0</sub><sup>n-i</sup>.a<sub>1</sub><sup>i</sup> */ private static BigDecimal getSum(BigDecimal a0, BigDecimal a1, int n) { BigDecimal result = BigDecimal.ZERO; for (int i = 1; i <= n; i++) { result = result.add(MathUtils.pow(a0, n - i, 1).multiply(MathUtils.pow(a1, i, 1), MathUtils.CONTEXT), MathUtils.CONTEXT); } return result; } private Marginal input; private Parfactor expected; public AggregationExampleWithConversion(Marginal input, Parfactor expected) { this.input = input; this.expected = expected; } @Test public void testAggregationExampleWithConversion() { ACFOVE acfove = new ACFOVE(input); Parfactor result = acfove.run(); assertEquals(expected, result); } } public static class CorrectnessTest { @Ignore @Test public void testNodeWithCommonParent() { int domainSize = 3; LogicalVariable x = StdLogicalVariable.getInstance("X", "x", domainSize); Constant x1 = Constant.getInstance("x1"); Prv b = StdPrv.getBooleanInstance("b"); Prv r = StdPrv.getBooleanInstance("r", x); Prv r11 = StdPrv.getBooleanInstance("r", x1); List<BigDecimal> f1 = Utils.toBigDecimalList(0.2, 0.8); List<BigDecimal> f2 = Utils.toBigDecimalList(1.0, 0.0, 0.1, 0.9); Parfactor g1 = new StdParfactorBuilder().variables(b).values(f1).build(); Parfactor g2 = new StdParfactorBuilder().variables(b, r).values(f2).build(); RandomVariableSet query = RandomVariableSet.getInstance(b, Sets.<Constraint>getInstance(0)); Marginal input = new StdMarginalBuilder(2).parfactors(g1, g2).preservable(query).build(); Marginal groundedInput = propositionalizeAll(input); // ACFOVE ve = new ACFOVE(groundedInput); // Parfactor groundedResult = ve.run(); // System.out.println(groundedResult); ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); System.out.println(result); } @Ignore("First need to know how to eliminate variable from parfactor of type (a(X), b(Y))") @Test public void testNodesWithCommonParent() { int domainSize = 3; LogicalVariable x = StdLogicalVariable.getInstance("X", "x", domainSize); LogicalVariable y = StdLogicalVariable.getInstance("Y", "y", domainSize); Constant x1 = Constant.getInstance("x1"); Constant y1 = Constant.getInstance("y1"); Prv b = StdPrv.getBooleanInstance("b", y); Prv r = StdPrv.getBooleanInstance("r", x, y); Prv r11 = StdPrv.getBooleanInstance("r", x1, y1); List<BigDecimal> f1 = Utils.toBigDecimalList(0.2, 0.8); List<BigDecimal> f2 = Utils.toBigDecimalList(1.0, 0.0, 0.1, 0.9); Parfactor g1 = new StdParfactorBuilder().variables(b).values(f1).build(); Parfactor g2 = new StdParfactorBuilder().variables(b, r).values(f2).build(); RandomVariableSet query = RandomVariableSet.getInstance(r11, Sets.<Constraint>getInstance(0)); Marginal input = new StdMarginalBuilder(2).parfactors(g1, g2).preservable(query).build(); Marginal groundedInput = propositionalizeAll(input); System.out.println(g1.multiply(g2)); ACFOVE ve = new ACFOVE(groundedInput); Parfactor groundedResult = ve.run(); System.out.println(groundedResult); ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); System.out.println(result); } private Marginal propositionalizeAll(Marginal m) { Marginal result = m; for (Parfactor p : m) { for (LogicalVariable v : p.logicalVariables()) { result = new Propositionalize(result, p, v).run(); } } return result; } @Ignore("Test with Kisynski example first") @Test public void testExistsNodeManually() { int domainSize = 2; LogicalVariable x = StdLogicalVariable.getInstance("X", "x", domainSize); LogicalVariable y = StdLogicalVariable.getInstance("Y", "y", domainSize); Constant x1 = Constant.getInstance("x1"); Constant y1 = Constant.getInstance("y1"); Prv b = StdPrv.getBooleanInstance("b", y); Prv by = CountingFormula.getInstance(y, b); Prv r = StdPrv.getBooleanInstance("r", x, y); Prv a = StdPrv.getBooleanInstance("and", x, y); Prv e = StdPrv.getBooleanInstance("exists", x); Prv eaux = StdPrv.getBooleanInstance("exists_aux", x); Prv ex = CountingFormula.getInstance(x, e); List<BigDecimal> fb = Utils.toBigDecimalList(0.1, 0.9); List<BigDecimal> fr = Utils.toBigDecimalList(0.2, 0.8); List<BigDecimal> fand = Utils.toBigDecimalList(1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0); List<BigDecimal> f6 = Utils.toBigDecimalList(1.0, 1.0, 1.0, 0.0); List<BigDecimal> f7 = Utils.toBigDecimalList(1.0, 0.0, -1.0, 1.0); Parfactor g1 = new StdParfactorBuilder().variables(b).values(fb).build(); Parfactor g2 = new StdParfactorBuilder().variables(r).values(fr).build(); Parfactor g3 = new StdParfactorBuilder().variables(r, b, a).values(fand).build(); Parfactor g4 = new AggParfactorBuilder(a, e, Or.OR).context(b).build(); Parfactor g5 = g2.multiply(g3).sumOut(r); Parfactor g6 = g5.multiply(g1); Parfactor g7 = g6.multiply(g4).sumOut(a); Parfactor g8 = g7.count(x); Parfactor g9 = g8.sumOut(b); // Trying to use special agg parfactor conversion // Parfactor g5 = g2.multiply(g3).sumOut(r); // Parfactor g6 = new StdParfactorBuilder().variables(a, eaux).values(f6).build(); // Parfactor g7 = new StdParfactorBuilder().variables(e, eaux).values(f7).build(); // Parfactor g8 = g5.multiply(g6).sumOut(a); // Parfactor g9 = g1.multiply(g8).count(y); // Parfactor g10 = g9.multiply(g7).sumOut(eaux); // Parfactor g11 = g10.count(x).sumOut(by); System.out.println("Welcome to this incredible test!"); System.out.println(getCorrectResultOfExistsNodeManually(domainSize)); System.out.println(g9); } private Factor getCorrectResultOfExistsNodeManually(int n) { // List of constants List<Constant> x = new ArrayList<Constant>(n); List<Constant> y = new ArrayList<Constant>(n); for (int i = 0; i < n; i++) { x.add(Constant.getInstance("x" + i)); y.add(Constant.getInstance("y" + i)); } // Creates random variables List<Prv> b = new ArrayList<Prv>(n); List<Prv> r = new ArrayList<Prv>(n * n); List<Prv> a = new ArrayList<Prv>(n * n); List<Prv> e = new ArrayList<Prv>(n); for (int i = 0; i < n; i++) { b.add(StdPrv.getBooleanInstance("b", y.get(i))); e.add(StdPrv.getBooleanInstance("exists", x.get(i))); for (int j = 0; j < n; j++) { r.add(StdPrv.getBooleanInstance("r", x.get(i), y.get(j))); a.add(StdPrv.getBooleanInstance("and", x.get(i), y.get(j))); } } // Creates factors on b(Y) List<BigDecimal> vb = Utils.toBigDecimalList(0.1, 0.9); List<Factor> fb = new ArrayList<Factor>(n); for (int i = 0; i < n; i++) { int index = i; List<Prv> rvs = Lists.listOf(b.get(index)); fb.add(StdFactor.getInstance("", rvs, vb)); } // Creates factors on r(X,Y) List<BigDecimal> vr = Utils.toBigDecimalList(0.2, 0.8); List<Factor> fr = new ArrayList<Factor>(n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { int index = i * n + j; List<Prv> rvs = Lists.listOf(r.get(index)); fr.add(StdFactor.getInstance("", rvs, vr)); } } // Creates factors on r(X,Y), b(Y), and(X,Y) List<BigDecimal> vand = Utils.toBigDecimalList(1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0); List<Factor> fa = new ArrayList<Factor>(n * n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { List<Prv> rvs = Lists.listOf(b.get(j), r.get(i * n + j), a.get(i * n + j)); fa.add(StdFactor.getInstance("", rvs, vand)); } } // Creates factors on and(X,y1), ..., and(X,yn), exists(X) int vexistsSize = (int) Math.pow(2, n); List<BigDecimal> vexists = new ArrayList<BigDecimal>(vexistsSize); for (int i = 0; i < vexistsSize; i++) { vexists.add(BigDecimal.ZERO); vexists.add(BigDecimal.ONE); } vexists.set(0, BigDecimal.ONE); vexists.set(1, BigDecimal.ZERO); List<Factor> fe = new ArrayList<Factor>(n); for (int i = 0; i < n; i++) { List<Prv> rvs = new ArrayList<Prv>(n + 1); for (int j = 0; j < n; j++) { rvs.add(a.get(i * n + j)); } rvs.add(e.get(i)); fe.add(StdFactor.getInstance("", rvs, vexists)); } // The stupidest way to solve it // Factor result = StdFactor.getInstance(); // for (Factor f : fb) { // result = result.multiply(f); // } // for (Factor f : fr) { // result = result.multiply(f); // } // for (Factor f : fa) { // result = result.multiply(f); // } // for (Factor f : fe) { // result = result.multiply(f); // } // for (Prv v : result.variables()) { // if (!e.contains(v)) { // result = result.sumOut(v); // } // } // Step by step // Eliminate r(X,Y) List<Factor> afterEliminating_r = new ArrayList<Factor>(n * n); for (int i = 0; i < n * n; i++) { afterEliminating_r.add(fr.get(i).multiply(fa.get(i)).sumOut(r.get(i))); } // Eliminate and(X,Y) List<Factor> afterEliminating_and = new ArrayList<Factor>(n * n); for (int i = 0; i < n; i++) { Factor product = fe.get(i); for (Prv v : fe.get(i).variables()) { for (Factor f : afterEliminating_r) { if (f.variables().contains(v)) { product = product.multiply(f); } } } for (Prv and : a) { if (product.variables().contains(and)) { product = product.sumOut(and); } } afterEliminating_and.add(product); } // Eliminate b(Y) -- need to multiply all remaining factors ... Factor product = StdFactor.getInstance(); for (Factor f : afterEliminating_and) { product = product.multiply(f); } for (Factor f : fb) { product = product.multiply(f); } // and eliminate each b(yi) Factor result = product; for (Prv by : b) { result = result.sumOut(by); } result = result.sumOut(e.get(0)); return result; } @Ignore("Wrong: this is the case to use Aggregation parfactors") @Test public void testFullyConnectedNetwork() throws IOException { int domainSize = 3; LogicalVariable x = StdLogicalVariable.getInstance("X", "x", domainSize); LogicalVariable y = StdLogicalVariable.getInstance("Y", "y", domainSize); Constant x1 = Constant.getInstance("x1"); Constant x2 = Constant.getInstance("x2"); Constant y1 = Constant.getInstance("y1"); Prv e = StdPrv.getBooleanInstance("e", x); Prv b = StdPrv.getBooleanInstance("b", y); Prv e1 = StdPrv.getBooleanInstance("e", x1); List<BigDecimal> f1 = Utils.toBigDecimalList(0.2, 0.8); double r1n = Math.pow(0.1, domainSize); List<BigDecimal> f2 = Utils.toBigDecimalList(1.0, 0.0, r1n, 1.0 - r1n); Parfactor g1 = new StdParfactorBuilder().variables(b).values(f1).build(); Parfactor g2 = new StdParfactorBuilder().variables(b, e).values(f2).build(); RandomVariableSet query = RandomVariableSet.getInstance(e1, Sets.<Constraint>getInstance(0)); Marginal input = new StdMarginalBuilder(2).parfactors(g1, g2).preservable(query).build(); Marginal groundedInput = propositionalizeAll(input); Parfactor product = new StdParfactorBuilder().build(); for (Parfactor p : groundedInput) { product = product.multiply(p); } ACFOVE ve = new LoggedACFOVE(groundedInput); Parfactor groundedResult = ve.run(); System.out.println("Hi: " + groundedResult); } @Ignore("Testing manually") @Test public void testInferenceOnExistsNodeSimplified() throws IOException { int n = 1; LogicalVariable x = StdLogicalVariable.getInstance("X", "x", n); LogicalVariable y = StdLogicalVariable.getInstance("Y", "y", n); Constant x1 = Constant.getInstance("x1"); Prv b = StdPrv.getBooleanInstance("b", y); Prv e = StdPrv.getBooleanInstance("e", x); Prv r = StdPrv.getBooleanInstance("r", x, y); Prv r1 = StdPrv.getBooleanInstance("and", x, y); Prv e1 = StdPrv.getBooleanInstance("e", x1); List<BigDecimal> fb = TestUtils.toBigDecimalList(0.1, 0.9); List<BigDecimal> fr = TestUtils.toBigDecimalList(0.2, 0.8); List<BigDecimal> fr1 = TestUtils.toBigDecimalList(1, 0, 1, 0, 1, 0, 0, 1); Parfactor gb = new StdParfactorBuilder().variables(b).values(fb).build(); Parfactor gr = new StdParfactorBuilder().variables(r).values(fr).build(); Parfactor gr1 = new StdParfactorBuilder().variables(b, r, r1).values(fr1).build(); Parfactor ge = new AggParfactorBuilder(r1, e, Or.OR).build(); RandomVariableSet query = RandomVariableSet.getInstance(e1, new HashSet<Constraint>(0)); Marginal input = new StdMarginalBuilder().parfactors(gb, gr, gr1, ge).preservable(query).build(); input = propositionalizeAll(input); ACFOVE ve = new LoggedACFOVE(input); Parfactor groundedResult = ve.run(); System.out.println("Hi again: " + groundedResult); } @Test public void testBigJackpotInference() { int populationSize = 4; LogicalVariable person = StdLogicalVariable.getInstance("Person", "x", populationSize); Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); Prv played = StdPrv.getBooleanInstance("played", person); Prv matched_6 = StdPrv.getBooleanInstance("matched_6", person); Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); List<BigDecimal> fBigJackpot = Utils.toBigDecimalList(0.8, 0.2); List<BigDecimal> fPlayed = Utils.toBigDecimalList(0.95, 0.05, 0.85, 0.15); List<BigDecimal> fMatched6 = Utils.toBigDecimalList(1.0, 0.0, 0.99999993, 0.00000007); Parfactor g1 = new StdParfactorBuilder().variables(big_jackpot).values(fBigJackpot).build(); Parfactor g2 = new StdParfactorBuilder().variables(big_jackpot, played).values(fPlayed).build(); Parfactor g3 = new StdParfactorBuilder().variables(played, matched_6).values(fMatched6).build(); Parfactor g4 = new AggParfactorBuilder(matched_6, jackpot_won, Or.OR).context(big_jackpot).build(); Parfactor g2xg3 = g2.multiply(g3); Parfactor afterEliminatingPlayed = g2xg3.sumOut(played); Parfactor g4xg5 = g4.multiply(afterEliminatingPlayed); Parfactor afterEliminatingMatched6 = g4xg5.sumOut(matched_6); Parfactor g6xg1 = g1.multiply(afterEliminatingMatched6); Parfactor afterEliminatingBigJackpot = g6xg1.sumOut(big_jackpot); Factor r = getCorrectResultOfBigJackpotInference(populationSize); Parfactor expected = new StdParfactorBuilder().factor(r).build(); assertEquals(expected, afterEliminatingBigJackpot); } // propositionalizes the model above and calculates P(jackpot_won) private Factor getCorrectResultOfBigJackpotInference(int n) { List<Prv> matched_6 = new ArrayList<Prv>(n); for (int i = 0; i < n; i++) { Constant p = Constant.getInstance("x" + i); matched_6.add(StdPrv.getBooleanInstance("matched_6", p)); } Prv jackpot_won = StdPrv.getBooleanInstance("jackpot_won"); // Creates factor on matched_6(x0) ... matched_6(xn) jackpot_won() List<Prv> vars = Lists.listOf(matched_6); vars.add(jackpot_won); List<BigDecimal> fagg = new ArrayList<BigDecimal>(); for (int i = 0; i < (int) Math.pow(2, n); i++) { fagg.add(BigDecimal.ZERO); fagg.add(BigDecimal.ONE); } fagg.set(0, BigDecimal.ONE); fagg.set(1, BigDecimal.ZERO); Factor jw = StdFactor.getInstance("", vars, fagg); // Creates factor on big_jackpot() matched_6(X) Prv big_jackpot = StdPrv.getBooleanInstance("big_jackpot"); List<BigDecimal> f5vals = TestUtils.toBigDecimalList(0.9999999965, 0.0000000035, 0.9999999895, 0.0000000105); List<Factor> f5 = new ArrayList<Factor>(n); for (Prv m6 : matched_6) { List<Prv> rvs = new ArrayList<Prv>(2); rvs.add(big_jackpot); rvs.add(m6); f5.add(StdFactor.getInstance("", rvs, f5vals)); } // Creates factor big_jackpot() List<BigDecimal> fbj = TestUtils.toBigDecimalList(0.8, 0.2); Factor bj = StdFactor.getInstance("", big_jackpot, fbj); // multiplies all factors Factor product = jw; for (Factor f : f5) { product = product.multiply(f); } // eliminates matched_6(x) Factor result = product; for (Prv m6 : matched_6) { result = result.sumOut(m6); } // Multiplies by big_jackpot result = result.multiply(bj); // eliminates big_jackpot result = result.sumOut(big_jackpot); return result; } } /** * Some tests involving the lifted version of the water sprinkler network. * Data structures used in test cases from this class are created by * {@link WaterSprinklerNetwork} class. */ public static class WaterSprinklerNetworkTest { /** * Network: water sprinkler * Query: wet_grass(Lot) * Evidence: none * Population size: n >= 5 * * For n < 5, the algorithm propositionalizes the logical variable Lot * because it is cheaper than counting it. In the end, the result will * be propositionalized. This test assumes the answer is lifted so it * is easier to compare with the expected result. Last time I checked * the answer is correct for n < 5 too. */ @Test public void queryWetGrass() { int domainSize = 5; WaterSprinklerNetwork wsn = new WaterSprinklerNetwork(domainSize); wsn.setQuery(wsn.wetGrass); Marginal input = wsn.getMarginal(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = getResultWetGrass(wsn); // Compares expected with result assertEquals(expected, result); } private Parfactor getResultWetGrass(WaterSprinklerNetwork wsn) { Parfactor afterSumOutSprinkler = wsn.sprinklerParfactor .multiply(wsn.wetGrassParfactor).sumOut(wsn.sprinkler); Parfactor afterCountingLot = afterSumOutSprinkler.count(wsn.lot); Parfactor afterSumOutRain = afterCountingLot .multiply(wsn.rainParfactor).sumOut(wsn.rain); Parfactor afterSumOutCloudy = afterSumOutRain .multiply(wsn.cloudyParfactor).sumOut(wsn.cloudy); return afterSumOutCloudy; } /** * Network: water sprinkler * Query: rain() * Evidence: wet_grass(lot0) = true * Population size: 1 * * Even though the answer is consistent, the real probability is obtained * by dividing the result by the normalizing constant. I thought AC-FOVE * results did not need this kind of correction. Is it correct? * * TODO Check the need for normalizing constants */ @Test public void queryRainGivenWetGrassWithOneLot() { int domainSize = 1; WaterSprinklerNetwork wsn = new WaterSprinklerNetwork(domainSize); wsn.setEvidence(wsn.wetGrass, 0); wsn.setQuery(wsn.rain); Marginal input = wsn.getMarginal(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = getResultRainGivenWetGrassWithOneLot(wsn); // Compares expected with result assertEquals(expected, result); } private Parfactor getResultRainGivenWetGrassWithOneLot(WaterSprinklerNetwork wsn) { // Sum out wet_grass(lot0) Prv wetGrass = wsn.wetGrass.apply(wsn.getLot(0)); Parfactor wetGrassParfactor = wsn.wetGrassParfactor.apply(wsn.getLot(0)); Factor afterSumOutWetGrass = wsn.evidenceParfactor.factor() .multiply(wetGrassParfactor.factor()).sumOut(wetGrass); // Sum out sprinkler(lot0) Prv sprinkler = wsn.sprinkler.apply(wsn.getLot(0)); Parfactor sprinklerParfactor = wsn.sprinklerParfactor.apply(wsn.getLot(0)); Factor afterSumOutSprinkler = afterSumOutWetGrass .multiply(sprinklerParfactor.factor()).sumOut(sprinkler); // Sum out cloudy Factor afterSumOutCloudy = afterSumOutSprinkler .multiply(wsn.cloudyParfactor.factor()) .multiply(wsn.rainParfactor.factor()).sumOut(wsn.cloudy); Parfactor expected = new StdParfactorBuilder().factor(afterSumOutCloudy).build(); return expected; } /** * Network: water sprinkler * Query: sprinkler(lot0) * Evidence: wet_grass(lot0) = true * Population size: 1 * * Even though the answer is consistent, the real probability is obtained * by dividing the result by the normalizing constant. I thought AC-FOVE * results did not need this kind of correction. Is it correct? * * TODO Check the need for normalizing constants */ @Test public void querySprinklerGivenWetGrassWithOneLot() { int domainSize = 1; WaterSprinklerNetwork wsn = new WaterSprinklerNetwork(domainSize); wsn.setEvidence(wsn.wetGrass, 0); wsn.setQuery(wsn.sprinkler.apply(wsn.getLot(0))); Marginal input = wsn.getMarginal(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = getResultSprinklerGivenWetGrassWithOneLot(wsn); // Compares expected with result assertEquals(expected, result); } private Parfactor getResultSprinklerGivenWetGrassWithOneLot(WaterSprinklerNetwork wsn) { // Sum out wet_grass(lot0) Prv wetGrass = wsn.wetGrass.apply(wsn.getLot(0)); Parfactor wetGrassParfactor = wsn.wetGrassParfactor.apply(wsn.getLot(0)); Factor afterSumOutWetGrass = wsn.evidenceParfactor.factor() .multiply(wetGrassParfactor.factor()).sumOut(wetGrass); // Sum out rain Factor afterSumOutRain = afterSumOutWetGrass .multiply(wsn.rainParfactor.factor()).sumOut(wsn.rain); // Sum out cloudy Factor afterSumOutCloudy = afterSumOutRain .multiply(wsn.cloudyParfactor.factor()) .multiply(wsn.sprinklerParfactor.apply(wsn.getLot(0)).factor()) .sumOut(wsn.cloudy); Parfactor expected = new StdParfactorBuilder().factor(afterSumOutCloudy).build(); return expected; } /** * Network: Water Sprinkler * Query: rain() * Evidence: wet_grass(lot0) = true * Population size: 100 */ @Test public void queryRainGivenWetGrassWithManyLots() { int domainSize = 100; WaterSprinklerNetwork wsn = new WaterSprinklerNetwork(domainSize); wsn.setEvidence(wsn.wetGrass, 0); wsn.setQuery(wsn.rain); Marginal input = wsn.getMarginal(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = getResultyRainGivenWetGrassWithManyLots(wsn); // Compares expected with result assertEquals(expected, result); } private Parfactor getResultyRainGivenWetGrassWithManyLots(WaterSprinklerNetwork wsn) { // Splits the marginal on the evidence and the query Parfactor g1 = wsn.cloudyParfactor; Parfactor g2 = wsn.rainParfactor; SplitResult splitSprinkler = wsn.sprinklerParfactor.splitOn(wsn.getLot(0)); Parfactor g3 = splitSprinkler.residue().iterator().next(); Parfactor g3_0 = splitSprinkler.result(); SplitResult splitWetGrass = wsn.wetGrassParfactor.splitOn(wsn.getLot(0)); Parfactor g4 = splitWetGrass.residue().iterator().next(); Parfactor g4_0 = splitWetGrass.result(); Parfactor g5 = wsn.evidenceParfactor; // sum out wet_grass(Lot):{Lot!=lot0} Parfactor afterSumOutWetGrass = g4.sumOut(wsn.wetGrass); // sum out sprinkler(Lot):{Lot!=lot0} Parfactor afterSumOutSprinkler = afterSumOutWetGrass.multiply(g3).sumOut(wsn.sprinkler); // sum out wet_grass(lot0) Prv wetgrass_lot0 = wsn.wetGrass.apply(wsn.getLot(0)); Parfactor afterSumOutWetGrassLot0 = g5.multiply(g4_0).sumOut(wetgrass_lot0); // sum out sprinkler(lot0) Prv sprinkler_lot0 = wsn.sprinkler.apply(wsn.getLot(0)); Parfactor afterSumOutSprinklerLot0 = afterSumOutWetGrassLot0.multiply(g3_0).sumOut(sprinkler_lot0); // sum out cloudy() Parfactor afterSumOutCloudy = afterSumOutSprinkler .multiply(afterSumOutSprinklerLot0).multiply(g1).multiply(g2) .sumOut(wsn.cloudy); return afterSumOutCloudy; } /** * Network: Water Sprinkler * Query: sprinkler(lot0) * Evidence: wet_grass(lot0) = true * Population size: 10 */ @Test public void querySprinklerGivenWetGrassWithManyLots() { int domainSize = 100; WaterSprinklerNetwork wsn = new WaterSprinklerNetwork(domainSize); wsn.setEvidence(wsn.wetGrass, 0); wsn.setQuery(wsn.sprinkler.apply(wsn.getLot(0))); Marginal input = wsn.getMarginal(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); Parfactor expected = getResultSprinklerGivenWetGrassWithManyLots(wsn); // Compares expected with result assertEquals(expected, result); } private Parfactor getResultSprinklerGivenWetGrassWithManyLots(WaterSprinklerNetwork wsn) { // Splits the marginal on the evidence and the query Parfactor g1 = wsn.cloudyParfactor; Parfactor g2 = wsn.rainParfactor; SplitResult splitSprinkler = wsn.sprinklerParfactor.splitOn(wsn.getLot(0)); Parfactor g3 = splitSprinkler.residue().iterator().next(); Parfactor g3_0 = splitSprinkler.result(); SplitResult splitWetGrass = wsn.wetGrassParfactor.splitOn(wsn.getLot(0)); Parfactor g4 = splitWetGrass.residue().iterator().next(); Parfactor g4_0 = splitWetGrass.result(); Parfactor g5 = wsn.evidenceParfactor; // sum out wet_grass(Lot):{Lot!=lot0} Parfactor afterSumOutWetGrass = g4.sumOut(wsn.wetGrass); // sum out sprinkler(Lot):{Lot!=lot0} Parfactor afterSumOutSprinkler = afterSumOutWetGrass.multiply(g3).sumOut(wsn.sprinkler); // sum out wet_grass(lot0) Prv wetgrass_lot0 = wsn.wetGrass.apply(wsn.getLot(0)); Parfactor afterSumOutWetGrassLot0 = g5.multiply(g4_0).sumOut(wetgrass_lot0); // sum out rain() Parfactor afterSumOutRain = afterSumOutWetGrassLot0 .multiply(afterSumOutSprinkler).multiply(g2).sumOut(wsn.rain); // sum out cloudy() Parfactor afterSumOutCloudy = afterSumOutRain.multiply(g1).multiply(g3_0).sumOut(wsn.cloudy); return afterSumOutCloudy; } } /** * Test using the competing workshop network. * @see Example#competingWorkshopsNetwork(int, int) */ public static class CompetingWorkshops { /** * Network: competing workshops (Milch 2008) * Query: success * Evidence: none * Population size: 10 workshops, 1000 people * */ @Test public void querySomeDeath() { // Network initialization int numberOfPeople = 10; int numberOfWorkshops = 10; Example network = Example.competingWorkshopsNetwork(numberOfWorkshops, numberOfPeople); Parfactor gh = network.parfactor("ghot"); Parfactor ga = network.parfactor("gattends"); Parfactor gs = network.parfactor("gsuccess"); // Query Prv success = network.prv("success ( )"); RandomVariableSet query = RandomVariableSet.getInstance(success, new HashSet<Constraint>(0)); // Input marginal Marginal input = new StdMarginalBuilder(5).parfactors(gh, ga, gs).preservable(query).build(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); // Calculates the correct result // Sum out hot Prv hot = network.prv("hot ( Workshop )"); Parfactor afterSumOutHot = gh.multiply(ga).sumOut(hot); // Converts aggregation parfactor to standard parfactors Distribution converted = ((AggregationParfactor) gs).toStdParfactors(); // Gets the converted parfactor that contains the counting formula Prv attends = network.prv("attends ( Person )"); for (Parfactor p : converted) { if (!p.contains(attends)) { gs = p; } } // Sum out attends LogicalVariable person = network.lv("Person"); attends = CountingFormula.getInstance(person, attends); Parfactor afterSumOutAttends = afterSumOutHot.multiply(gs).sumOut(attends); Parfactor expected = afterSumOutAttends; // Compares expected with result assertEquals(expected, result); } } /** * Test using the sick and death network. * @see Example#sickDeathNetwork(int) */ public static class SickDeath { /** * Network: sick and death (Braz 2005) * Query: someDeath * Evidence: none * Population size: 10 */ @Test public void querySomeDeath() { // Network initialization int domainSize = 10; Example network = Example.sickDeathNetwork(domainSize); Parfactor ge = network.parfactor("gepidemic"); Parfactor gs = network.parfactor("gsick"); Parfactor gd = network.parfactor("gdeath"); Parfactor gsd = network.parfactor("gsomedeath"); // Query Prv someDeath = network.prv("someDeath ( )"); RandomVariableSet query = RandomVariableSet.getInstance(someDeath, new HashSet<Constraint>(0)); // Input marginal Marginal input = new StdMarginalBuilder(5).parfactors(ge, gs, gd, gsd).preservable(query).build(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); // Calculates the correct result // Sum out sick Prv sick = network.prv("sick ( Person )"); Parfactor afterSumOutSick = gs.multiply(gd).sumOut(sick); // Sum out death Prv death = network.prv("death ( Person )"); Parfactor afterSumOutDeath = afterSumOutSick.multiply(gsd).sumOut(death); // Sum out epidemic Prv epidemic = network.prv("epidemic ( )"); Parfactor afterSumOutEpidemic = afterSumOutDeath.multiply(ge).sumOut(epidemic); Parfactor expected = afterSumOutEpidemic; // Compares expected with result assertEquals(expected, result); } } }