/**
* Copyright (c) 2011 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation.
*/
package org.eclipse.recommenders.jayes.factor;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.*;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.util.MathUtils;
import org.eclipse.recommenders.testing.jayes.ArrayUtils;
import org.junit.Test;
@SuppressWarnings("deprecation")
public class DenseFactorTest {
private static double[] distribution2x2x2() {
// @formatter:off
return ArrayUtils.flatten(new double[][][] { { { 0.5, 0.5 }, { 1.0, 0.0 } }, { { 0.4, 0.6 }, { 0.3, 0.7 } } });
// @formatter:on
}
private static final double TOLERANCE = 0.00001;
@Test
public void testSum() {
AbstractFactor factor = new DenseFactor();
factor.setDimensionIDs(0, 1);
factor.setDimensions(2, 2);
factor.setValues(new DoubleArrayWrapper(0.5, 0.5, 1.0, 0.0));
double[] prob = MathUtils.normalize(factor.marginalizeAllBut(-1));
assertArrayEquals(prob, new double[] { 0.75, 0.25 }, TOLERANCE);
}
@Test
public void testSelectAndSum() {
AbstractFactor factor = create2x2x2Factor();
factor.setValues(new DoubleArrayWrapper(distribution2x2x2()));
factor.select(0, 0);
double[] prob = MathUtils.normalize(factor.marginalizeAllBut(-1));
assertArrayEquals(prob, new double[] { 0.75, 0.25 }, TOLERANCE);
factor.select(0, -1);
factor.select(1, 1);
prob = MathUtils.normalize(factor.marginalizeAllBut(-1));
assertArrayEquals(prob, new double[] { 0.65, 0.35 }, TOLERANCE);
}
private AbstractFactor create2x2x2Factor() {
AbstractFactor factor = new DenseFactor();
factor.setDimensionIDs(0, 1, 2);
factor.setDimensions(2, 2, 2);
return factor;
}
@Test
public void testSumMiddle1() {
AbstractFactor factor = create2x2x2Factor();
factor.setValues(new DoubleArrayWrapper(distribution2x2x2()));
factor.select(2, 0);
double[] prob = MathUtils.normalize(factor.marginalizeAllBut(1));
assertArrayEquals(prob, new double[] { 0.9 / 2.2, 1.3 / 2.2 }, TOLERANCE);
}
@Test
public void testSumMiddle2() {
AbstractFactor factor = create2x2x2Factor();
factor.setValues(new DoubleArrayWrapper(distribution2x2x2()));
factor.select(0, 1);
factor.select(2, 1);
double[] prob = MathUtils.normalize(factor.marginalizeAllBut(1));
assertArrayEquals(prob, new double[] { 0.6 / 1.3, 0.7 / 1.3 }, TOLERANCE);
}
@Test
public void testMultiplication() {
AbstractFactor f1 = create2x2x2Factor();
// @formatter:off
f1.setValues(new DoubleArrayWrapper(ArrayUtils.flatten(new double[][][] { { { 0.5, 0.5 }, { 0.5, 0.5 } },
{ { 0.5, 0.5 }, { 0.5, 0.5 } } })));
// @formatter:on
AbstractFactor f2 = new DenseFactor();
f2.setDimensionIDs(2, 0);
f2.setDimensions(2, 2);
f2.setValues(new DoubleArrayWrapper(1.0, 0.0, 0.0, 1.0));
f1.multiplyCompatible(f2);
assertArrayEquals(f1.getValues().toDoubleArray(), new double[] { 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5 },
TOLERANCE);
}
@Test
public void testPreparedSum() {
AbstractFactor f = new DenseFactor();
f.setDimensionIDs(0, 1, 2);
f.setDimensions(4, 4, 4);
f.fill(1);
AbstractFactor f2 = new DenseFactor();
f2.setDimensionIDs(2);
f2.setDimensions(4);
f.sumPrepared(f2.getValues(), f.prepareMultiplication(f2));
assertArrayEquals(f.marginalizeAllBut(-1), f2.getValues().toDoubleArray(), TOLERANCE);
}
@Test
public void testCopy() {
AbstractFactor f = create2x2x2Factor();
f.select(2, 1);
// no ArrayIndexOutOfBoundsException should be thrown
f.copyValues(new DoubleArrayWrapper(1, 1, 1, 1, 1, 1, 1, 1));
for (int oddIndex = 1; oddIndex < f.getValues().length(); oddIndex += 2) {
assertThat(f.getValue(oddIndex), is(1.0));
}
}
@Test
public void testMultiplySparseFactor() {
AbstractFactor f = create2x2x2Factor();
f.setValues(new DoubleArrayWrapper(distribution2x2x2()));
AbstractFactor f2 = SparseFactor.fromFactor(f);
f.multiplyCompatible(f2);
// @formatter:off
assertArrayEquals(
f.getValues().toDoubleArray(),
ArrayUtils.flatten(new double[][][] { { { 0.5 * 0.5, 0.5 * 0.5 }, { 1.0, 0.0 } },
{ { 0.4 * 0.4, 0.6 * 0.6 }, { 0.3 * 0.3, 0.7 * 0.7 } } }), TOLERANCE);
// @formatter:on
}
@Test
public void testZeroDimensional() {
AbstractFactor dense = new DenseFactor();
dense.setDimensionIDs();
dense.setDimensions();
dense.setValues(new DoubleArrayWrapper(2));
AbstractFactor dense2 = new DenseFactor();
dense2.setDimensionIDs();
dense2.setDimensions();
dense2.setValues(new DoubleArrayWrapper(3));
dense2.multiplyCompatible(dense);
assertThat(dense2.getValues().toDoubleArray(), is(new double[] { 6 }));
}
}