/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math.solver; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.MahoutTestCase; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.junit.Test; public final class LSMRTest extends MahoutTestCase { @Test public void basics() { Matrix m = hilbert(5); // make sure it is the hilbert matrix we know and love assertEquals(1, m.get(0, 0), 0); assertEquals(0.5, m.get(0, 1), 0); assertEquals(1 / 6.0, m.get(2, 3), 1.0e-9); Vector x = new DenseVector(new double[]{5, -120, 630, -1120, 630}); Vector b = new DenseVector(5); b.assign(1); assertEquals(0, m.times(x).minus(b).norm(2), 1.0e-9); LSMR r = new LSMR(); Vector x1 = r.solve(m, b); // the ideal solution is [5 -120 630 -1120 630] but the 5x5 hilbert matrix // has a condition number of almost 500,000 and the normal equation condition // number is that squared. This means that we don't get the exact answer with // a fast iterative solution. // Thus, we have to check the residuals rather than testing that the answer matched // the ideal. assertEquals(0, m.times(x1).minus(b).norm(2), 1.0e-2); assertEquals(0, m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), 1.0e-7); // and we need to check that the error estimates are pretty good. assertEquals(m.times(x1).minus(b).norm(2), r.getResidualNorm(), 1.0e-5); assertEquals(m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), r.getNormalEquationResidual(), 1.0e-9); } private static Matrix hilbert(int n) { Matrix r = new DenseMatrix(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { r.set(i, j, 1.0 / (i + j + 1)); } } return r; } /* private Matrix overDetermined(int n) { Random rand = RandomUtils.getRandom(); Matrix r = new DenseMatrix(2 * n, n); for (int i = 0; i < 2 * n; i++) { for (int j = 0; j < n; j++) { r.set(i, j, rand.nextGaussian()); } } return r; } */ }