/*******************************************************************************
* Copyright 2014 Felipe Takiyama
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package br.usp.poli.takiyama.acfove;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import br.usp.poli.takiyama.cfove.StdParfactor.StdParfactorBuilder;
import br.usp.poli.takiyama.common.Factor;
import br.usp.poli.takiyama.common.Marginal;
import br.usp.poli.takiyama.common.Parfactor;
import br.usp.poli.takiyama.prv.Binding;
import br.usp.poli.takiyama.prv.Constant;
import br.usp.poli.takiyama.prv.CountingFormula;
import br.usp.poli.takiyama.prv.LogicalVariable;
import br.usp.poli.takiyama.prv.Prv;
import br.usp.poli.takiyama.prv.RandomVariableSet;
import br.usp.poli.takiyama.prv.Substitution;
public class VariableElimination extends ACFOVE {
private Set<Factor> factors;
private Set<Prv> query;
// no observation
@Override
public Parfactor run() {
Set<Prv> preservables = new HashSet<Prv>(query); // a more general version would include evidences
while (thereAreVariablesToEliminate(factors, preservables)) {
Prv eliminable = selectPrvToEliminate(factors, preservables);
factors = eliminate(factors, eliminable);
}
Factor product = multiplyAll(factors);
return new StdParfactorBuilder().factor(product).build(); // returns unnormalized factor
}
// returns true if there is a variable to eliminate
private boolean thereAreVariablesToEliminate(Set<Factor> allFactors, Set<Prv> preservables) {
for (Factor factor : allFactors) {
List<Prv> vars = factor.variables();
vars.removeAll(preservables);
if (!vars.isEmpty()) {
return true;
}
}
return false;
}
// returns a prv in some factor that is not preservable
private Prv selectPrvToEliminate(Set<Factor> factors, Set<Prv> preservables) {
for (Factor factor : factors) {
for (Prv prv : factor.variables()) {
if (!preservables.contains(prv)) {
return prv;
}
}
}
return null;
}
// eliminates a random variable from a set of factors
private Set<Factor> eliminate(Set<Factor> allFactors, Prv eliminable) {
Set<Factor> factorsAfterElimination = new HashSet<Factor>(allFactors);
Factor product = null;
for (Factor factor : allFactors) {
if (factor.variables().contains(eliminable)) {
if (product == null) {
product = factor;
} else {
product = product.multiply(factor);
}
factorsAfterElimination.remove(factor);
}
}
Factor sumOut = product.sumOut(eliminable);
factorsAfterElimination.add(sumOut);
return factorsAfterElimination;
}
// multiplies all factors in the specified set, returns null if the set is empty
private Factor multiplyAll(Set<Factor> factors) {
Factor product = null;
for (Factor factor : factors) {
if (product == null) {
product = factor;
} else {
product = product.multiply(factor);
}
}
return product;
}
public VariableElimination(Marginal parfactors, Level logLevel) {
super(parfactors, logLevel);
this.factors = propositionalizeAll(removeAggregation(parfactors));
this.query = propositionalizeQuery(parfactors);
}
public VariableElimination(Marginal parfactors) {
super(parfactors);
this.factors = propositionalizeAll(removeAggregation(parfactors));
this.query = propositionalizeQuery(parfactors);
}
/**
* Returns the specified network completely propositionalized. The resulting
* network can be used with VE algorithms.
* <p>
* The resulting network will not have logical variables and factors (space
* constraint).
* </p>
*
* @param network The network to propositionalize.
* @return the specified network completely propositionalized.
*/
private Set<Factor> propositionalizeAll(Marginal marginal) {
Set<LogicalVariable> logicalVariables = new HashSet<LogicalVariable>();
// get all logical variables
for (Parfactor parfactor : marginal) {
logicalVariables.addAll(parfactor.logicalVariables());
}
// Auxiliary set of parfactors
Set<Parfactor> parfactors = marginal.distribution().toSet();
// propositionalizes all parfactors in the set on all logical variables
for (LogicalVariable lv : logicalVariables) {
for (Parfactor parfactor : parfactors) {
if (parfactor.logicalVariables().contains(lv)) {
MacroOperation propositionalize = new Propositionalize(marginal, parfactor, lv);
marginal = propositionalize.run();
}
}
// updates the set of parfactors
parfactors = marginal.distribution().toSet();
}
// expands all counting formulas
for (Parfactor parfactor : parfactors) {
for (Prv prv : parfactor.prvs()) {
if (prv instanceof CountingFormula) {
MacroOperation fullExpand = new FullExpand(marginal, parfactor, prv);
marginal = fullExpand.run();
}
}
// updates the set of parfactors
parfactors = marginal.distribution().toSet();
}
// extracts factors from parfactors
Set<Factor> factors = getFactors(parfactors);
return factors;
}
// propositionalizes the query; works for up to 1 parameter
private Set<Prv> propositionalizeQuery(Marginal marginal) {
RandomVariableSet preservable = marginal.preservable();
Set<Prv> query = new HashSet<Prv>();
for (LogicalVariable propositionalizable : preservable.parameters()) {
for (Constant individual : propositionalizable.individualsSatisfying(preservable.constraints())) {
Binding b = Binding.getInstance(propositionalizable, individual);
Substitution s = Substitution.getInstance(b);
query.add(preservable.prv().apply(s));
}
}
if (preservable.parameters().isEmpty()) {
query.add(preservable.prv());
}
return query;
}
// returns the factors from the specified set of parfactors
private Set<Factor> getFactors(Set<Parfactor> parfactors) {
Set<Factor> factors = new HashSet<Factor>(parfactors.size());
for (Parfactor p : parfactors) {
factors.add(p.factor());
}
return factors;
}
private static Marginal removeAggregation(Marginal marginal) {
for (Parfactor parfactor : marginal) {
if (parfactor instanceof AggParfactor) {
MacroOperation convert = new ConvertToStdParfactors(marginal, parfactor);
marginal = convert.run();
}
}
return marginal;
}
}