/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.benchmarks.hmm;
import java.util.Random;
import com.analog.lyric.benchmarking.Benchmark;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolver;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
@SuppressWarnings({"null", "deprecation"})
public class hmmBenchmark
{
private static final Random rng = new Random(0);
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmGibbs100x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new GibbsSolver());
fg.setOption(GibbsOptions.numSamples, 600000); // Aiming for ~1s execution time
int stages = 100;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmGibbs100000x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new GibbsSolver());
fg.setOption(GibbsOptions.numSamples, 300); // Aiming for ~1s execution time
int stages = 100000;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmGibbs1000x1000x1000()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new GibbsSolver());
fg.setOption(GibbsOptions.numSamples, 1500); // Aiming for ~1s execution time
int stages = 1000;
int stateDomainOrder = 1000;
int observationDomainOrder = 1000;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct100x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 1200); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 100;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct100x4x4Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 1200); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 100;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct100000x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 240); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 100000;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct100000x4x4Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 240); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 100000;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct1000x1000x1000()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 3); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 1000;
int stateDomainOrder = 1000;
int observationDomainOrder = 1000;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmSumProduct1000x1000x1000Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new SumProductSolver());
fg.setOption(BPOptions.iterations, 3); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 1000;
int stateDomainOrder = 1000;
int observationDomainOrder = 1000;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum100x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 6000); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 100;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum100x4x4Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 6000); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 100;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum100000x4x4()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 360); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 100000;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum100000x4x4Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 360); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 100000;
int stateDomainOrder = 4;
int observationDomainOrder = 4;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum1000x1000x1000()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 3); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
int stages = 1000;
int stateDomainOrder = 1000;
int observationDomainOrder = 1000;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@Benchmark(warmupIterations = 0, iterations = 1)
public boolean hmmMinSum1000x1000x1000Optimized()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.iterations, 3); // Aiming for ~1s execution time
fg.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
int stages = 1000;
int stateDomainOrder = 1000;
int observationDomainOrder = 1000;
hmmInference(fg, stages, stateDomainOrder, observationDomainOrder);
return false;
}
@SuppressWarnings("unused")
private void hmmInference(FactorGraph fg, int stages, int stateDomainOrder,
int observationDomainOrder)
{
HmmGraph hmm = new HmmGraph(fg, stages, stateDomainOrder,
observationDomainOrder);
fg.solve();
Integer v0 = (Integer) hmm.getStates()[1].getValue();
double score = fg.getScore();
}
private static class HmmGraph
{
private final Discrete[] _states;
private final Discrete[] _observations;
public HmmGraph(FactorGraph fg, int stages, int stateDomainOrder,
int observationDomainOrder)
{
DiscreteDomain stateDomain = DiscreteDomain.range(0,
stateDomainOrder - 1);
DiscreteDomain observationDomain = DiscreteDomain.range(0,
observationDomainOrder - 1);
_states = new Discrete[stages];
_observations = new Discrete[stages];
IFactorTable stateToStateTransitionFactorTable = randomFactorTable(
stateDomain, stateDomain);
IFactorTable stateToObservationTransitionFactorTable = randomFactorTable(
stateDomain, observationDomain);
for (int i = 0; i < _states.length; i++)
{
_states[i] = new Discrete(stateDomain);
_observations[i] = new Discrete(observationDomain);
fg.addFactor(stateToObservationTransitionFactorTable,
_states[i], _observations[i]);
if (i > 0)
{
fg.addFactor(stateToStateTransitionFactorTable,
_states[i - 1], _states[i]);
}
}
}
private IFactorTable randomFactorTable(DiscreteDomain... domains)
{
IFactorTable result = FactorTable.create(domains);
double[] weights = new double[result.getDomainIndexer()
.getCardinality()];
for (int i = 0; i < weights.length; i++)
{
weights[i] = rng.nextDouble();
}
result.setWeightsDense(weights);
return result;
}
@SuppressWarnings("unused")
public Discrete[] getObservations()
{
return _observations;
}
public Discrete[] getStates()
{
return _states;
}
}
}