/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.core;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.util.misc.Misc;
public abstract class ParameterEstimator
{
private FactorGraph _fg;
private IFactorTable [] _tables;
private Random _r;
private HashMap<IFactorTable,ArrayList<Factor>> _table2factors;
private boolean _forceKeep;
public ParameterEstimator(FactorGraph fg, IFactorTable [] tables, Random r)
{
_fg = fg;
_tables = tables;
_r = r;
HashMap<IFactorTable,ArrayList<Factor>> table2factors = new HashMap<IFactorTable, ArrayList<Factor>>();
for (Factor f : fg.getFactors())
{
IFactorTable ft = f.getFactorTable();
if (! table2factors.containsKey(ft))
table2factors.put(ft,new ArrayList<Factor>());
table2factors.get(ft).add(f);
}
//Verify directionality is consistent.
_table2factors = table2factors;
}
public void setRandom(Random r)
{
_r = r;
}
public HashMap<IFactorTable,ArrayList<Factor>> getTable2Factors()
{
return _table2factors;
}
public IFactorTable [] getTables()
{
return _tables;
}
IFactorTable [] saveFactorTables(IFactorTable [] fts)
{
IFactorTable [] savedFts = new IFactorTable[fts.length];
for (int i = 0; i < fts.length; i++)
savedFts[i] = fts[i].clone();
return savedFts;
}
IFactorTable [] unique(IFactorTable [] factorTables)
{
HashSet<IFactorTable> set = new HashSet<IFactorTable>();
for (int i = 0; i < factorTables.length; i++)
set.add(factorTables[i]);
factorTables = new IFactorTable[set.size()];
int i = 0;
for (IFactorTable ft : set)
{
factorTables[i] = ft;
i++;
}
return factorTables;
}
public FactorGraph getFactorGraph()
{
return _fg;
}
public void setForceKeep(boolean val)
{
_forceKeep = val;
}
public void run(int numRestarts, int numSteps)
{
//make sure the factortable list is unique
_tables = unique(_tables);
//measure betheFreeEnergy
_fg.solve();
double currentBFE = _fg.getBetheFreeEnergy();
IFactorTable [] bestFactorTables = saveFactorTables(_tables);
//for each restart
for (int i = 0; i <= numRestarts; i++)
{
//if not first time, pick random weights
if (i != 0)
for (int j = 0; j < _tables.length; j++)
{
_tables[j].randomizeWeights(_r);
if (_tables[j].isDirected())
_tables[j].normalizeConditional();
}
//for numSteps
for (int j = 0; j < numSteps; j++)
{
runStep(_fg);
}
_fg.solve();
double newBetheFreeEnergy = _fg.getBetheFreeEnergy();
//if betheFreeEnergy is better
//store this is answer
if (newBetheFreeEnergy < currentBFE || _forceKeep)
{
currentBFE = newBetheFreeEnergy;
bestFactorTables = saveFactorTables(_tables);
}
}
//Set weights to best answer
for (int i = 0; i < _tables.length; i++)
{
_tables[i].copy(bestFactorTables[i]);
}
}
public abstract void runStep(FactorGraph fg);
public static class BaumWelch extends ParameterEstimator
{
public BaumWelch(FactorGraph fg, IFactorTable[] tables, Random r)
{
super(fg, tables, r);
for (IFactorTable table : getTable2Factors().keySet())
{
ArrayList<Factor> factors = getTable2Factors().get(table);
int [] direction = null;
for (Factor f : factors)
{
if (f.getFactorTable() != table)
{
Misc.breakpoint();
}
int [] tmp = f.getDirectedTo();
if (tmp == null)
throw new DimpleException("Baum Welch only works with directed Factors");
if (direction == null)
direction = tmp;
else
{
if (tmp.length != direction.length)
throw new DimpleException("Directions must be the same for all factors sharing a Factor Table");
for (int i = 0; i < tmp.length; i++)
if (tmp[i] != direction[i])
throw new DimpleException("Directions must be the same for all factors sharing a Factor Table");
}
}
}
}
@Override
public void runStep(FactorGraph fg)
{
//run BP
fg.solve();
//Assign new weights
//For each Factor Table
for (IFactorTable ft : getTable2Factors().keySet())
{
//Calculate the average of the FactorTable beliefs
ArrayList<Factor> factors = getTable2Factors().get(ft);
double [] sum = new double[ft.sparseSize()];
for (Factor f : factors)
{
if (f.getFactorTable() != ft)
{
Misc.breakpoint();
}
double [] belief = (double[])requireNonNull(f.getSolver()).getBelief();
for (int i = 0; i < sum.length; i++)
sum[i] += belief[i];
}
//Get first directionality
Factor firstFactor = factors.get(0);
int [] directedTo = firstFactor.getDirectedTo();
int [] directedFrom = firstFactor.getDirectedFrom();
//Set the weights to that
ft.replaceWeightsSparse(sum);
if (directedTo != null && directedFrom != null)
{
ft.makeConditional(BitSetUtil.bitsetFromIndices(directedTo.length + directedFrom.length, directedTo));
}
}
}
}
}