/*******************************************************************************
* Copyright (c) 2012 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation
******************************************************************************/
package org.eclipse.recommenders.jayes.factor;
import java.util.Arrays;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.factor.arraywrapper.IArrayWrapper;
import org.eclipse.recommenders.jayes.factor.opcache.DivisionCache;
import org.eclipse.recommenders.jayes.factor.opcache.ModuloCache;
import org.eclipse.recommenders.jayes.util.MathUtils;
public abstract class AbstractFactor implements Cloneable {
public abstract void copyValues(IArrayWrapper arrayWrapper);
public abstract int[] prepareMultiplication(AbstractFactor compatible);
protected abstract int getRealPosition(int virtualPosition);
public abstract void fill(double d);
protected int[] dimensions = new int[0];
protected int[] dimensionIDs = new int[0];
protected IArrayWrapper values = new DoubleArrayWrapper(0.0);
protected int[] selections = new int[0];
protected Cut cut = new Cut(this);
private boolean isCutValid = false;
private boolean isLogScale = false;
public AbstractFactor() {
super();
}
public void setValues(IArrayWrapper values) {
this.values = values;
assert MathUtils.product(dimensions) == values.length();
}
public IArrayWrapper getValues() {
return values;
}
public double getValue(int i) {
return values.getDouble(getRealPosition(i));
}
public void setDimensions(int... dimensions) {
this.dimensions = Arrays.copyOf(dimensions, dimensions.length);
selections = new int[dimensions.length];
resetSelections();
int length = MathUtils.product(dimensions);
if (length > values.length()) {
values.newArray(length);
}
dimensionIDs = Arrays.copyOf(dimensionIDs, dimensions.length);
}
public int[] getDimensions() {
return dimensions;
}
/**
* tells the Factor which variables the dimensions correspond to. Uniqueness and consistency of size is assumed.
*/
public void setDimensionIDs(int... ids) {
dimensionIDs = ids.clone();
}
public int[] getDimensionIDs() {
return dimensionIDs;
}
protected int getDimensionFromID(int id) {
for (int i = 0; i < dimensionIDs.length; i++) {
if (dimensionIDs[i] == id) {
return i;
}
}
return -1;
}
public void select(int dimensionID, int index) {
int dim = getDimensionFromID(dimensionID);
if (selections[dim] != index) {
selections[dim] = index;
isCutValid = false;
}
}
public void resetSelections() {
Arrays.fill(selections, -1);
isCutValid = false;
}
public void setLogScale(boolean isLogScale) {
this.isLogScale = isLogScale;
}
public boolean isLogScale() {
return isLogScale;
}
/**
* marginalizes out all variables except the one with id sumDimensionID
*
* @param sumDimensionID
* -1 for last dimension (default)
* @return
*/
@Deprecated
public double[] marginalizeAllBut(int sumDimensionID) {
validateCut();
if (sumDimensionID == -1) {
sumDimensionID = dimensionIDs[dimensionIDs.length - 1];
}
int sumDimension = getDimensionFromID(sumDimensionID);
double[] result = new double[dimensions[sumDimension]];
int divisor = MathUtils.productOfRange(dimensions, sumDimension + 1, dimensions.length);
DivisionCache division = new DivisionCache(divisor);
sumToBucket(cut, 0, division, new ModuloCache(result.length), result);
return result;
}
private void sumToBucket(Cut cut, int offset, DivisionCache division, ModuloCache modulo, double[] result) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
int targetPos = modulo.apply(division.apply(i));
result[targetPos] += values.getDouble(j);
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
sumToBucket(c, offset + i, division, modulo, result);
}
}
}
/**
* multiply the factors. Only compatible factors are allowed, meaning ones that have a subset of the variables of
* this factor (assume consistent Dimension ID / size pairs)
*
* @param compatible
*/
public void multiplyCompatible(AbstractFactor compatible) {
int[] positions = prepareMultiplication(compatible);
multiplyPrepared(compatible.values, positions);
}
public void multiplyPrepared(IArrayWrapper compatibleValues, int[] positions) {
validateCut();
if (!isLogScale) {
multiplyPrepared(cut, 0, compatibleValues, positions);
} else {
multiplyPreparedLog(cut, 0, compatibleValues, positions);
}
}
private void multiplyPrepared(Cut cut, int offset, IArrayWrapper compatibleValues, int[] positions) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
values.mulAssign(j, compatibleValues, positions[j]);
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
multiplyPrepared(c, offset + i, compatibleValues, positions);
}
}
}
public void sumPrepared(IArrayWrapper compatibleFactorValues, int[] preparedOperation) {
validateCut();
compatibleFactorValues.fill(0);
if (!isLogScale) {
sumPrepared(cut, 0, compatibleFactorValues, preparedOperation);
} else {
sumPreparedLog(compatibleFactorValues, preparedOperation);
}
}
private void sumPrepared(Cut cut, int offset, IArrayWrapper compatibleFactorValues, int[] positions) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
compatibleFactorValues.addAssign(positions[j], values, j);
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
sumPrepared(c, offset + i, compatibleFactorValues, positions);
}
}
}
private void sumPreparedLog(IArrayWrapper compatibleFactorValues, int[] positions) {
double max = findMax(cut, 0, 0);
sumPreparedLog(cut, 0, compatibleFactorValues, positions, max);
for (int i = 0; i < compatibleFactorValues.length(); i++) {
compatibleFactorValues.set(i, Math.log(compatibleFactorValues.getDouble(i)) + max);
}
}
private double findMax(Cut cut, int offset, double max) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
if (values.getDouble(j) != Double.NEGATIVE_INFINITY && Math.abs(values.getDouble(j)) > Math.abs(max)) {
max = values.getDouble(j);
}
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
double pot = findMax(c, offset + i, max);
if (pot != Double.NEGATIVE_INFINITY && Math.abs(pot) > Math.abs(max)) {
max = pot;
}
}
}
return max;
}
private void sumPreparedLog(Cut cut, int offset, IArrayWrapper compatibleFactorValues, int[] positions, double max) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
compatibleFactorValues.addAssign(positions[j], Math.exp(values.getDouble(j) - max));
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
sumPreparedLog(c, offset + i, compatibleFactorValues, positions, max);
}
}
}
private void multiplyPreparedLog(Cut cut, int offset, IArrayWrapper compatibleValues, int[] positions) {
if (cut.getSubCut() == null) {
int last = cut.getEnd() + offset;
for (int i = cut.getStart() + offset; i < last; i += cut.getStepSize()) {
int j = getRealPosition(i);
values.addAssign(j, compatibleValues, positions[j]);
}
} else {
Cut c = cut.getSubCut();
for (int i = 0; i < cut.getLength(); i += cut.getSubtreeStepsize()) {
multiplyPreparedLog(c, offset + i, compatibleValues, positions);
}
}
}
protected void validateCut() {
if (!isCutValid) {
cut.initialize();
isCutValid = true;
}
}
@Override
public AbstractFactor clone() {
AbstractFactor f = null;
try {
f = (AbstractFactor) super.clone();
} catch (CloneNotSupportedException x) {
// should not be possible to happen
throw new RuntimeException(x);
}
f.values = values.clone();
f.selections = selections.clone();
f.cut = new Cut(f);
f.isCutValid = false;
return f;
}
public void multiplyCompatibleToLog(AbstractFactor factor) {
int[] positions = prepareMultiplication(factor);
for (int i = 0; i < values.length(); i++) {
values.addAssign(i, Math.log(factor.values.getDouble(positions[i])));
}
}
/**
* @return approximated memory requirements on top of the value array
*/
public abstract int getOverhead();
}