/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import com.analog.lyric.dimple.data.DataStack;
import com.analog.lyric.dimple.data.PriorDataLayer;
import com.analog.lyric.dimple.data.ValueDataLayer;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.dimple.test.data.TestDataLayer;
/**
* Test of Gibbs with large discrete factor based on "Subtle motif" finding problem from
* Coursera Bioinformatics Algorithm course.
* <p>
* The problem is this: given a set of DNA strands, you are to find a length-k (in this case 15)
* subsequence from each -- known as a "motif" -- that optimizes a similarity score. This code simply
* uses a simple integer score representing the number of differences to a "consensus" sequence that
* consists of the most popular nucleotides at each position across the chosen motifs. Other approaches
* may attempt to minimize the entropy of the collection of the motifs, or maximize the KL-divergence
* between the motifs and the distribution of nucleotides in the strands if it is not uniform.
* <p>
* This test exercises ability to support large discrete factors that are too big for
* factor tables and Gibbs's use of the FactorFunction.updateEnergy method.
* <p>
* @since 0.08
* @author Christopher Barber
*/
public class TestDNAMotifFinder extends DimpleTestBase
{
@Test
public void test()
{
testCase(true);
testCase(false);
}
private void testCase(boolean useUpdateEnergy)
{
boolean print = false;
List<String> strands = Arrays.asList(_dna);
final int t = strands.size();
int[] motifs = findMotifOffsets(15, strands, 38, useUpdateEnergy);
MotifScoreFunction scoreFunction = new MotifScoreFunction(15, strands, useUpdateEnergy);
Value[] values = new Value[t];
for (int i = 0; i < t; ++i)
{
values[i] = Value.create(DiscreteDomain.range(0, strands.get(i).length()));
values[i].setInt(motifs[i]);
int expectedOffset = _motifOffsets[i];
int offset = motifs[i];
if (print)
{
System.out.format("Expected: %3d %s Actual: %3d %s\n",
expectedOffset, _dna[i].substring(expectedOffset, expectedOffset + 15),
offset, _dna[i].substring(offset, offset + 15)
);
}
}
double score = scoreFunction.evalEnergy(values);
if (print)
{
System.out.format("Score: %f\n", score);
}
assertTrue(score <= 39);
}
private static String[] _dna = new String[] {
"TCCGAACAACGAGTAGGCGTACTCACCGGCATGGCCGGATACACCGACCATCGCGGACGAGAAAGGGAGGGCTGAAATACAGACAGCGTACTGTATTAAGCAGAAACGAGAGGAGACAGATCTCATCCCTGGTGTGGTGGAACTGGAGGACTCGCCTCGGTGTGAGTCGTAAGGTGACCGACGATGAAATGCAAGTTCCAACGGCCAACAGCGCGTCAACAACAATGCGCACGTGTCGTAAACTGACGTGAGGTCCCCTTATAGCCCATGAAGAACTTTGACTCGCCTCCGGTAGCCGCTAGTTTTATGCGTGAATGTGCGTTATGCCAACTCAAATGTCTCGCAAGTCAATGAATCAACCGCGATTCTTTATTAACTTCATATCAGGCTAACAAGGACAGACAGCAACAAAGTTCTGCAAAACTGTTCCGGTCTCATCCCTAACTCTCTAACTGATAACAGTCTAACTTGCACCAAGAGTCGCTCGATCGACCAAAGAAATTACCGCCGCCTTGCAGTTCCGATGCCTGGAGTCCCCCTCCGTGTGAAGGTGATAAACCATTTGTCCAACAATGTTAGACAATAAACCACGTAAAAGGC",
"GGAACTAGCTTAAAAAATAGCAGGGTGTGCCTGATCCTTCCGGTGTTTAAGTAGAAGGCAGGACGGACAGAGTTCCATCCACAGAAGCATAGTTTGATCGTATTGGCGACAGGCTGATGCGAAGCTCCGCTCCAAACGAGAGAGATAAATGCATGCGGTTTGGCCTAAGGCGGGGGGGCAACCCGGCTTATCAATTAGCTAGCCTTGCTTTGGAACAAGGGCCAAGCGGGAGGTAAACTCTTCAGCCCGGGTGTCCCAGTAGCGCGATTTGGTGCTAGCCAGGTTTCGATCAAAAGGGGCTCTTGCAACGCTCTCTTCTAAAAATAAAATGCAATTAGTTGGCAGGGTTGATGAGTGTCGAATCCTTGCAAGCGAGATTCTTCCATGCAGTTGCAGCGGGGCGAGGCCAAAGAGCTCAGCTAGCTTGGGGACTCGCGCCCTGCTTATTCACCCTCGGTGCAACTAATCCTTACCGTGAATTTGGTAGATGTCCAAGCATTGTTCTTTATTAATGCACGTGTTAATAGGGGTAGACTATTCCCGTCCCGGCCTACGGTGTCAAAATCAAGTAGGCGCCGAGATTATTCTTGCATGCCTGTA",
"TGCGAACTAGTTTTCGCAACTTAACGTAGCGCGTGGGGCGTCCCTAGTGGCTCTGTCAAAGCAATTTGGTTCGTTTAGCTGTTATAGTTTTGGATCACAGCGAATAGAGTTAGTCTTGCTAGTCCGTAAATCAACGGACCGCGTCCCATTAAACACAGCTTCGTCGAGTCTATGACGCTCATACTCTACCATGACCGCGCCGGGACGACCGCCAACTCATAAATGAACGCCTAATAGAACCGAAAAGGGTCGGCGGCACAAAACTCCGGAACGTGGTCTGGGTTAACAAAGGCGCGATGATATTGTTCGTAGATCCCTGTTGGACTCTCCAACAAGTTTCCCGGAGGACTCGAGGTTCCAGGCCGAGTAAATAAAAGTTTTCTCGGGGTGGTGCCGGAAGGCGGGAAGTGGTGGTTAGGACAGATAATGACGAAAACAATGGATCGTGGAAGAGATCGCCCAGAGGTTCGATAGGATGTTACGCTACTTGTGTTCGAGGGGGAGACGGTTTCTACCTAGGCGGGTACCACAAAGCTGTTCTCTATTCTGGAAATTATGTACTCTGTTACTTGAATAAAATAAACAGCGGGGTACGCGGAT",
"ATCCTGACTACGGCGGTTTTCGTCTTGGGTAGGCACGGAGCTAGAGTATACACGGCAGCTCGTAGGGGGTCGATGCGTCTCGATTAGCTCGTTCCTATAGCTCAGCGATATCCCCGGGTTAAGAAGATTGCTCTCGTTACGCACTAGCCTCCGACTCGCGGGGCGTAACCAGTACAGTAAAAGACGCTAGAATCGACGCTTTCGCATAGTAGTCATTTAGAACCCGGGCTTAAACGATCGTACTTGATACACCCCGGGAGATGTGGATACCATTAAGTTAACCAGATCTATATGCGACCAGTCCTGCAGTAAAGATTGGCTGTCTTGGACTTGTATGCAAGCATAATCAGGGCAGAGGCAGTGGTCCGTTGCCTGAGGACGTCAAGAGTTCTCAGTCTAAAGTATTCCGGGGAAATAGTTAGTTGGCATAAGTCCGCCAAAGATCGCAGATGGTTAGTAGGTAACACTGGGGCCCTCCAGCTTAAGCCAAGCTAACTACGCTCAAGCAGGCTTTTTTTTTATGTTGAACAGAAAAAAAGGGGTTTTCACGCACACTTAGCCCTTTCTACGTAAGAGTCATTCTCAATACTGATGTCAGGA",
"AATTATACATAGGTGTGACTCTATGCTCGGCTATGGAAATAAGGTTCGCGCCGACACCTATGAAGAATTGTCACCCATGTTTTTGTGTCTATCAGCTTTGAGTGAGATTTGGTTTTCACGGGAGAAAGAGGATGTTCTCTGCGTGCGGACTCCTGAGACTTTGCTGAATGATGATGTAGCGGATCCACGAGGAACTGAGGTCCCGCAGCTCCGAGACAGGTGCTGATGCTTTGGCAACGATTTGAGGGCACAATTCCCGAGTACCTAGGATGGTATTCTGTATTGATTGGTTTTTGAAATGTGCTTGATTCGAACCAAGCGAGCAATTGACAAACGCTGTGCCTAGGTATACCTAAAATAAAACTGCGACAGTTGATCAAACATAAAGTAGAGGGGGTCCAAGTATCCATGAGTGATGCTTAGCACACCCTGCTCCCTGGACTTTTGGATTACCCCCTTCTAGCTTGCTTCTAGCTCAAGCTAAGACCTACCCCAATAAGAGGTAGCTAAGAACGGGGTCTGGGCAGTCATCAACGCCCGTGATCGTAAATCGGTCGTCCCACCGCACTCGCCGCGAATTACGAATAGCCATAGATGAGC",
"TCTAAAAATGGGGCGGCCAGTGAATAAAGCCTGCGCGTATTCGTAGCTGTTTACTCGGGAGACCGGCGCCCGAACAGCGCCCTGCCTAACGCCAGCTTACACCGATAGACGAACACGGTTGGGCTGATATACGTCGAACCTGCCTAACCTTAATACTTTCCCTAGTCAGAAGTTGGCCCGAACTTAAGCGTTCGAATGTAGGAGGACTATGAGAGCAAAGCGCGCGCCCGGTCATTTGCACAGAATTCACGTATGTAGTGTAGAGGCGAGACGGGTTTGTCGCGTACACTGCAGACCCAACAGTTTTACGGCAACACAATATCCGTCCAGCCGTAATACGAGCGCAAAGCACGTAGGGTCATCTGGCTAAAGAATTAGGCGCCACTCATTTTGACGGAGAGCGCTTTGCGATCAGATCAGTGGAGTCCAGATTTGATTGTAACTCACTTACCGCACGGCAACAACGCTCATTCCCGCTAATGTATGAGGTACAGGTTGCACTGGTCAGTTTAATGAAGGTCATAGAACACGGGTTTACGTGAATGCGTGTCGCCATCCTCGGCCGAAATGATGAGTTGCCAGGACCGATCTGGCGCCAGC",
"AGGTAAGGCTCGCCTCTACATCTCCGTACAAACTATCAGACGTAAAGAAAGCTGGAGGATTGCCAGCGAAAAGTACATAACACAAAGAACAAAAAGAGAAGGGGGTACGGGCTATTCGATCTAGATGGAGGCTAGGCAATAGAAGTTCGATCATCCATGGTAACTAGATATATGCTGAGAGCAAACGATCCCTAGTACCGCCTGTGTTATATGCCACCAATCTTTCTTCAGTTAGAAACCTCATTGTCGGGCGACACCAGGTCGATTCAAGAGGCGAGAGCCCATATGCTCGACCTATGGCGTGAACGCTAAGCGGCTGGAGCAAGAGAGGTGTATCCAACGACGGTTTTGAATTTACAATTCAGCCCACTGATATAAGCTGTATGGACTGACTCTGGAGGGACGCGCTGATATCTAAGGGCTTCGCGTACTAGGGTCACTACGGAAGCCATCGGCACTGTGCATCTTACAAAACGGACGTCCTTGACGGCCCTATGACCTTAGCACAAACGAATTGATGACCGAATGTACAGTACTTTGTGCTGGCTGAGCACTCCCTACCACGATCCGGCCAGCCGATCTGCGTCGAGGCTGCCACGC",
"AAGCTCAGCTAACTAGGCGTGAATAATAACGGAACACCTTAGGTAATGTTGGGGTCCTTACCACCATTTTACGTGGATCTCTAGACGGGCAGCACAAGCAGACGCTCAACGTAGTAATGCCAAGAAGAGATCTACTCCTGTGTTCACTTACATATATTCCCACTCAGAACCGCGTCTTCTGAACTGAGGAAGAAGTTACACTAACTGCACGAGATACCGGATCTGCACCTAGCCTGCTAGGCGTGGCACACGTAGCGCACCCTCACGGCTGCAATGGAATTTGCACAAAAACCAGCGCGTGGCGGATATTCCTCGTTTACAGAGTGGGTTGGAACATCCGGCGGTCCCGAGAGAACCGTCTTTCCGGTCGCCCATTTTATCAAAGATTGCAGTCTACTTGCCCGTATTCCTTGAGATGATTCGAAGGTCGAAATCGTAGCACATGGCTAACAATCCTGTTATTTATGCAGTAGCCGCGCCGCTTAGACGGCTTACCCCCGATATAGGGGAGCCCACCAGCTATGCCCTGGAAGGGACGATAAATAGCGTTGTGATTTATGATACCTTCACCAGCTTCGTACGTGCATAGAAAAGGAAGGG",
"TTCGTATCTTTCTCGGCGCCCTGATTCCAGTGATGGATTGTGAGGTCACTTCAAGTGAGATGTGTATTCCCAGCCAATCTATCCGTGTTAACTGATCCTAAACAGAGTGTGCCCAGATTAATGGGAACCCCAGTGTCAAGCGGGCCCTTAACACGGCCTGGTTAGATTCGTTTTAAGTGGGTCCTCTAACTCCTAACATTTTGACTTAAGGGTTTAACCGCTGACAGGCAGTAGCAACGGCTGTAGGGGAACACGAGGTTTTTTAATAAGTCTTGCAGTTTCATGCGGTTCTCACCAGAACGTTATAATCGCGAGTGCCCCGCTCAGGAATAGGATCAATGACGATTCTTATATCTCCGGAATTATGGTTACAGCTTCGTCAACGGCCTCAGGGTCGGGTTTTAAGCGGGGCCCGTTATCCAGAAATTACCCTGCATGGCAGGTCTACGCTAAAGTCCGAGCAAGAAAAAAGAGAGGAGTTTGCCCACGTGCCGCACACCCGGAGCTAGTCAGCATTGGTCTTCGAGAGATGCTCGCTGGACTCGGTTCATCTACTCGATCTAATTTTATGGCCGCCAACCATCAAAACGTATGACCTAA",
"TCTCATCCGTAGATTTAGTCCGGAGTGTTGAACAGCCCTCGGAGGTGCTACTAGCAATCACGAGATGCTAACGAGGAATATTTGGGATAGACGGTTCCTTCATGTTGTTCTGGGTACGCACTGCCGGCGAGTACCCCAGTGCCGAAACCGGTAAGAGTAAGTTCCTTAGGTTACGAGATTCCAGGCTTTTTGGGTAAGCGAGACCTACCCACTTGTTGCATCTACCCGTGTCTGTCAATCGCTGACTAGAACTGGTATCACGAGAAGAGAAACTTTCGATCTGTGCCCCATCAGTACCGAAGTTTGGTATAAATCGATGTGATATCCAAGACATGGAATAGCTTTCGCTCTTACGAGAGCATATGAAGGTTGCAACTAATTACCTATCTGATGTACGAAATTCAAGCTAAAGGGGGGTCAATCTCGTCCGAGTGCGACGGGGCAATAGCCCGGTACGATCTCCCATTTTCCCTTCCGGTACTCTACTGCTTTGGCGGGGTCGAGTTATCCGTGCGAACATTCAACCACCTCTGAGAACGGGGCCATAATGAACTGTGATCTTGATTCTACCTAAACACGCAGGACCAAAGCCTTCGCCGA"
};
// The
private static int[] _motifOffsets = new int[] { 56, 11, 576, 530, 382, 1, 90, 585, 465, 403 };
private int[] findMotifOffsets(int motifLength, List<String> strands, int cutoffScore, boolean useUpdateEnergy)
{
final int t = strands.size();
FactorGraph fg = new FactorGraph();
DiscreteDomain[] motifDomains = new DiscreteDomain[t];
Discrete[] motifVariables = new Discrete[t];
for (int i = 0; i < t; ++i)
{
DiscreteDomain domain = DiscreteDomain.range(0, strands.get(i).length() - motifLength);
motifDomains[i] = domain;
Discrete var = new Discrete(domain);
var.setName("motif-" + i);
motifVariables[i] = var;
}
MotifScoreFunction scoreFunction = new MotifScoreFunction(motifLength, strands, useUpdateEnergy);
fg.addFactor(scoreFunction, motifVariables);
GibbsSolverGraph gibbs = requireNonNull(fg.setSolverFactory(new GibbsSolver()));
GibbsDiscrete[] gibbsMotifVars = new GibbsDiscrete[t];
for (int i = 0; i < t; ++i)
{
gibbsMotifVars[i] = (GibbsDiscrete)gibbs.getSolverVariable(motifVariables[i]);
}
// Use seed for repeatable results. This value was chosen because it happens to hit the
// answer with fewer restarts so the test won't run too long.
gibbs.setSeed(15);
fg.setOption(GibbsOptions.numRandomRestarts, 2000);
fg.setOption(GibbsOptions.numSamples, 20);
fg.setOption(GibbsOptions.enableAnnealing, true);
fg.initialize();
for (int restart = 0; restart < 100; ++restart)
{
gibbs.burnIn(restart);
if (restart == 0)
{
// Test for infinite energy annealing bug 403.
// Instead of running many samples to drive energy to infinity, just start with a really low
// temperature.
double temperature = gibbs.getTemperature();
gibbs.setTemperature(1e-300);
gibbs.sample(20);
gibbs.setTemperature(temperature);
}
else
{
gibbs.sample(20);
}
double sampleScore1 = gibbs.getSampleScore();
ValueDataLayer dataLayer = gibbs.getSampleLayer();
TestDataLayer.assertInvariants(dataLayer);
for (Variable var : fg.getVariables())
{
assertNotNull(dataLayer.get(var));
}
DataStack dataStack = new DataStack(new PriorDataLayer(fg), dataLayer);
assertEquals(sampleScore1, dataStack.computeTotalEnergy(), 1e-15);
if (gibbs.getBestSampleScore() <= cutoffScore)
{
// System.out.println("required restarts: " + restart);
break;
}
}
int[] bestMotifs = new int[t];
for (int i = 0; i < t; ++i)
{
GibbsDiscrete svar = requireNonNull((GibbsDiscrete) gibbs.getSolverVariable(motifVariables[i]));
bestMotifs[i] = svar.getBestSampleIndex();
}
return bestMotifs;
}
private static class MotifScoreFunction extends FactorFunction
{
final boolean _useUpdateEnergy;
private final int _motifLength;
private final List<String> _strands;
private final int[][] _countsByNucleotide;
private int _nIncremented;
private MotifScoreFunction(int motifLength, List<String> strands, boolean useUpdateEnergy)
{
_motifLength = motifLength;
_strands = strands;
_useUpdateEnergy = useUpdateEnergy;
_countsByNucleotide = new int[4][];
for (int i = 0; i < 4; ++i)
{
_countsByNucleotide[i] = new int[motifLength];
}
}
@Override
public double evalEnergy(Value[] values)
{
resetCounts();
for (int i = 0, end = values.length; i < end; ++i)
{
int offset = values[i].getInt();
incrementCount(_strands.get(i), offset);
}
return computeScore();
}
@Override
public boolean useUpdateEnergy(Value[] values, int nChangedValues)
{
return _useUpdateEnergy;
}
@Override
public double updateEnergy(Value[] values, IndexedValue[] oldValues, double oldEnergy)
{
for (IndexedValue oldValue : oldValues)
{
final int i = oldValue.getIndex();
final String strand = _strands.get(i);
final int oldOffset = oldValue.getValue().getIndex();
final int newOffset = values[i].getIndex();
decrementCount(strand, oldOffset);
incrementCount(strand, newOffset);
}
return computeScore();
}
private int charToIndex(char c)
{
switch (c)
{
case 'A': return 0;
case 'C': return 1;
case 'G': return 2;
case 'T': return 3;
default: throw new RuntimeException("Bad character " + c);
}
}
private void decrementCount(String strand, int start)
{
--_nIncremented;
for (int i = 0; i < _motifLength; ++i)
{
--_countsByNucleotide[charToIndex(strand.charAt(start+i))][i];
}
}
private void incrementCount(String strand, int start)
{
++_nIncremented;
for (int i = 0; i < _motifLength; ++i)
{
++_countsByNucleotide[charToIndex(strand.charAt(start+i))][i];
}
}
private void resetCounts()
{
_nIncremented = 0;
for (int i = 0; i < 4; ++i)
{
Arrays.fill(_countsByNucleotide[i], 0);
}
}
private int computeScore()
{
int score = 0;
for (int i = 0; i < _motifLength; ++i)
{
int max = 0;
for (int nucleotide = 0; nucleotide < 4; ++nucleotide)
{
max = Math.max(max, _countsByNucleotide[nucleotide][i]);
}
score += _nIncremented - max;
}
return score;
}
}
}