/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import org.apache.mahout.math.function.Functions;
import org.junit.Test;
public class WeightedVectorTest extends AbstractVectorTest {
@Test
public void testLength() {
Vector v = new DenseVector(new double[]{0.9921337470551008, 1.0031004325833064, 0.9963963182745947});
Centroid c = new Centroid(3, new DenseVector(v), 2);
assertEquals(c.getVector().getLengthSquared(), c.getLengthSquared(), 1.0e-17);
// previously, this wouldn't clear the cached squared length value correctly which would cause bad distances
c.set(0, -1);
System.out.printf("c = %.9f\nv = %.9f\n", c.getLengthSquared(), c.getVector().getLengthSquared());
assertEquals(c.getVector().getLengthSquared(), c.getLengthSquared(), 1.0e-17);
}
@Override
public Vector vectorToTest(int size) {
return new WeightedVector(new DenseVector(size), 4.52, 345);
}
@Test
public void testOrdering() {
WeightedVector v1 = new WeightedVector(new DenseVector(new double[]{1, 2, 3}), 5.41, 31);
WeightedVector v2 = new WeightedVector(new DenseVector(new double[]{1, 2, 3}), 5.00, 31);
WeightedVector v3 = new WeightedVector(new DenseVector(new double[]{1, 3, 3}), 5.00, 31);
WeightedVector v4 = v1.clone();
WeightedVectorComparator comparator = new WeightedVectorComparator();
assertTrue(comparator.compare(v1, v2) > 0);
assertTrue(comparator.compare(v3, v1) < 0);
assertTrue(comparator.compare(v3, v2) > 0);
assertEquals(0, comparator.compare(v4, v1));
assertEquals(0, comparator.compare(v1, v1));
}
@Test
public void testProjection() {
Vector v1 = new DenseVector(10).assign(Functions.random());
WeightedVector v2 = new WeightedVector(v1, v1, 31);
assertEquals(v1.dot(v1), v2.getWeight(), 1.0e-13);
assertEquals(31, v2.getIndex());
Matrix y = new DenseMatrix(10, 4).assign(Functions.random());
Matrix q = new QRDecomposition(y.viewPart(0, 10, 0, 3)).getQ();
Vector nullSpace = y.viewColumn(3).minus(q.times(q.transpose().times(y.viewColumn(3))));
WeightedVector v3 = new WeightedVector(q.viewColumn(0).plus(q.viewColumn(1)), nullSpace, 1);
assertEquals(0, v3.getWeight(), 1.0e-13);
Vector qx = q.viewColumn(0).plus(q.viewColumn(1)).normalize();
WeightedVector v4 = new WeightedVector(qx, q.viewColumn(0), 2);
assertEquals(Math.sqrt(0.5), v4.getWeight(), 1.0e-13);
WeightedVector v5 = WeightedVector.project(q.viewColumn(0), qx);
assertEquals(Math.sqrt(0.5), v5.getWeight(), 1.0e-13);
}
@Override
public void testSize() {
assertEquals("size", 3, getTestVector().getNumNonZeroElements());
}
@Override
Vector generateTestVector(int cardinality) {
return new WeightedVector(new DenseVector(cardinality), 3.14, 53);
}
}