/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import org.apache.mahout.math.function.Functions;
import org.junit.Before;
import org.junit.Test;
import java.util.Iterator;
public abstract class AbstractTestVector extends MahoutTestCase {
private static final double[] values = {1.1, 2.2, 3.3};
private static final double[] gold = {0.0, 1.1, 0.0, 2.2, 0.0, 3.3, 0.0};
private Vector test;
abstract Vector generateTestVector(int cardinality);
Vector getTestVector() {
return test;
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
test = generateTestVector(2 * values.length + 1);
for (int i = 0; i < values.length; i++) {
test.set(2*i + 1, values[i]);
}
}
@Test
public void testCardinality() {
assertEquals("size", 7, test.size());
}
@Test
public void testIterator() {
Iterator<Vector.Element> iterator = test.iterateNonZero();
checkIterator(iterator, gold);
iterator = test.iterator();
checkIterator(iterator, gold);
double[] doubles = {0.0, 5.0, 0, 3.0};
RandomAccessSparseVector zeros = new RandomAccessSparseVector(doubles.length);
for (int i = 0; i < doubles.length; i++) {
zeros.setQuick(i, doubles[i]);
}
iterator = zeros.iterateNonZero();
checkIterator(iterator, doubles);
iterator = zeros.iterator();
checkIterator(iterator, doubles);
doubles = new double[]{0.0, 0.0, 0, 0.0};
zeros = new RandomAccessSparseVector(doubles.length);
for (int i = 0; i < doubles.length; i++) {
zeros.setQuick(i, doubles[i]);
}
iterator = zeros.iterateNonZero();
checkIterator(iterator, doubles);
iterator = zeros.iterator();
checkIterator(iterator, doubles);
}
private static void checkIterator(Iterator<Vector.Element> nzIter, double[] values) {
while (nzIter.hasNext()) {
Vector.Element elt = nzIter.next();
assertEquals(elt.index() + " Value: " + values[elt.index()]
+ " does not equal: " + elt.get(), values[elt.index()], elt.get(), 0.0);
}
}
@Test
public void testIteratorSet() {
Vector clone = test.clone();
Iterator<Vector.Element> it = clone.iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
e.set(e.get() * 2.0);
}
it = clone.iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
}
clone = test.clone();
it = clone.iterator();
while (it.hasNext()) {
Vector.Element e = it.next();
e.set(e.get() * 2.0);
}
it = clone.iterator();
while (it.hasNext()) {
Vector.Element e = it.next();
assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
}
}
@Test
public void testCopy() {
Vector copy = test.clone();
for (int i = 0; i < test.size(); i++) {
assertEquals("copy [" + i + ']', test.get(i), copy.get(i), EPSILON);
}
}
@Test
public void testGet() {
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2], test.get(i), EPSILON);
}
}
}
@Test(expected = IndexException.class)
public void testGetOver() {
test.get(test.size());
}
@Test(expected = IndexException.class)
public void testGetUnder() {
test.get(-1);
}
@Test
public void testSet() {
test.set(3, 4.5);
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else if (i == 3) {
assertEquals("set [" + i + ']', 4.5, test.get(i), EPSILON);
} else {
assertEquals("set [" + i + ']', values[i/2], test.get(i), EPSILON);
}
}
}
@Test
public void testSize() {
assertEquals("size", 3, test.getNumNondefaultElements());
}
@Test
public void testViewPart() {
Vector part = test.viewPart(1, 2);
assertEquals("part size", 2, part.getNumNondefaultElements());
for (int i = 0; i < part.size(); i++) {
assertEquals("part[" + i + ']', test.get(i+1), part.get(i), EPSILON);
}
}
@Test(expected = IndexException.class)
public void testViewPartUnder() {
test.viewPart(-1, values.length);
}
@Test(expected = IndexException.class)
public void testViewPartOver() {
test.viewPart(2, 7);
}
@Test(expected = IndexException.class)
public void testViewPartCardinality() {
test.viewPart(1, 8);
}
@Test
public void testSparseDoubleVectorInt() {
Vector val = new RandomAccessSparseVector(4);
assertEquals("size", 4, val.size());
for (int i = 0; i < 4; i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
}
}
@Test
public void testDot() {
double res = test.dot(test);
double expected = 3.3 * 3.3 + 2.2 * 2.2 + 1.1 * 1.1;
assertEquals("dot", expected, res, EPSILON);
}
@Test
public void testDot2() {
Vector test2 = test.clone();
test2.set(1, 0.0);
test2.set(3, 0.0);
assertEquals(3.3 * 3.3, test2.dot(test), EPSILON);
}
@Test(expected = CardinalityException.class)
public void testDotCardinality() {
test.dot(new DenseVector(test.size() + 1));
}
@Test
public void testNormalize() {
Vector val = test.normalize();
double mag = Math.sqrt(1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3);
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("dot", values[i/2] / mag, val.get(i), EPSILON);
}
}
}
@Test
public void testMinus() {
Vector val = test.minus(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
}
val = test.minus(test).minus(test);
assertEquals("cardinality", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i) + test.get(i), EPSILON);
}
Vector val1 = test.plus(1);
val = val1.minus(test);
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
}
val1 = test.plus(-1);
val = val1.minus(test);
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', -1.0, val.get(i), EPSILON);
}
}
@Test
public void testPlusDouble() {
Vector val = test.plus(1);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] + 1.0, val.get(i), EPSILON);
}
}
}
@Test
public void testPlusVector() {
Vector val = test.plus(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * 2.0, val.get(i), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testPlusVectorCardinality() {
test.plus(new DenseVector(test.size() + 1));
}
@Test
public void testTimesDouble() {
Vector val = test.times(3);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * 3.0, val.get(i), EPSILON);
}
}
}
@Test
public void testDivideDouble() {
Vector val = test.divide(3);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] / 3.0, val.get(i), EPSILON);
}
}
}
@Test
public void testTimesVector() {
Vector val = test.times(test);
assertEquals("size", test.size(), val.size());
for (int i = 0; i < test.size(); i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
} else {
assertEquals("get [" + i + ']', values[i/2] * values[i/2], val.get(i), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testTimesVectorCardinality() {
test.times(new DenseVector(test.size() + 1));
}
@Test
public void testZSum() {
double expected = 0;
for (double value : values) {
expected += value;
}
assertEquals("wrong zSum", expected, test.zSum(), EPSILON);
}
@Test
public void testGetDistanceSquared() {
Vector other = new RandomAccessSparseVector(test.size());
other.set(1, -2);
other.set(2, -5);
other.set(3, -9);
other.set(4, 1);
double expected = test.minus(other).getLengthSquared();
assertTrue("a.getDistanceSquared(b) != a.minus(b).getLengthSquared",
Math.abs(expected - test.getDistanceSquared(other)) < 10.0E-7);
}
@Test
public void testAssignDouble() {
test.assign(0);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignDoubleArray() {
double[] array = new double[test.size()];
test.assign(array);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignDoubleArrayCardinality() {
double[] array = new double[test.size() + 1];
test.assign(array);
}
@Test
public void testAssignVector() {
Vector other = new DenseVector(test.size());
test.assign(other);
for (int i = 0; i < values.length; i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignVectorCardinality() {
Vector other = new DenseVector(test.size() - 1);
test.assign(other);
}
@Test
public void testAssignUnaryFunction() {
test.assign(Functions.NEGATE);
for (int i = 1; i < values.length; i += 2) {
assertEquals("value[" + i + ']', -values[i], test.getQuick(i+2), EPSILON);
}
}
@Test
public void testAssignBinaryFunction() {
test.assign(test, Functions.PLUS);
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', 2 * values[i - 1], test.getQuick(i), EPSILON);
}
}
}
@Test
public void testAssignBinaryFunction2() {
test.assign(Functions.plus(4));
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 4.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', values[i - 1] + 4, test.getQuick(i), EPSILON);
}
}
}
@Test
public void testAssignBinaryFunction3() {
test.assign(Functions.mult(4));
for (int i = 0; i < values.length; i++) {
if (i % 2 == 0) {
assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
} else {
assertEquals("value[" + i + ']', values[i - 1] * 4, test.getQuick(i), EPSILON);
}
}
}
@Test
public void testLike() {
Vector other = test.like();
assertTrue("not like", test.getClass().isAssignableFrom(other.getClass()));
assertEquals("size", test.size(), other.size());
}
@Test
public void testCrossProduct() {
Matrix result = test.cross(test);
assertEquals("row size", test.size(), result.rowSize());
assertEquals("col size", test.size(), result.columnSize());
for (int row = 0; row < result.rowSize(); row++) {
for (int col = 0; col < result.columnSize(); col++) {
assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
* test.getQuick(col), result.getQuick(row, col), EPSILON);
}
}
}
}