/*******************************************************************************
* Copyright 2014 Felipe Takiyama
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package br.usp.poli.takiyama.acfove;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import br.usp.poli.takiyama.common.Constraint;
import br.usp.poli.takiyama.common.Marginal;
import br.usp.poli.takiyama.common.Parfactor;
import br.usp.poli.takiyama.common.Scanner;
import br.usp.poli.takiyama.common.SplitResult;
import br.usp.poli.takiyama.common.StdMarginal.StdMarginalBuilder;
import br.usp.poli.takiyama.common.Tuple;
import br.usp.poli.takiyama.prv.Binding;
import br.usp.poli.takiyama.prv.NameGenerator;
import br.usp.poli.takiyama.prv.Prv;
import br.usp.poli.takiyama.prv.Prvs;
import br.usp.poli.takiyama.prv.RandomVariableSet;
import br.usp.poli.takiyama.prv.Substitution;
import br.usp.poli.takiyama.prv.Term;
import br.usp.poli.takiyama.utils.Lists;
import br.usp.poli.takiyama.utils.Sets;
/**
* This operation makes all the necessary splits and expansions to
* guarantee that the sets of random variables represented by parameterized
* random variables in each parfactor of the given set are equal or
* disjoint.
* <p>
* In other words, for any parameterized random variables p and q from
* parfactors in the marginal, p and q represent identical or
* disjoint sets of random variables.
* </p>
* <p>
* This operation is used before multiplication and elimination
* on parfactors.
* </p>
*
* @author Felipe Takiyama
*/
public final class Shatter implements MacroOperation {
private Marginal marginal;
public static class MutableQueue<T> implements Iterable<Tuple<T>> {
// The queue
private final List<T> queue;
// Iterator that returns pairs of parfactor from the queue
private MutableQueueIterator iterator;
public MutableQueue(Collection<? extends T> c) throws IllegalArgumentException{
if (c.size() < 2) {
throw new IllegalArgumentException();
}
queue = new ArrayList<T>(c);
iterator = new MutableQueueIterator();
}
public void add(Collection<? extends T> c) {
queue.addAll(c);
iterator.reset();
}
/**
* Tries to remove the elements in the specified tuple from the queue.
* Elements that are not in the queue are not removed.
* @param t
*/
public void remove(Tuple<T> t) {
boolean removedItem = false;
for (int i = 0; i < t.size(); i++) {
if (queue.remove(t.get(i))) {
removedItem = true;
}
}
if (removedItem) {
/*
* Resets iteration only if queue structure was modified.
* Otherwise we would have an infinite loop.
*/
iterator.reset();
}
}
private class MutableQueueIterator implements Iterator<Tuple<T>> {
private int i, j;
private MutableQueueIterator() {
reset();
}
private void reset() {
i = -1;
j = 0;
}
@Override
public boolean hasNext() {
return ((i < (queue.size() - 2)) || (j < (queue.size() - 1)));
}
@Override
public Tuple<T> next() {
// calculates next tuple
if ((i == -1) && (j == 0)) {
i++;
j++;
} else {
j++;
if (j == queue.size()) {
i++;
if (i != (queue.size() - 1)) {
j = i + 1;
}
}
}
Tuple<T> t = Tuple.getInstance(Lists.listOf(queue.get(i), queue.get(j)));
return t;
}
/**
* Throws {@link UnsupportedOperationException}.
*/
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
@Override
public Iterator<Tuple<T>> iterator() {
return iterator;
}
public Set<T> toSet() {
return new HashSet<T>(queue);
}
}
public Shatter(Marginal marginal) {
this.marginal = new StdMarginalBuilder().add(marginal).build();
// ConsoleLogger.setup();
}
// @Override
// public Marginal run() {
// if (marginal.distribution().isEmpty()) {
// return marginal;
// }
//
// simplifyLogicalVariables();
// renameAllLogicalVariables();
//
// Stack<Parfactor> parfactorsToProcess = new Stack<Parfactor>();
// parfactorsToProcess.addAll(marginal.distribution().toSet());
//
// // A set of shattered parfactors
// Set<Parfactor> shatteredSet = new HashSet<Parfactor>();
// // A temporary set of shattered parfactors
// Set<Parfactor> shatteredPool = new HashSet<Parfactor>();
//
// while (!parfactorsToProcess.isEmpty()) {
// Parfactor p1 = parfactorsToProcess.pop();
// while (!parfactorsToProcess.isEmpty()) {
// Parfactor p2 = parfactorsToProcess.pop();
// logger.info("\n Evaluating\n" + p1 + "\nwith \n" + p2);
// Marginal unifiedSet = unify(p1, p2);
// if (unifiedSet.isEmpty()) {
// logger.info("Unification result: they do not unify");
// shatteredPool.add(p2);
// } else {
// logger.info("Unification result:\n" + unifiedSet);
// parfactorsToProcess.addAll(unifiedSet.distribution().toSet());
// parfactorsToProcess.addAll(shatteredPool);
// parfactorsToProcess.addAll(shatteredSet);
// shatteredPool.clear();
// shatteredSet.clear();
// break; //p1 = parfactorsToProcess.pop();
// }
// }
// shatteredSet.add(p1);
// parfactorsToProcess.addAll(shatteredPool);
// shatteredPool.clear();
// }
//
// shatteredSet = Sets.apply(NameGenerator.getOldNames(), shatteredSet);
//
// // clears buffered names
// NameGenerator.reset();
//
// return new StdMarginalBuilder().parfactors(shatteredSet)
// .preservable(marginal.preservable()).build();
// }
@Override
public Marginal run() {
int marginalSize = marginal.distribution().size();
if (marginalSize < 2) {
return marginal;
}
marginal = simplifyLogicalVariables(marginal);
renameAllLogicalVariables();
MutableQueue<Parfactor> queue = new MutableQueue<Parfactor>(marginal.distribution().toSet());
for (Tuple<Parfactor> pair : queue) {
Marginal unifiedSet = unify(pair.get(0), pair.get(1));
if (!unifiedSet.isEmpty()) {
queue.remove(pair);
queue.add(unifiedSet.distribution().toSet());
}
}
// Renames back logical variables
Set<Parfactor> shattered = Sets.apply(NameGenerator.getOldNames(), queue.toSet());
// Clears buffered names
NameGenerator.reset();
// Builds the resulting marginal
Marginal result = new StdMarginalBuilder().parfactors(shattered)
.preservable(marginal.preservable()).build();
// Simplifies logical variables after shattering
result = simplifyLogicalVariables(result);
return result;
}
/**
* Replaces logical variables constrained to a single constant with this
* constant in all parfactors in the distribution.
*/
private Marginal simplifyLogicalVariables(Marginal marginal) {
StdMarginalBuilder m = new StdMarginalBuilder(marginal.size());
for (Parfactor p : marginal) {
m.add(p.simplifyLogicalVariables());
}
RandomVariableSet query = marginal.preservable();
return m.preservable(query).build();
}
/**
* Renames logical variables in parfactors. This is done to avoid repetition
* of logical variable names from different parfactors.
*/
private void renameAllLogicalVariables() {
StdMarginalBuilder m = new StdMarginalBuilder(marginal.size());
for (Parfactor p : this.marginal) {
m.add(renameLogicalVariables(p));
}
RandomVariableSet query = marginal.preservable();
this.marginal = m.preservable(query).build();
}
/**
* Renames logical variables from the specified parfactor. Names are
* generated by a {@link NameGenerator}.
*/
private Parfactor renameLogicalVariables(Parfactor p) {
Parfactor scanned = new Scanner(p);
return p.apply(NameGenerator.rename(scanned.logicalVariables()));
}
/**
* Tries to unify two parfactors. This function returns on the first
* oportunity where a pair of PRVs unify.
* To unify two parfactors it may be necessary to call this function
* several times.
* Returns an empty Marginal if p1 and p2 do not have unifiable PRVs.
*/
private Marginal unify(Parfactor p1, Parfactor p2) {
for (Prv prv1 : p1.prvs()) {
for (Prv prv2 : p2.prvs()) {
Marginal result = unify(p1, prv1, p2, prv2);
if (!result.isEmpty()) {
return result;
}
}
}
return new StdMarginalBuilder().build();
}
/**
* Unifies p1 and p2 on variables prv1 and prv2.
* prv1 must belong to p1 and prv2 must belong to p2, although no check is
* made to assure that. You will get a warming IndexOutOfBoundsException if
* you try to do that =)
* Returns an empty Marginal if prv1 and prv2 do not unify.
*/
private Marginal unify(Parfactor p1, Prv prv1, Parfactor p2, Prv prv2) {
// Stores the reference for prv1 and prv2 (they change if they unify)
int indexOfPrv1 = p1.prvs().indexOf(prv1);
int indexOfPrv2 = p2.prvs().indexOf(prv2);
StdMarginalBuilder result = new StdMarginalBuilder();
try {
Substitution mgu = Prvs.mgu(prv1, prv2);
// Now must check whether MGU is consistent with all constraints,
// including constraints from counting formulas
Set<Constraint> allConstraints = Sets.union(p1.constraints(),
p2.constraints(), prv1.constraints(), prv2.constraints());
if (!mgu.isEmpty() && mgu.isConsistentWith(allConstraints)) {
// Splits p1 and p2 on MGU
SplitResult firstSplit = split(p1, mgu);
SplitResult secondSplit = split(p2, mgu);
// Splits first result on second result constraints
prv2 = secondSplit.result().prvs().get(indexOfPrv2);
allConstraints = Sets.union(secondSplit.result().constraints(), prv2.constraints());
SplitResult firstSplitOnConstraints = split(firstSplit, allConstraints);
// Splits second result on first result constraints >> keep symmetry
prv1 = firstSplit.result().prvs().get(indexOfPrv1);
allConstraints = Sets.union(firstSplit.result().constraints(), prv1.constraints());
SplitResult secondSplitOnConstraints = split(secondSplit, allConstraints);
// Put everything together
Set<Parfactor> union = Sets.union(
firstSplitOnConstraints.distribution().toSet(),
secondSplitOnConstraints.distribution().toSet());
result.parfactors(union);
result.preservable(marginal.preservable());
} else {
// PRVs do not unify
}
} catch (IllegalArgumentException e) {
// PRVs represent disjoint sets of random variables
}
return result.build();
}
/**
* Returns the result of splitting the specified parfactor on the
* specified MGU
* <p>
* When the MGU is consistent with a set of inequality constraints,
* parameterized random variables represent non-disjoint and possibly
* non-identical sets of random variables. To make then identical, we
* split the parfactor involved on the MGU.
* </p>
* <p>
* This method splits the specified parfactor in all substitutions present
* in the MGU. The result depends on the order in which substitutions are
* made.
* </p>
*
* @param parfactor The parfactor to split
* @param mgu The Most General Unifier to split this parfactor.
* @return The result of splitting the specified parfactor on the
* specified MGU
*/
private SplitResult split(Parfactor parfactor, Substitution mgu) {
Parfactor result = parfactor;
StdMarginalBuilder residues = new StdMarginalBuilder();
for (Binding bind : mgu.asList()) {
Parfactor scanner = new Scanner(result);
if (scanner.logicalVariables().contains(bind.firstTerm())) {
// expands all counting formulas - not sure if it is the right thing to do
result = expand(result, Substitution.getInstance(bind));
Substitution bindAsSub = Substitution.getInstance(bind);
if (result.isSplittable(bindAsSub)) {
SplitResult split = result.splitOn(bindAsSub);
result = split.result();
residues.parfactors(split.residue());
} else {
result = result.apply(bindAsSub);
}
}
}
// works for std split result too
return SplitResult.getInstance(result, residues.build());
}
/**
* Returns the result of splitting the specified parfactor on the
* specified constraints.
* The only difference is that constraints are converted to substitutions
* and splits are made on residues.
*
* @see #split(Parfactor, Substitution)
*/
private SplitResult split(SplitResult splitResult, Set<Constraint> constraints) {
Parfactor residue = splitResult.result();
StdMarginalBuilder byProduct = new StdMarginalBuilder();
byProduct.parfactors(splitResult.residue());
for (Constraint constraint : constraints) {
Substitution constraintAsSub = convertToSubstitution(constraint);
residue = expand(residue, constraintAsSub);
if (residue.isSplittable(constraintAsSub)) {
SplitResult split = residue.splitOn(constraintAsSub);
residue = split.residue().iterator().next();
byProduct.parfactors(split.result());
}
}
return SplitResult.getInstance(residue, byProduct.build());
}
/**
* Returns the result of expanding all counting formulas from the
* specified parfactor on the specified term. Expansion is made only if
* conditions for expansion are met.
*/
private Parfactor expand(Parfactor parfactor, Substitution sub) {
List<Prv> variables = parfactor.prvs();
for (Prv prv : variables) {
if (parfactor.isExpandable(prv, sub)) {
Term term = sub.getReplacement(prv.boundVariable());
parfactor = parfactor.expand(prv, term);
}
}
return parfactor;
}
private Substitution convertToSubstitution(Constraint constraint) {
Substitution constraintAsSub;
try {
constraintAsSub = Substitution.getInstance(constraint.toBinding());
} catch (IllegalStateException e) {
constraintAsSub = Substitution.getInstance(constraint.toInverseBinding());
}
return constraintAsSub;
}
@Override
public int cost() {
return (int) Double.POSITIVE_INFINITY;
}
@Override
public int numberOfRandomVariablesEliminated() {
return 0;
}
@Override
public String toString() {
return "SHATTER";
}
}