/*******************************************************************************
* Copyright 2012-2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.factors;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.List;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.factorfunctions.core.TableFactorFunction;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.variables.Constant;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.IConstantOrVariable;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.util.misc.Internal;
/**
* A factor with only {@link Discrete} variables.
*/
public class DiscreteFactor extends Factor
{
/*-------
* State
*/
private @Nullable JointDomainIndexer _domainList = null;
private @Nullable JointDomainIndexer _factorArgumentDomains = null;
private @Nullable IFactorTable _factorTable = null;
/*--------------
* Construction
*/
@Internal
public DiscreteFactor(FactorFunction factorFunc)
{
super(factorFunc);
}
protected DiscreteFactor(DiscreteFactor that)
{
super(that);
_domainList = that._domainList;
}
@Override
public DiscreteFactor clone()
{
return new DiscreteFactor(this);
}
/*--------------
* Node methods
*/
@Override
public Discrete getConnectedNodeFlat(int i)
{
return (Discrete)super.getConnectedNodeFlat(i);
}
@Override
public Discrete getSibling(int i)
{
return (Discrete)super.getSibling(i);
}
@Override
protected void notifyConnectionsChanged()
{
_domainList = null;
_factorArgumentDomains = null;
_factorTable = null;
}
/*----------------
* Factor methods
*/
@Override
public JointDomainIndexer getDomainList()
{
JointDomainIndexer domainList = _domainList;
if (domainList == null)
{
int numVariables = getSiblingCount();
DiscreteDomain[] domains = new DiscreteDomain[numVariables];
for (int i = 0; i < numVariables; i++)
{
domains[i] = getSibling(i).getDomain();
}
domainList = JointDomainIndexer.create(getDirectedTo(), domains);
_domainList = domainList;
}
return domainList;
}
@Override
public JointDomainIndexer getArgumentDomains()
{
JointDomainIndexer domainList = _factorArgumentDomains;
if (domainList == null)
{
super.getArgumentDomains();
int numArgs = getArgumentCount();
DiscreteDomain[] domains = new DiscreteDomain[numArgs];
for (int i = 0; i < numArgs; i++)
{
IConstantOrVariable arg = getArgument(i);
if (arg instanceof Constant)
{
// Make a single-element discrete domain for the constant value.
domains[i] = DiscreteDomain.create(requireNonNull(((Constant)arg).value().getObject()));
}
else
{
domains[i] = ((Discrete)arg).getDomain();
}
}
domainList = JointDomainIndexer.create(getDirectedTo(), domains);
_factorArgumentDomains = domainList;
}
return domainList;
}
@Override
public IFactorTable getFactorTable()
{
IFactorTable table = _factorTable;
if (table == null)
{
final FactorFunction func = getFactorFunction();
if (func instanceof TableFactorFunction)
{
final TableFactorFunction tableFunc = (TableFactorFunction)func;
table = tableFunc.getFactorTable();
if (hasConstants() && table.getDimensions() != getSiblingCount())
{
// Convert table to get rid of the constant dimensions
table.convert(JointDomainReindexer.createRemover(table.getDomainIndexer(), getConstantIndices()));
}
}
else
{
table = func.getFactorTable(this);
}
_factorTable = table;
}
return table;
}
@Override
public void setFactorFunction(FactorFunction function)
{
super.setFactorFunction(function);
_factorTable = null;
}
public int[][] getPossibleBeliefIndices()
{
return requireSolver("getPossibleBeliefIndices").getPossibleBeliefIndices();
}
@Override
public boolean isDiscrete()
{
return true;
}
@Override
public void replaceVariablesWithJoint(Variable [] variablesToJoin, Variable newJoint)
{
assertNotFrozen();
//Support a mixture of variables referred to in this factor and previously not referred to in this factor
List<? extends Variable> ports = getSiblings();
ArrayList<Variable> newVariables = new ArrayList<Variable>();
//First we figure out which variables are not currently referred to in this factor.
for (int i = 0; i < variablesToJoin.length; i++)
{
boolean exists = false;
for (int j = 0; j < ports.size(); j++)
if (getConnectedNodeFlat(j).equals(variablesToJoin[i]))
{
exists = true;
break;
}
if (!exists)
{
newVariables.add(variablesToJoin[i]);
}
}
//Next we figure out the domain lengths of all the new variables
DiscreteDomain [] newDomains = new DiscreteDomain[newVariables.size()];
for (int i = 0; i < newDomains.length; i++)
newDomains[i] = ((Discrete)newVariables.get(i)).getDiscreteDomain();
//Now, we modify the combo table to include the new variables.
if (newDomains.length > 0)
{
//getFactorFunction();
IFactorTable newTable = getFactorTable().createTableWithNewVariables(newDomains);
setFactorFunction(TableFactorFunction.forFactor(this, newTable));
for (Variable v : newVariables)
{
addEdge(this, v);
}
}
//Now get the indices of all the variables
int [] factorVarIndices = new int[variablesToJoin.length];
int [] indexToJointIndex = new int[variablesToJoin.length];
//Figure out which are the new variables and store a mapping
int index = 0;
for (int i = 0; i < ports.size(); i++)
{
for (int j = 0; j < variablesToJoin.length; j++)
{
if (getConnectedNodeFlat(i).equals(variablesToJoin[j]))
{
factorVarIndices[index] = i;
indexToJointIndex[j] = index;
index++;
break;
}
}
}
//Get all the domain lengths
DiscreteDomain [] allDomains = new DiscreteDomain[ports.size()];
for (int i = 0; i < allDomains.length; i++)
allDomains[i] = getConnectedNodeFlat(i).getDiscreteDomain();
//Create the new combo table
IFactorTable newTable2 = getFactorTable().joinVariablesAndCreateNewTable(
factorVarIndices,
indexToJointIndex,
allDomains,
((Discrete)newJoint).getDiscreteDomain());
setFactorFunction(new TableFactorFunction(getFactorFunction().getName(),newTable2));
//Remove old edges in descending order (they were added in ascending order above)
for (int i = factorVarIndices.length; --i>=0;)
{
EdgeState edge = getSiblingEdgeState(factorVarIndices[i]);
removeSiblingEdge(edge);
}
//Add the new joint variable
addEdge(this, newJoint);
}
public String getFactorTableString()
{
String s = "TableFactor [" + getLabel() + "] " + getFactorTable().toString();
return s;
}
public double [] getBelief()
{
return (double[])requireSolver("getBelief").getBelief();
}
void setFactorTable(IFactorTable table)
{
assertNotFrozen();
setFactorFunction(TableFactorFunction.forFactor(this, table));
_factorTable = table;
if (!hasConstants())
{
// If there are no domains, then this is the same as the table's
_factorArgumentDomains = table.getDomainIndexer();
}
}
}