/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.minsum;
import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableBlock;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolver;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Tests for Min-sum solver
* @since 0.08
* @author Christopher Barber
*/
public class TestMinSum extends DimpleTestBase
{
/**
* Test using min-sum so solve "small parsimony" problem in evolutionary bioinformatics.
* <p>
* The test generates a binary tree of random mutations to a DNA strand.
* It then uses the shape of the tree and the leaves of the tree to attempt
* to reconstruct the values of the strands on the internal nodes. Min-sum should
* be able to produce an answer that has the same or better score than the original
* tree, where the score is the sum of the hamming distances along the edges.
* <p>
* The idea for this test was taken from an assignment in the Coursera Bioinformatics Algorithms courses from UCSD.
*/
@Test
public void smallParsimonyTest()
{
smallParsimonyCase(1, 10, .1);
smallParsimonyCase(5, 20, .1);
smallParsimonyCase(10, 20, .1);
smallParsimonyCase(20, 20, .2);
}
private void smallParsimonyCase(int strandLength, int treeSize, double mutationRate)
{
// Each nucleotide in the strand is considered to be independent of the others,
// so the resulting variables in the graph should not be connected to each other
// through a factor. Instead, we will use variable blocks to tie them together.
//
// We use the guess values to store the generated nucleotide values for each variable.
// TODO - use data abstraction layer for this TBD
final DiscreteDomain nucleotides = DiscreteDomain.create('A','C','G','T');
final FactorFunction delta = new FactorFunction() {
@Override
public double evalEnergy(Value[] values)
{
assertEquals(2, values.length);
return values[0].valueEquals(values[1]) ? 0.0 : 1.0;
}
};
final FactorGraph fg = new FactorGraph();
// Need to attach solver so that we can set guesses.
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
fg.setSolverFactory(new MinSumSolver());
try (CurrentModel curModel = using(fg))
{
// Generate root block
final VariableBlock root = block(discretes("n0_", nucleotides, strandLength));
for (Variable var : root)
{
int index = testRand.nextInt(4);
var.setGuess(nucleotides.getElement(index));
}
ArrayList<VariableBlock> leaves = new ArrayList<>(treeSize);
ArrayList<VariableBlock> internal = new ArrayList<>();
int mutations = 0;
leaves.add(root);
for (int i = 1; i < treeSize; i += 2)
{
// Pick a random leaf
VariableBlock leaf = leaves.remove(testRand.nextInt(leaves.size()));
internal.add(leaf);
// Mutate it randomly twice to create child blocks. This simulates multiple steps of
// evolution. A single mutation event would be expected to result in one mutated child
// and a clone.
VariableBlock left = block(discretes(String.format("n%d_", i), nucleotides, strandLength));
leaves.add(left);
for (int j = 0; j < strandLength; ++j)
{
Discrete cur = (Discrete)left.get(j), prev = (Discrete)leaf.get(j);
int prevIndex = prev.getGuessIndex();
int nextIndex = testRand.nextBoolean(mutationRate) ? testRand.nextInt(4) : prevIndex;
if (prevIndex != nextIndex)
++mutations;
cur.setGuessIndex(nextIndex);
addFactor(delta, prev, cur);
}
VariableBlock right = block(discretes(String.format("n%d_", i+1), nucleotides, strandLength));
leaves.add(right);
for (int j = 0; j < strandLength; ++j)
{
Discrete cur = (Discrete)right.get(j), prev = (Discrete)leaf.get(j);
int prevIndex = prev.getGuessIndex();
int nextIndex = testRand.nextBoolean(mutationRate) ? testRand.nextInt(4) : prevIndex;
if (prevIndex != nextIndex)
++mutations;
cur.setGuessIndex(nextIndex);
addFactor(delta, prev, cur);
}
}
// Now solve using Min-Sum
// The initial score should simply be the sum of all of the energy functions for the guesses,
// which should be the same as the total hamming distance.
double expectedScore = 0.0;
for (Factor factor : fg.getFactors())
{
expectedScore += factor.getScore();
}
assertEquals(mutations, expectedScore, 0.0);
// clear the guesses on the non-leaves
for (VariableBlock block : internal)
{
for (Variable var : block)
{
var.setGuess(null);
// Add a random input to avoid ambiguous result - see BUG 408
var.setPrior(randomPrior());
}
}
// set fixed values on the leaves
for (VariableBlock block : leaves)
{
for (Variable var : block)
{
Discrete d = (Discrete)var;
d.setPriorIndex(d.getGuessIndex());
}
}
fg.solve();
// Set guesses from inferred best value
for (VariableBlock block : internal)
{
for (Variable var : block)
{
Discrete d = (Discrete)var;
double maxBelief = Double.NEGATIVE_INFINITY;
for (double belief : d.getBelief())
{
// Make sure BUG 408 will not affect the result
assertNotEquals(belief, maxBelief, 0.0);
if (belief > maxBelief)
maxBelief = belief;
}
d.setGuessIndex(d.getValueIndex());
}
}
double score = 0.0;
for (Factor factor : fg.getFactors())
{
score += factor.getScore();
}
if (expectedScore < score)
{
fail(String.format("Expected %f or better but got %f", expectedScore, score));
}
}
}
private double[] randomPrior()
{
double[] input = new double[4];
for (int i = 0; i < 4; ++i)
{
input[i] = 1000 + testRand.nextDouble();
}
return input;
}
}