/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import static java.util.Objects.*;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicReference;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import cern.colt.map.OpenIntIntHashMap;
/**
* Deterministic matrix product. This is a deterministic directed factor
* (if smoothing is not enabled).
* <p>
* The constructor has three arguments that specify the sizes of the input and output
* matrices. The first two are the number of rows and columns, respectively, of the
* first input matrix. The third is the number of columns of the second input matrix.
* The number of rows of the second input matrix must equal the number of columns of the
* first input matrix.
* <p>
* Optional smoothing may be applied, by providing a smoothing value in the
* constructor. If smoothing is enabled, the distribution is smoothed by
* exp(-difference^2/smoothing), where difference is the distance between the
* output value and the deterministic output value for the corresponding inputs.
* <p>
* The variables are ordered as follows in the argument list:
* <ol>
* <li>Output matrix (Nr x Nc, scanned by columns [because MATLAB assumes this])
* <li>Input matrix 1 (Nr x Nx, scanned by columns [because MATLAB assumes this])
* <li>Input matrix 2 (Nx x Nc, scanned by columns [because MATLAB assumes this])
* </ol>
* @since 0.05
*/
public class MatrixProduct extends FactorFunction
{
protected int _Nr;
protected int _Nx;
protected int _Nc;
protected double[][] _in1;
protected double[][] _in2;
protected double[][] _out;
protected double _beta = 0;
protected boolean _smoothingSpecified = false;
private final int _updateDeterministicLimit;
public MatrixProduct(int Nr, int Nx, int Nc) {this(Nr, Nx, Nc, 0);}
public MatrixProduct(int Nr, int Nx, int Nc, double smoothing)
{
super();
_Nr = Nr;
_Nx = Nx;
_Nc = Nc;
_in1 = new double[Nr][Nx];
_in2 = new double[Nx][Nc];
_out = new double[Nr][Nc];
if (smoothing > 0)
{
_beta = 1 / smoothing;
_smoothingSpecified = true;
_updateDeterministicLimit = 0;
}
else
{
// A full update costs Nr*Nx*Nc multiply/adds. An incremental update will cost either
// 2*Nr or 2*Nc depending on which input matrix contains the changed variable.
_updateDeterministicLimit = (Nr * Nx * Nc) / (2 * Math.max(Nr, Nc));
}
}
@Override
public final double evalEnergy(Value[] arguments)
{
// Compute the expected output
final Value[] expectedResult = evalDeterministicToCopy(arguments);
// Compare the output to the expected output
final int numOutputArguments = _Nr * _Nc;
double error = 0;
for (int i = 0; i < numOutputArguments; i++)
{
final double diff = arguments[i].getDouble() - expectedResult[i].getDouble();
error += diff*diff;
}
if (_smoothingSpecified)
return error*_beta;
else
return (error == 0) ? 0 : Double.POSITIVE_INFINITY;
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices()
{
int[] indexList = new int[_Nr * _Nc];
for (int col = 0, i = 0; col < _Nc; col++)
for (int row = 0; row < _Nr; row++, i++)
indexList[i] = i;
return indexList;
}
@SuppressWarnings("deprecation")
@Override
public final int[] getDirectedToIndicesForInput(Factor factor, int inputEdge)
{
final int outRows = _Nr;
final int outCols = _Nc;
final int outSize = outRows * outCols;
final int in1Rows = _Nr;
final int in1Cols = _Nx;
final int in1Size = in1Rows * in1Cols;
final int in2Rows = _Nx;
final int in2Cols = _Nc;
final int in2Size = in2Rows * in2Cols;
final int numEdges = factor.getArgumentCount();
final int nInputEdges = numEdges - outSize;
final int in1Offset = outSize;
int in2Offset = in1Offset;
if (nInputEdges == in1Size + in2Size ||
nInputEdges == in1Size + 1 && factor.hasConstantAtIndexOfType(numEdges - 1, double[][].class))
{
// First input matrix is flattened out.
in2Offset += in1Size;
}
else if (nInputEdges == 2 ||
nInputEdges == in2Size + 1 && factor.hasConstantAtIndexOfType(in1Offset, double[][].class))
{
// First input matrix is a constant
in2Offset += 1;
}
else
{
throw new DimpleException("Bad number of edges %d for MatrixProduct (Nr=%d, Nx=%d, Nc=%d)",
numEdges, _Nr, _Nx, _Nc);
}
if (inputEdge >= in2Offset)
{
// Edge from second input matrix -- changes column of output
final int[] to = new int[outRows];
final int col = (inputEdge - in2Offset) / in2Rows;
for (int row = 0, outIndex = col * outRows; row < outRows; ++row, ++outIndex)
{
to[row] = outIndex;
}
return to;
}
else
{
// Edge from second input matrix -- changes row of output
final int[] to = new int[outCols];
final int row = (inputEdge - in1Offset) % in1Rows;
for (int col = 0, outIndex = row; col < outCols; ++col, outIndex += outRows)
{
to[col] = outIndex;
}
return to;
}
}
@Override
public final boolean isDeterministicDirected() {return !_smoothingSpecified;}
@Override
public final void evalDeterministic(Value[] arguments)
{
final int Nr = _Nr;
final int Nx = _Nx;
final int Nc = _Nc;
double[][] in1 = _in1;
double[][] in2 = _in2;
int argIndex = Nr * Nc; // Skip the outputs
// Get the first input matrix values
if (arguments[argIndex].getObject() instanceof double[][]) // Constant matrix is passed as a single argument
in1 = (double[][])requireNonNull(arguments[argIndex++].getObject());
else
{
for (int x = 0; x < Nx; x++) // Scan by columns
for (int r = 0; r < Nr; r++)
in1[r][x] = arguments[argIndex++].getDouble();
}
// Get the second input matrix values
if (arguments[argIndex].getObject() instanceof double[][]) // Constant matrix is passed as a single argument
in2 = (double[][])requireNonNull(arguments[argIndex++].getObject());
else
{
for (int c = 0; c < Nc; c++) // Scan by columns
for (int x = 0; x < Nx; x++)
in2[x][c] = arguments[argIndex++].getDouble();
}
// Compute the output and replace the output values
int outIndex = 0;
for (int c = 0; c < Nc; c++) // Scan by columns
{
for (int r = 0; r < Nr; r++)
{
final double[] in1r = in1[r];
double sum = 0;
for (int x = 0; x < Nx; x++)
sum += in1r[x] * in2[x][c];
arguments[outIndex++].setDouble(sum);
}
}
}
@Override
public final int updateDeterministicLimit(int numEdges)
{
return _updateDeterministicLimit;
}
@Override
public final boolean updateDeterministic(Value[] values, Collection<IndexedValue> oldValues,
AtomicReference<int[]> changedOutputsHolder)
{
boolean incremental = false;
final int outRows = _Nr;
final int outCols = _Nc;
final int outSize = outRows * outCols;
final int in1Rows = _Nr;
final int in1Cols = _Nx;
final int in1Size = in1Rows * in1Cols;
final int in2Rows = _Nx;
final int in1Offset = outSize;
final Object objAtIn1Offset = values[in1Offset].getObject();
final double[][] in1Matrix = objAtIn1Offset instanceof double[][] ? (double[][])objAtIn1Offset : null;
final int in2Offset = in1Offset + (in1Matrix == null ? in1Size : 1);
final Object objAtIn2Offset = values[in2Offset].getObject();
final double[][] in2Matrix = objAtIn2Offset instanceof double[][] ? (double[][])objAtIn2Offset : null;
final int minSupportedIndex = in1Matrix == null ? in1Offset : (in2Matrix == null ? in2Offset : values.length);
final int maxSupportedIndex = in2Matrix == null ? values.length : (in1Matrix == null ? in2Offset : in1Offset);
OpenIntIntHashMap changedOutputSet = new OpenIntIntHashMap(Math.max(outRows, outCols));
doIncremental:
{
if (in1Matrix != null && in2Matrix != null)
{
break doIncremental;
}
for (IndexedValue old : oldValues)
{
final int changedIndex = old.getIndex();
if (changedIndex < in1Offset || changedIndex >= values.length)
{
throw new IndexOutOfBoundsException();
}
if (changedIndex < minSupportedIndex || changedIndex >= maxSupportedIndex)
{
break doIncremental;
}
final double newInput = values[changedIndex].getDouble();
final double oldInput = old.getValue().getDouble();
if (newInput == oldInput)
{
continue;
}
if (changedIndex >= in2Offset)
{
// Second matrix cell changed - changes column of output matrix
final int x = changedIndex - in2Offset;
final int col = x / in2Rows;
final int in2Row = x - col * in2Rows;
final int in1Col = in2Row;
int row = 0;
int outIndex = col * outRows;
if (in1Matrix != null)
{
for (; row < outRows; ++row, ++outIndex)
{
final Value outputValue = values[outIndex];
final double oldOutput = outputValue.getDouble();
final double in1Value = in1Matrix[row][in1Col];
outputValue.setDouble(oldOutput - in1Value * oldInput + in1Value * newInput);
changedOutputSet.put(outIndex, outIndex);
}
}
else
{
int in1Index = in1Offset + in1Col * in1Rows;
for (; row < outRows; ++row, ++outIndex, ++in1Index)
{
final Value outputValue = values[outIndex];
final double oldOutput = outputValue.getDouble();
final double in1Value = values[in1Index].getDouble();
outputValue.setDouble(oldOutput - in1Value * oldInput + in1Value * newInput);
changedOutputSet.put(outIndex, outIndex);
}
}
}
else
{
// First matrix cell changed - changes row of output matrix
final int x = changedIndex - in1Offset;
final int in1Col = x / in1Rows;
final int row = x - in1Col * in1Rows;
final int in2Row = in1Col;
int col = 0;
int outIndex = row;
if (in2Matrix != null)
{
final double[] rowValues = in2Matrix[in2Row];
for (; col < outCols; ++col, outIndex += outRows)
{
final Value outputValue = values[outIndex];
final double oldOutput = outputValue.getDouble();
final double in2Value = rowValues[col] ;
outputValue.setDouble(oldOutput - in2Value * oldInput + in2Value * newInput);
changedOutputSet.put(outIndex, outIndex);
}
}
else
{
int in2Index = in2Offset + in2Row;
for (; col < outCols; ++col, outIndex += outRows, in2Index += in2Rows)
{
final Value outputValue = values[outIndex];
final double oldOutput = outputValue.getDouble();
final double in2Value = values[in2Index].getDouble();
outputValue.setDouble(oldOutput - in2Value * oldInput + in2Value * newInput);
changedOutputSet.put(outIndex, outIndex);
}
}
}
}
incremental = true;
}
changedOutputsHolder.set(changedOutputSet.keys().elements());
return incremental || super.updateDeterministic(values, oldValues, changedOutputsHolder);
}
}