/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import static java.util.Objects.*;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.util.misc.Matlab;
/**
* Deterministic matrix-vector product. This is a deterministic directed factor
* (if smoothing is not enabled).
*
* The constructor has two arguments that specify the length of the input and output
* vectors, respectively.
*
* Optional smoothing may be applied, by providing a smoothing value in the
* constructor. If smoothing is enabled, the distribution is smoothed by
* exp(-difference^2/smoothing), where difference is the distance between the
* output value and the deterministic output value for the corresponding inputs.
*
* The variables are ordered as follows in the argument list:
*
* 1) Output vector
* 2) Input matrix (outLength x inLength, scanned by columns [because MATLAB assumes this])
* 3) Input vector
*
*/
public class MatrixVectorProduct extends FactorFunction
{
protected int _inLength;
protected int _outLength;
protected double[][] _matrix;
protected double[] _inVector;
protected double[] _outVector;
protected double _beta = 0;
protected boolean _smoothingSpecified = false;
private final int _updateDeterministicLimit;
@Matlab
public MatrixVectorProduct(int inLength, int outLength) {this(inLength, outLength, 0);}
public MatrixVectorProduct(int inLength, int outLength, double smoothing)
{
super();
_inLength = inLength;
_outLength = outLength;
_matrix = new double[_outLength][_inLength];
_inVector = new double[_inLength];
_outVector = new double[_outLength];
if (smoothing > 0)
{
_beta = 1 / smoothing;
_smoothingSpecified = true;
_updateDeterministicLimit = 0;
}
else
{
// A full update requires inLength*outLength multiply/adds. An incremental update
// will take either 2 for changes to input matrix, or outLength*2 for changes to
// input vector. So for matrix input changes the limit should be inLength*outLength/2
// and for vector input changes the limit should be inLength/2. We will use the min of
// these two for now:
_updateDeterministicLimit = inLength / 2;
}
}
@Override
public final double evalEnergy(Value[] arguments)
{
// Compute the expected output
final Value[] expectedResult = evalDeterministicToCopy(arguments);
// Compare the output to the expected output
final int outLength = _outLength;
double error = 0;
for (int i = 0; i < outLength; i++)
{
final double diff = arguments[i].getDouble() - expectedResult[i].getDouble();
error += diff*diff;
}
if (_smoothingSpecified)
return error*_beta;
else
return (error == 0) ? 0 : Double.POSITIVE_INFINITY;
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices()
{
int[] indexList = new int[_outLength];
for (int i = 0; i < _outLength; i++)
indexList[i] = i;
return indexList;
}
@SuppressWarnings("deprecation")
@Override
public final @Nullable int[] getDirectedToIndicesForInput(Factor factor, int inputEdge)
{
final int outLength = _outLength;
final int inLength = _inLength;
final int nEdges = factor.getArgumentCount();
final int nInputEdges = nEdges - outLength;
final int matrixSize = outLength * inLength;
final int vectorSize = inLength;
final int matrixOffset = outLength;
int vectorOffset = matrixOffset;
if (nInputEdges == matrixSize + vectorSize ||
nInputEdges == matrixSize + 1 && factor.hasConstantAtIndexOfType(nEdges - 1, double[].class))
{
vectorOffset += matrixSize;
}
else if (nInputEdges == 2 ||
nInputEdges == vectorSize + 1 && factor.hasConstantAtIndexOfType(matrixOffset, double[][].class))
{
vectorOffset += 1;
}
else
{
throw new DimpleException("Bad number of edges %d for MatrixVectorProduct (inLength=%d, outLength=%d)",
nEdges, inLength, outLength);
}
if (inputEdge >= vectorOffset)
{
// Same as full output edges
return null;
}
else
{
return new int[] { (inputEdge - matrixOffset) % outLength };
}
}
@Override
public final boolean isDeterministicDirected() {return !_smoothingSpecified;}
@Override
public final void evalDeterministic(Value[] arguments)
{
int argIndex = _outLength; // Skip the outputs
final int inLength = _inLength;
final int outLength = _outLength;
double[][] matrix = _matrix;
// Get the matrix values
if (arguments[argIndex].getObject() instanceof double[][]) // Constant matrix is passed as a single argument
matrix = (double[][])requireNonNull(arguments[argIndex++].getObject());
else
{
for (int col = 0; col < inLength; col++)
for (int row = 0; row < outLength; row++)
matrix[row][col] = arguments[argIndex++].getDouble();
}
// Get the input vector values
double[] inVector = _inVector;
if (arguments[argIndex].getObject() instanceof double[]) // Constant matrix is passed as a single argument
inVector = arguments[argIndex++].getDoubleArray();
else
{
for (int i = 0; i < inLength; i++)
inVector[i] = arguments[argIndex++].getDouble();
}
// Compute the output
double[] outVector = _outVector;
for (int row = 0; row < outLength; row++)
{
double sum = 0;
final double[] rowValues = matrix[row];
for (int col = 0; col < inLength; col++)
sum += rowValues[col] * inVector[col];
outVector[row] = sum;
}
// Replace the output values
int outIndex = 0;
for (int i = 0; i < outLength; i++)
arguments[outIndex++].setDouble(outVector[i]);
}
@Override
public final int updateDeterministicLimit(int numEdges)
{
return _updateDeterministicLimit;
}
@Override
public final boolean updateDeterministic(Value[] values, Collection<IndexedValue> oldValues, AtomicReference<int[]> changedOutputsHolder)
{
final int inLength = _inLength;
final int outLength = _outLength;
final int matrixOffset = outLength;
final Object objAtMatrixOffset = values[matrixOffset].getObject();
final double[][] matrix = objAtMatrixOffset instanceof double[][] ? (double[][])objAtMatrixOffset : null;
final int vectorOffset = matrixOffset + (matrix != null ? 1 : inLength * outLength);
final Object objAtVectorOffset = values[vectorOffset].getObject();
final double[] vector = objAtVectorOffset instanceof double[] ? (double[])objAtVectorOffset : null;
final int minSupportedIndex = matrix == null ? matrixOffset : (vector == null ? vectorOffset : values.length);
final int maxSupportedIndex = vector == null ? values.length : (matrix == null ? vectorOffset : matrixOffset);
boolean incremental = false;
doIncremental:
{
for (IndexedValue old : oldValues)
{
final int changedIndex = old.getIndex();
if (changedIndex < matrixOffset || changedIndex >= values.length)
{
throw new IndexOutOfBoundsException();
}
if (changedIndex < minSupportedIndex || changedIndex >= maxSupportedIndex)
{
break doIncremental;
}
final double newInput = values[changedIndex].getDouble();
final double oldInput = old.getValue().getDouble();
if (newInput == oldInput)
{
continue;
}
if (changedIndex >= vectorOffset)
{
// Input vector variable changed
final int col = changedIndex - vectorOffset;
if (matrix != null)
{
for (int row = 0; row < outLength; ++row)
{
final Value outputValue = values[row];
final double oldOutput = outputValue.getDouble();
final double matrixValue = matrix[row][col];
outputValue.setDouble(oldOutput - matrixValue * oldInput + matrixValue * newInput);
}
}
else
{
int matrixIndex = matrixOffset + col * outLength;
for (int row = 0; row < outLength; ++row, ++matrixIndex)
{
final Value outputValue = values[row];
final double oldOutput = outputValue.getDouble();
final double matrixValue = values[matrixIndex].getDouble();
outputValue.setDouble(oldOutput - matrixValue * oldInput + matrixValue * newInput);
}
}
}
else
{
// Matrix value changed
int x = changedIndex - matrixOffset;
final int col = x / outLength;
final int row = x - col * outLength;
Value outputValue = values[row];
final double oldOutput = outputValue.getDouble();
final double inVectorVal = vector != null ? vector[col] : values[vectorOffset + col].getDouble();
outputValue.setDouble(oldOutput - inVectorVal * oldInput + inVectorVal * newInput);
}
}
incremental = true;
}
return incremental || super.updateDeterministic(values, oldValues, changedOutputsHolder);
}
}