/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import java.util.Iterator;
import java.util.Random;
import com.google.common.collect.Iterables;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.jet.random.Normal;
import org.apache.mahout.math.random.MultiNormal;
import org.junit.Before;
import org.junit.Test;
/**
* Makes sure that a vector under test acts the same as a DenseVector or RandomAccessSparseVector
* (according to whether it is dense or sparse). Most operations need to be done within a reasonable
* tolerance.
*
* The idea is that a new vector implementation can extend AbstractVectorTest to get pretty high
* confidence that it is working correctly.
*/
public abstract class AbstractVectorTest<T extends Vector> extends MahoutTestCase {
private static final double FUZZ = 1.0e-13;
private static final double[] values = {1.1, 2.2, 3.3};
private static final double[] gold = {0.0, 1.1, 0.0, 2.2, 0.0, 3.3, 0.0};
private Vector test;
private static void checkIterator(Iterator<Vector.Element> nzIter, double[] values) {
while (nzIter.hasNext()) {
Vector.Element elt = nzIter.next();
assertEquals(elt.index() + " Value: " + values[elt.index()]
+ " does not equal: " + elt.get(), values[elt.index()], elt.get(), 0.0);
}
}
public abstract T vectorToTest(int size);
@Test
public void testSimpleOps() {
T v0 = vectorToTest(20);
Random gen = RandomUtils.getRandom();
Vector v1 = v0.assign(new Normal(0, 1, gen));
// verify that v0 and v1 share and are identical
assertEquals(v0.get(12), v1.get(12), 0);
v0.set(12, gen.nextDouble());
assertEquals(v0.get(12), v1.get(12), 0);
assertSame(v0, v1);
Vector v2 = vectorToTest(20).assign(new Normal(0, 1, gen));
Vector dv1 = new DenseVector(v1);
Vector dv2 = new DenseVector(v2);
Vector sv1 = new RandomAccessSparseVector(v1);
Vector sv2 = new RandomAccessSparseVector(v2);
assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(v2)), FUZZ);
assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(dv2)), FUZZ);
assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(sv2)), FUZZ);
assertEquals(0, dv1.plus(dv2).getDistanceSquared(sv1.plus(v2)), FUZZ);
assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(v2)), FUZZ);
assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(dv2)), FUZZ);
assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(sv2)), FUZZ);
assertEquals(0, dv1.times(dv2).getDistanceSquared(sv1.times(v2)), FUZZ);
assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(v2)), FUZZ);
assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(dv2)), FUZZ);
assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(sv2)), FUZZ);
assertEquals(0, dv1.minus(dv2).getDistanceSquared(sv1.minus(v2)), FUZZ);
double z = gen.nextDouble();
assertEquals(0, dv1.divide(z).getDistanceSquared(v1.divide(z)), 1.0e-12);
assertEquals(0, dv1.times(z).getDistanceSquared(v1.times(z)), 1.0e-12);
assertEquals(0, dv1.plus(z).getDistanceSquared(v1.plus(z)), 1.0e-12);
assertEquals(dv1.dot(dv2), v1.dot(v2), FUZZ);
assertEquals(dv1.dot(dv2), v1.dot(dv2), FUZZ);
assertEquals(dv1.dot(dv2), v1.dot(sv2), FUZZ);
assertEquals(dv1.dot(dv2), sv1.dot(v2), FUZZ);
assertEquals(dv1.dot(dv2), dv1.dot(v2), FUZZ);
// first attempt has no cached distances
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2), FUZZ);
// now repeat with cached sizes
assertEquals(dv1.getLengthSquared(), v1.getLengthSquared(), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2), FUZZ);
assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2), FUZZ);
assertEquals(dv1.minValue(), v1.minValue(), FUZZ);
assertEquals(dv1.minValueIndex(), v1.minValueIndex());
assertEquals(dv1.maxValue(), v1.maxValue(), FUZZ);
assertEquals(dv1.maxValueIndex(), v1.maxValueIndex());
Vector nv1 = v1.normalize();
assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
assertEquals(1, nv1.norm(2), FUZZ);
assertEquals(0, dv1.normalize().getDistanceSquared(nv1), FUZZ);
nv1 = v1.normalize(1);
assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
assertEquals(1, nv1.norm(1), FUZZ);
assertEquals(0, dv1.normalize(1).getDistanceSquared(nv1), FUZZ);
assertEquals(dv1.norm(0), v1.norm(0), FUZZ);
assertEquals(dv1.norm(1), v1.norm(1), FUZZ);
assertEquals(dv1.norm(1.5), v1.norm(1.5), FUZZ);
assertEquals(dv1.norm(2), v1.norm(2), FUZZ);
assertEquals(dv1.zSum(), v1.zSum(), FUZZ);
assertEquals(3.1 * v1.size(), v1.assign(3.1).zSum(), FUZZ);
assertEquals(0, v1.plus(-3.1).norm(1), FUZZ);
v1.assign(dv1);
assertEquals(0, v1.getDistanceSquared(dv1), FUZZ);
assertEquals(dv1.zSum() - dv1.size() * 3.4, v1.assign(Functions.minus(3.4)).zSum(), FUZZ);
assertEquals(dv1.zSum() - dv1.size() * 4.5, v1.assign(Functions.MINUS, 1.1).zSum(), FUZZ);
v1.assign(dv1);
assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.assign(v2, Functions.MINUS)), FUZZ);
v1.assign(dv1);
assertEquals(dv1.norm(2), Math.sqrt(v1.aggregate(Functions.PLUS, Functions.pow(2))), FUZZ);
assertEquals(dv1.dot(dv2), v1.aggregate(v2, Functions.PLUS, Functions.MULT), FUZZ);
assertEquals(dv1.viewPart(5, 10).zSum(), v1.viewPart(5, 10).zSum(), FUZZ);
Vector v3 = v1.clone();
// must be the right type ... tricky to tell that in the face of type erasure
assertTrue(v0.getClass().isAssignableFrom(v3.getClass()));
assertTrue(v3.getClass().isAssignableFrom(v0.getClass()));
assertEquals(0, v1.getDistanceSquared(v3), FUZZ);
assertNotSame(v1, v3);
v3.assign(0);
assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
assertEquals(0, v3.getLengthSquared(), FUZZ);
dv1.assign(Functions.ABS);
v1.assign(Functions.ABS);
assertEquals(0, dv1.logNormalize().getDistanceSquared(v1.logNormalize()), FUZZ);
assertEquals(0, dv1.logNormalize(1.5).getDistanceSquared(v1.logNormalize(1.5)), FUZZ);
// aggregate
// cross,
// getNumNondefaultElements
for (Vector.Element element : v1.all()) {
assertEquals(dv1.get(element.index()), element.get(), 0);
assertEquals(dv1.get(element.index()), v1.get(element.index()), 0);
assertEquals(dv1.get(element.index()), v1.getQuick(element.index()), 0);
}
}
abstract Vector generateTestVector(int cardinality);
Vector getTestVector() {
return test;
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
test = generateTestVector(2 * values.length + 1);
for (int i = 0; i < values.length; i++) {
test.set(2 * i + 1, values[i]);
}
}
@Test
public void testCardinality() {
assertEquals("size", 7, test.size());
}
@Test
public void testIterator() {
Iterator<Vector.Element> iterator = test.nonZeroes().iterator();
checkIterator(iterator, gold);
iterator = test.all().iterator();
checkIterator(iterator, gold);
double[] doubles = {0.0, 5.0, 0, 3.0};
RandomAccessSparseVector zeros = new RandomAccessSparseVector(doubles.length);
for (int i = 0; i < doubles.length; i++) {
zeros.setQuick(i, doubles[i]);
}
iterator = zeros.iterateNonZero();
checkIterator(iterator, doubles);
iterator = zeros.iterator();
checkIterator(iterator, doubles);
doubles = new double[]{0.0, 0.0, 0, 0.0};
zeros = new RandomAccessSparseVector(doubles.length);
for (int i = 0; i < doubles.length; i++) {
zeros.setQuick(i, doubles[i]);
}
iterator = zeros.iterateNonZero();
checkIterator(iterator, doubles);
iterator = zeros.iterator();
checkIterator(iterator, doubles);
}
@Test
public void testIteratorSet() {
Vector clone = test.clone();
for (Element e : clone.nonZeroes()) {
e.set(e.get() * 2.0);
}
for (Element e : clone.nonZeroes()) {
assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
}
clone = test.clone();
for (Element e : clone.all()) {
e.set(e.get() * 2.0);
}
for (Element e : clone.all()) {
assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
}
}
@Test
public void testCopy() {
Vector copy = test.clone();
for (int i = 0; i < test.size(); i++) {
assertEquals("copy [" + i + ']', test.get(i), copy.get(i), EPSILON);
}
}
@Test
public void testGet() {
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2], test.get(i), EPSILON);
}
}
}
@Test(expected = IndexException.class)
public void testGetOver() {
test.get(test.size());
}
@Test(expected = IndexException.class)
public void testGetUnder() {
test.get(-1);
}
@Test
public void testSet() {
test.set(3, 4.5);
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else if (i == 3) {
assertEquals("set [" + i + ']', 4.5, test.get(i), EPSILON);
} else {
assertEquals("set [" + i + ']', values[i/2], test.get(i), EPSILON);
}
}
}
@Test
public void testSize() {
assertEquals("size", 3, test.getNumNondefaultElements());
}
@Test
public void testViewPart() {
Vector part = test.viewPart(1, 2);
assertEquals("part size", 2, part.getNumNondefaultElements());
for (int i = 0; i < part.size(); i++) {
assertEquals("part[" + i + ']', test.get(i+1), part.get(i), EPSILON);
}
}
@Test(expected = IndexException.class)
public void testViewPartUnder() {
test.viewPart(-1, values.length);
}
@Test(expected = IndexException.class)
public void testViewPartOver() {
test.viewPart(2, 7);
}
@Test(expected = IndexException.class)
public void testViewPartCardinality() {
test.viewPart(1, 8);
}
@Test
public void testSparseDoubleVectorInt() {
Vector val = new RandomAccessSparseVector(4);
assertEquals("size", 4, val.size());
for (int i = 0; i < 4; i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
}
}
@Test
public void testDot() {
double res = test.dot(test);
double expected = 3.3 * 3.3 + 2.2 * 2.2 + 1.1 * 1.1;
assertEquals("dot", expected, res, EPSILON);
}
@Test
public void testDot2() {
Vector test2 = test.clone();
test2.set(1, 0.0);
test2.set(3, 0.0);
assertEquals(3.3 * 3.3, test2.dot(test), EPSILON);
}
@Test(expected = CardinalityException.class)
public void testDotCardinality() {
test.dot(new DenseVector(test.size() + 1));
}
@Test
public void testNormalize() {
Vector val = test.normalize();
double mag = Math.sqrt(1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3);
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("dot", values[i/2] / mag, val.get(i), EPSILON);
}
}
}
@Test
public void testMinus() {
Vector val = test.minus(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
}
val = test.minus(test).minus(test);
assertEquals("cardinality", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i) + test.get(i), EPSILON);
}
Vector val1 = test.plus(1);
val = val1.minus(test);
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
}
val1 = test.plus(-1);
val = val1.minus(test);
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', -1.0, val.get(i), EPSILON);
}
}
@Test
public void testPlusDouble() {
Vector val = test.plus(1);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] + 1.0, val.get(i), EPSILON);
}
}
}
@Test
public void testPlusVector() {
Vector val = test.plus(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * 2.0, val.get(i), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testPlusVectorCardinality() {
test.plus(new DenseVector(test.size() + 1));
}
@Test
public void testTimesDouble() {
Vector val = test.times(3);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * 3.0, val.get(i), EPSILON);
}
}
}
@Test
public void testDivideDouble() {
Vector val = test.divide(3);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] / 3.0, val.get(i), EPSILON);
}
}
}
@Test
public void testTimesVector() {
Vector val = test.times(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * values[i/2], val.get(i), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testTimesVectorCardinality() {
test.times(new DenseVector(test.size() + 1));
}
@Test
public void testZSum() {
double expected = 0;
for (double value : values) {
expected += value;
}
assertEquals("wrong zSum", expected, test.zSum(), EPSILON);
}
@Test
public void testGetDistanceSquared() {
Vector other = new RandomAccessSparseVector(test.size());
other.set(1, -2);
other.set(2, -5);
other.set(3, -9);
other.set(4, 1);
double expected = test.minus(other).getLengthSquared();
assertTrue("a.getDistanceSquared(b) != a.minus(b).getLengthSquared",
Math.abs(expected - test.getDistanceSquared(other)) < 10.0E-7);
}
@Test
public void testAssignDouble() {
test.assign(0);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignDoubleArray() {
double[] array = new double[test.size()];
test.assign(array);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignDoubleArrayCardinality() {
double[] array = new double[test.size() + 1];
test.assign(array);
}
@Test
public void testAssignVector() {
Vector other = new DenseVector(test.size());
test.assign(other);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignVectorCardinality() {
Vector other = new DenseVector(test.size() - 1);
test.assign(other);
}
@Test
public void testAssignUnaryFunction() {
test.assign(Functions.NEGATE);
for (int i = 1; i < values.length; i += 2) {
assertEquals("value[" + i + ']', -values[i], test.getQuick(i+2), EPSILON);
}
}
@Test
public void testAssignBinaryFunction() {
test.assign(test, Functions.PLUS);
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', 2 * values[i - 1], test.getQuick(i), EPSILON);
}
}
}
@Test
public void testAssignBinaryFunction2() {
test.assign(Functions.plus(4));
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 4.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', values[i - 1] + 4, test.getQuick(i), EPSILON);
}
}
}
@Test
public void testAssignBinaryFunction3() {
test.assign(Functions.mult(4));
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', values[i - 1] * 4, test.getQuick(i), EPSILON);
}
}
}
@Test
public void testLike() {
Vector other = test.like();
assertTrue("not like", test.getClass().isAssignableFrom(other.getClass()));
assertEquals("size", test.size(), other.size());
}
@Test
public void testCrossProduct() {
Matrix result = test.cross(test);
assertEquals("row size", test.size(), result.rowSize());
assertEquals("col size", test.size(), result.columnSize());
for (int row = 0; row < result.rowSize(); row++) {
for (int col = 0; col < result.columnSize(); col++) {
assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
* test.getQuick(col), result.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testIterators() {
final T v0 = vectorToTest(20);
double sum = 0;
int elements = 0;
int nonZero = 0;
for (Element element : v0.all()) {
elements++;
sum += element.get();
if (element.get() != 0) {
nonZero++;
}
}
int nonZeroIterated = Iterables.size(v0.nonZeroes());
assertEquals(20, elements);
assertEquals(v0.size(), elements);
assertEquals(nonZeroIterated, nonZero);
assertEquals(v0.zSum(), sum, 0);
}
@Test
public void testSmallDistances() {
for (double fuzz : new double[]{1.0e-5, 1.0e-6, 1.0e-7, 1.0e-8, 1.0e-9, 1.0e-10}) {
MultiNormal x = new MultiNormal(fuzz, new ConstantVector(0, 20));
for (int i = 0; i < 10000; i++) {
final T v1 = vectorToTest(20);
Vector v2 = v1.plus(x.sample());
if (1 + fuzz * fuzz > 1) {
String msg = String.format("fuzz = %.1g, >", fuzz);
assertTrue(msg, v1.getDistanceSquared(v2) > 0);
assertTrue(msg, v2.getDistanceSquared(v1) > 0);
} else {
String msg = String.format("fuzz = %.1g, >=", fuzz);
assertTrue(msg, v1.getDistanceSquared(v2) >= 0);
assertTrue(msg, v2.getDistanceSquared(v1) >= 0);
}
}
}
}
public void testToString() {
Vector w;
w = generateTestVector(20);
w.set(0, 1.1);
w.set(13, 100500.);
w.set(19, 3.141592);
assertEquals("{0:1.1,13:100500.0,19:3.141592}", w.toString());
w = generateTestVector(12);
w.set(10, 0.1);
assertEquals("{10:0.1}", w.toString());
w = generateTestVector(12);
assertEquals("{}", w.toString());
}
}