/******************************************************************************* * Copyright 2014 Felipe Takiyama * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package br.usp.poli.takiyama.acfove; import static org.junit.Assert.assertEquals; import java.util.HashSet; import org.junit.Test; import br.usp.poli.takiyama.common.Constraint; import br.usp.poli.takiyama.common.Marginal; import br.usp.poli.takiyama.common.Parfactor; import br.usp.poli.takiyama.common.StdMarginal.StdMarginalBuilder; import br.usp.poli.takiyama.prv.Prv; import br.usp.poli.takiyama.prv.RandomVariableSet; import br.usp.poli.takiyama.utils.Example; public class SickDeath { /** * Network: sick and death (Braz 2005) * Query: someDeath * Evidence: none * Population size: 10 * */ @Test public void querySomeDeath() { // Network initialization int domainSize = 10; Example network = Example.sickDeathNetwork(domainSize); Parfactor ge = network.parfactor("gepidemic"); Parfactor gs = network.parfactor("gsick"); Parfactor gd = network.parfactor("gdeath"); Parfactor gsd = network.parfactor("gsomedeath"); // Query Prv someDeath = network.prv("someDeath ( )"); RandomVariableSet query = RandomVariableSet.getInstance(someDeath, new HashSet<Constraint>(0)); // Input marginal Marginal input = new StdMarginalBuilder(5).parfactors(ge, gs, gd, gsd).preservable(query).build(); // Runs AC-FOVE on input marginal ACFOVE acfove = new LoggedACFOVE(input); Parfactor result = acfove.run(); // Calculates the correct result // Sum out sick Prv sick = network.prv("sick ( Person )"); Parfactor afterSumOutSick = gs.multiply(gd).sumOut(sick); // Sum out death Prv death = network.prv("death ( Person )"); Parfactor afterSumOutDeath = afterSumOutSick.multiply(gsd).sumOut(death); // Sum out epidemic Prv epidemic = network.prv("epidemic ( )"); Parfactor afterSumOutEpidemic = afterSumOutDeath.multiply(ge).sumOut(epidemic); Parfactor expected = afterSumOutEpidemic; // Compares expected with result assertEquals(expected, result); } }