// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.test.core; import static org.junit.Assert.*; import marmot.core.ConcatFloatFeatureVector; import marmot.core.DenseArrayFloatFeatureVector; import marmot.core.FloatFeatureVector; import marmot.core.FloatWeights; import marmot.core.ZeroFloatFeatureVector; import org.junit.Test; public class FloatVectorTest { class TestWeights implements FloatWeights { public int num_tags_; public double[] weights_; public TestWeights(int num_tags, int dim) { num_tags_ = num_tags; weights_ = new double[dim * num_tags]; } @Override public int getFloatIndex(int feature, int tag_index) { return feature * num_tags_ + tag_index; } @Override public double getFloatWeight(int index) { return weights_[index]; } @Override public void updateFloatWeight(int index, double value) { weights_[index] += value; } } public void sumHelper(double[] expected, TestWeights fw, FloatFeatureVector vc, double update) { assertArrayEquals(expected, fw.weights_, 1e-5); double expected_sum = 0.0; for (double f : expected) { expected_sum += f * f / update; } double actual_sum = vc.getDotProduct(fw, 0, 0); assertEquals(expected_sum, actual_sum, 1e-5); } @Test public void test() { { TestWeights fw = new TestWeights(1, 3 * 2); double[] w = { 1.0, 0.5, 0.3 }; FloatFeatureVector v = new DenseArrayFloatFeatureVector(w); double[] w2 = { -1.0, -0.5, -0.3 }; FloatFeatureVector v2 = new DenseArrayFloatFeatureVector(w2); FloatFeatureVector v3 = new ConcatFloatFeatureVector(v, v2); v3.updateFloatWeight(fw, 0, 0, 1.0); double[] expected = { 1.0, 0.5, 0.3, -1.0, -0.5, -0.3 }; sumHelper(expected, fw, v3, 1.0); } { TestWeights fw = new TestWeights(1, 3 * 3); double[] w = { 1.0, 0.5, 0.3 }; FloatFeatureVector v = new DenseArrayFloatFeatureVector(w); double[] w2 = { -1.0, -0.5, -0.3 }; FloatFeatureVector v2 = new DenseArrayFloatFeatureVector(w2); double[] w3 = { -2.0, -0.7, -0.2 }; FloatFeatureVector v3 = new DenseArrayFloatFeatureVector(w3); FloatFeatureVector vc = new ConcatFloatFeatureVector(v, new ConcatFloatFeatureVector(v2, v3)); vc.updateFloatWeight(fw, 0, 0, 1.0); double[] expected = { 1.0, 0.5, 0.3, -1.0, -0.5, -0.3, -2.0, -0.7, -0.2 }; sumHelper(expected, fw, vc, 1.0); } } @Test public void zero() { { TestWeights fw = new TestWeights(1, 3 * 2); FloatFeatureVector v = new ZeroFloatFeatureVector(3); double[] w2 = { -1.0, -0.5, -0.3 }; FloatFeatureVector v2 = new DenseArrayFloatFeatureVector(w2); FloatFeatureVector v3 = new ConcatFloatFeatureVector(v, v2); v3.updateFloatWeight(fw, 0, 0, 1.0); double[] expected = { 0.0, 0.0, 0.0, -1.0, -0.5, -0.3 }; sumHelper(expected, fw, v3, 1.0); } { TestWeights fw = new TestWeights(1, 3 * 2); double[] w = { 1.0, 0.5, 0.3 }; FloatFeatureVector v = new DenseArrayFloatFeatureVector(w); FloatFeatureVector v2 = new ZeroFloatFeatureVector(3); FloatFeatureVector v3 = new ConcatFloatFeatureVector(v, v2); v3.updateFloatWeight(fw, 0, 0, 0.1); double[] expected = { .1, 0.05, 0.03, 0.0, 0.0, 0.0 }; sumHelper(expected, fw, v3, 0.1); } } }