/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions.core;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.util.misc.Internal;
import com.google.common.cache.AbstractLoadingCache;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import net.jcip.annotations.Immutable;
/**
* Represents the product of multiple factor functions over subsets of the variables.
* <p>
* Used for joining factors.
* <p>
* @since 0.05
* @author Christopher Barber
*/
@Internal
public class JointFactorFunction extends FactorFunction
{
/*-------
* State
*/
/**
* @since 0.05
*/
@Immutable
@Internal
public static class Functions extends AbstractList<Tuple2<FactorFunction, int[]>>
{
private final Tuple2<FactorFunction,int[]>[] _functions;
private final int _hashCode;
/*--------------
* Construction
*/
public Functions(List<Tuple2<FactorFunction, int[]>> functions)
{
_functions = functions.toArray(new Tuple2[functions.size()]);
// The order of the functions doesn't really matter, so sort by function name.
// Since there could be duplicate or empty names, this does not really provide
// a canonical ordering, but it is better than nothing.
Arrays.sort(_functions, new Comparator<Tuple2<FactorFunction, int[]>> () {
@Override
@NonNullByDefault(false)
public int compare(Tuple2<FactorFunction, int[]> f1, Tuple2<FactorFunction, int[]> f2)
{
return f1.first.getName().compareTo(f2.first.getName());
}
});
_hashCode = Arrays.hashCode(_functions);
}
/*----------------
* Object methods
*/
@Override
public boolean equals(@Nullable Object other)
{
if (other == this)
{
return true;
}
if (other instanceof Functions)
{
final Functions that = (Functions)other;
return this._hashCode == that._hashCode && Arrays.equals(this._functions, that._functions);
}
return false;
}
@Override
public int hashCode()
{
return _hashCode;
}
/*--------------
* List methods
*/
@Override
public Tuple2<FactorFunction, int[]> get(int index)
{
return _functions[index];
}
@Override
public int size()
{
return _functions.length;
}
}
private final Functions _functions;
private final int _newNumInputs;
@NonNullByDefault(false)
private static class Loader extends CacheLoader<Functions, JointFactorFunction>
{
private static final Loader INSTANCE = new Loader();
@Override
public JointFactorFunction load(Functions functions) throws Exception
{
return new JointFactorFunction(functions);
}
}
/*--------------
* Construction
*/
/**
* @since 0.05
*/
@Internal
public JointFactorFunction(Functions functions)
{
this(buildName(functions), functions);
}
/**
* @since 0.05
*/
@Internal
public JointFactorFunction(String name, Functions functions)
{
super(name);
_functions = functions;
int maxIndex = 0;
for (Tuple2<FactorFunction, int[]> tuple : functions._functions)
{
for (int index : tuple.second)
{
maxIndex = Math.max(maxIndex, index);
}
}
_newNumInputs = maxIndex + 1;
}
private static String buildName(Functions functions)
{
StringBuilder builder = new StringBuilder();
for (int i = 0, end = functions._functions.length; i < end; ++i)
{
if (i > 0)
{
builder.append("+");
}
builder.append(functions._functions[i].first.getName());
}
return builder.toString();
}
private static class Cache extends AbstractLoadingCache<Functions,JointFactorFunction>
{
private static ConcurrentMap<Functions, JointFactorFunction> _map =
new ConcurrentHashMap<Functions, JointFactorFunction>();
@Override
public JointFactorFunction get(@Nullable Functions key) throws ExecutionException
{
JointFactorFunction function = _map.get(key);
if (function == null)
{
try
{
_map.putIfAbsent(key, Loader.INSTANCE.load(key));
}
catch (Exception ex)
{
throw new RuntimeException(ex);
}
function = _map.get(key);
}
return function;
}
@Override
public JointFactorFunction getIfPresent(@Nullable Object key)
{
return _map.get(key);
}
}
@Internal
public static LoadingCache<Functions,JointFactorFunction> createCache()
{
// FIXME: We cannot use Guava's CacheBuilder when run from MATLAB because MATLAB's static Java class
// path includes an ancient version of the Guava Objects class that is incompatible with the one
// needed by CacheBuilder. <Grrrr>
// return CacheBuilder.newBuilder().build(Loader.INSTANCE);
return new Cache();
}
@Internal
public static JointFactorFunction getFromCache(LoadingCache<Functions,JointFactorFunction> cache, Functions key)
{
try
{
return cache.get(key);
}
catch (ExecutionException ex)
{
throw new RuntimeException(ex);
}
}
/*------------------------
* FactorFunction methods
*/
@Override
public double evalEnergy(Value[] input)
{
//Make sure length of inputs is correct
if (input.length != _newNumInputs)
throw new DimpleException("expected " + _newNumInputs + " args but got " + input.length);
double energy = 0.0;
for (Tuple2<FactorFunction, int[]> tuple : _functions._functions)
{
final FactorFunction function = tuple.first;
final int[] inputIndicesForFunction = tuple.second;
// TODO: Use a cache of reusable array objects instead of allocating every time.
energy += function.evalEnergy(ArrayUtil.copyFromIndices(input, inputIndicesForFunction));
}
return energy;
}
@Override
public double evalEnergy(Object... input)
{
//Make sure length of inputs is correct
if (input.length != _newNumInputs)
throw new DimpleException("expected " + _newNumInputs + " args");
double energy = 0.0;
for (Tuple2<FactorFunction, int[]> tuple : _functions._functions)
{
final FactorFunction function = tuple.first;
final int[] inputIndicesForFunction = tuple.second;
// TODO: Use a cache of reusable array objects instead of allocating every time.
energy += function.evalEnergy(ArrayUtil.copyFromIndices(input, inputIndicesForFunction));
}
return energy;
}
@Override
protected IFactorTable createTableForDomains(JointDomainIndexer domains)
{
final int nFunctions = _functions.size();
final ArrayList<Tuple2<IFactorTable,int[]>> tables = new ArrayList<Tuple2<IFactorTable,int[]>>(nFunctions);
final DiscreteDomain[] domainArray = domains.toArray(new DiscreteDomain[domains.size()]);
for (int i = 0; i < nFunctions; ++i)
{
final Tuple2<FactorFunction,int[]> functionTuple = _functions.get(i);
final FactorFunction function = functionTuple.first;
final int[] indices = functionTuple.second;
final JointDomainIndexer factorDomains =
JointDomainIndexer.create(ArrayUtil.copyFromIndices(domainArray, indices));
final IFactorTable factorTable = function.createTableForDomains(factorDomains);
tables.add(Tuple2.create(factorTable, indices));
}
return Objects.requireNonNull(FactorTable.product(tables));
}
}