/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.TimesFunction;
import org.junit.Test;
import java.util.Iterator;
public final class TestVectorView extends MahoutTestCase {
private static final int CARDINALITY = 3;
private static final int OFFSET = 1;
private final double[] values = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5};
private final Vector test = new VectorView(new DenseVector(values), OFFSET, CARDINALITY);
@Test
public void testCardinality() {
assertEquals("size", 3, test.size());
}
@Test
public void testCopy() throws Exception {
Vector copy = test.clone();
for (int i = 0; i < test.size(); i++) {
assertEquals("copy [" + i + ']', test.get(i), copy.get(i), EPSILON);
}
}
@Test
public void testGet() throws Exception {
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[i + OFFSET], test.get(i), EPSILON);
}
}
@Test(expected = IndexException.class)
public void testGetOver() {
test.get(test.size());
}
@Test
public void testIterator() throws Exception {
VectorView view = new VectorView(new DenseVector(values), OFFSET, CARDINALITY);
double[] gold = {1.1, 2.2, 3.3};
Iterator<Vector.Element> iter = view.iterator();
checkIterator(iter, gold);
iter = view.iterateNonZero();
checkIterator(iter, gold);
view = new VectorView(new DenseVector(values), 0, CARDINALITY);
gold = new double[]{0.0, 1.1, 2.2};
iter = view.iterator();
checkIterator(iter, gold);
gold = new double[]{1.1, 2.2};
iter = view.iterateNonZero();
checkIterator(iter, gold);
}
private static void checkIterator(Iterator<Vector.Element> iter, double[] gold) {
int i = 0;
while (iter.hasNext()) {
Vector.Element elt = iter.next();
assertEquals(elt.index() + " Value: " + gold[i]
+ " does not equal: " + elt.get(), gold[i], elt.get(), 0.0);
i++;
}
}
@Test(expected = IndexException.class)
public void testGetUnder() {
test.get(-1);
}
@Test
public void testSet() throws Exception {
test.set(2, 4.5);
for (int i = 0; i < test.size(); i++) {
assertEquals("set [" + i + ']', i == 2 ? 4.5 : values[OFFSET + i], test.get(i), EPSILON);
}
}
@Test
public void testSize() throws Exception {
assertEquals("size", 3, test.getNumNondefaultElements());
}
@Test
public void testViewPart() throws Exception {
Vector part = test.viewPart(1, 2);
assertEquals("part size", 2, part.getNumNondefaultElements());
for (int i = 0; i < part.size(); i++) {
assertEquals("part[" + i + ']', values[OFFSET + i + 1], part.get(i), EPSILON);
}
}
@Test(expected = IndexException.class)
public void testViewPartUnder() {
test.viewPart(-1, CARDINALITY);
}
@Test(expected = IndexException.class)
public void testViewPartOver() {
test.viewPart(2, CARDINALITY);
}
@Test(expected = IndexException.class)
public void testViewPartCardinality() {
test.viewPart(1, values.length + 1);
}
@Test
public void testDot() throws Exception {
double res = test.dot(test);
assertEquals("dot", 1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3, res, EPSILON);
}
@Test(expected = CardinalityException.class)
public void testDotCardinality() {
test.dot(new DenseVector(test.size() + 1));
}
@Test
public void testNormalize() throws Exception {
Vector res = test.normalize();
double mag = Math.sqrt(1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3);
for (int i = 0; i < test.size(); i++) {
assertEquals("dot", values[OFFSET + i] / mag, res.get(i), EPSILON);
}
}
@Test
public void testMinus() throws Exception {
Vector val = test.minus(test);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
}
}
@Test
public void testPlusDouble() throws Exception {
Vector val = test.plus(1);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[OFFSET + i] + 1, val.get(i), EPSILON);
}
}
@Test
public void testPlusVector() throws Exception {
Vector val = test.plus(test);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[OFFSET + i] * 2, val.get(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testPlusVectorCardinality() {
test.plus(new DenseVector(test.size() + 1));
}
@Test
public void testTimesDouble() throws Exception {
Vector val = test.times(3);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[OFFSET + i] * 3, val.get(i), EPSILON);
}
}
@Test
public void testDivideDouble() throws Exception {
Vector val = test.divide(3);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[OFFSET + i] / 3, val.get(i), EPSILON);
}
}
@Test
public void testTimesVector() throws Exception {
Vector val = test.times(test);
assertEquals("size", 3, val.size());
for (int i = 0; i < test.size(); i++) {
assertEquals("get [" + i + ']', values[OFFSET + i] * values[OFFSET + i],
val.get(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testTimesVectorCardinality() {
test.times(new DenseVector(test.size() + 1));
}
@Test
public void testZSum() {
double expected = 0;
for (int i = OFFSET; i < OFFSET + CARDINALITY; i++) {
expected += values[i];
}
assertEquals("wrong zSum", expected, test.zSum(), EPSILON);
}
@Test
public void testAssignDouble() {
test.assign(0);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignDoubleArray() throws Exception {
double[] array = new double[test.size()];
test.assign(array);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignDoubleArrayCardinality() {
double[] array = new double[test.size() + 1];
test.assign(array);
}
@Test
public void testAssignVector() throws Exception {
Vector other = new DenseVector(test.size());
test.assign(other);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
}
}
@Test(expected = CardinalityException.class)
public void testAssignVectorCardinality() {
Vector other = new DenseVector(test.size() - 1);
test.assign(other);
}
@Test
public void testAssignUnaryFunction() {
test.assign(Functions.NEGATE);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', -values[i + 1], test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignBinaryFunction() throws Exception {
test.assign(test, Functions.PLUS);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', 2 * values[i + 1], test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignBinaryFunction2() throws Exception {
test.assign(Functions.PLUS, 4);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', values[i + 1] + 4, test.getQuick(i), EPSILON);
}
}
@Test
public void testAssignBinaryFunction3() throws Exception {
test.assign(new TimesFunction(), 4);
for (int i = 0; i < test.size(); i++) {
assertEquals("value[" + i + ']', values[i + 1] * 4, test.getQuick(i), EPSILON);
}
}
@Test
public void testLike() {
assertTrue("not like", test.like() instanceof VectorView);
}
@Test
public void testCrossProduct() {
Matrix result = test.cross(test);
assertEquals("row size", test.size(), result.rowSize());
assertEquals("col size", test.size(), result.columnSize());
for (int row = 0; row < result.rowSize(); row++) {
for (int col = 0; col < result.columnSize(); col++) {
assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
* test.getQuick(col), result.getQuick(row, col), EPSILON);
}
}
}
}