/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.solver;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;
public class TestConjugateGradientSolver extends MahoutTestCase {
@Test
public void testConjugateGradientSolver() {
Matrix a = getA();
Vector b = getB();
ConjugateGradientSolver solver = new ConjugateGradientSolver();
Vector x = solver.solve(a, b);
assertEquals(0.0, Math.sqrt(a.times(x).getDistanceSquared(b)), EPSILON);
assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
assertEquals(10, solver.getIterations());
}
@Test
public void testConditionedConjugateGradientSolver() {
Matrix a = getIllConditionedMatrix();
Vector b = getB();
Preconditioner conditioner = new JacobiConditioner(a);
ConjugateGradientSolver solver = new ConjugateGradientSolver();
Vector x = solver.solve(a, b, null, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
assertEquals(0.0, distance, EPSILON);
assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
assertEquals(16, solver.getIterations());
Vector x2 = solver.solve(a, b, conditioner, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
// the Jacobi preconditioner isn't very good, but it does result in one less iteration to converge
distance = Math.sqrt(a.times(x2).getDistanceSquared(b));
assertEquals(0.0, distance, EPSILON);
assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
assertEquals(15, solver.getIterations());
}
@Test
public void testEarlyStop() {
Matrix a = getA();
Vector b = getB();
ConjugateGradientSolver solver = new ConjugateGradientSolver();
// specifying a looser max error will result in few iterations but less accurate results
Vector x = solver.solve(a, b, null, 10, 0.1);
double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
assertTrue(distance > EPSILON);
assertEquals(0.0, distance, 0.1); // should be equal to within the error specified
assertEquals(7, solver.getIterations()); // should have taken fewer iterations
// can get a similar effect by bounding the number of iterations
x = solver.solve(a, b, null, 7, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
distance = Math.sqrt(a.times(x).getDistanceSquared(b));
assertTrue(distance > EPSILON);
assertEquals(0.0, distance, 0.1);
assertEquals(7, solver.getIterations());
}
private static Matrix getA() {
return reshape(new double[] {
11.7155649822793997, -0.7125253363083646, 4.6473613961860183, 1.6020939468348456, -4.6789817799137134,
-0.8140416763434970, -4.5995617505618345, -1.1749070042775340, -1.6747995811678336, 3.1922255171058342,
-0.7125253363083646, 12.3400579683994867, -2.6498099427000645, 0.5264507222630669, 0.3783428369189767,
-2.1170186159188811, 2.3695134252190528, 3.8182131490333013, 6.5285942298270347, 2.8564814419366353,
4.6473613961860183, -2.6498099427000645, 16.1317933921668484, -0.0409475448061225, 1.4805687075608227,
-2.9958076484628950, -2.5288893025027264, -0.9614557539842487, -2.2974738351519077, -1.5516184284572598,
1.6020939468348456, 0.5264507222630669, -0.0409475448061225, 4.1946802122694482, -2.5210038046912198,
0.6634899962909317, 0.4036187419205338, -0.2829211393003727, -0.2283091172980954, 1.1253516563552464,
-4.6789817799137134, 0.3783428369189767, 1.4805687075608227, -2.5210038046912198, 19.4307361862733430,
-2.5200132222091787, 2.3748511971444510, 11.6426598443305522, -0.1508136510863874, 4.3471343888063512,
-0.8140416763434970, -2.1170186159188811, -2.9958076484628950, 0.6634899962909317, -2.5200132222091787,
7.6712334419700747, -3.8687773629502851, -3.0453418711591529, -0.1155580876143619, -2.4025459467422121,
-4.5995617505618345, 2.3695134252190528, -2.5288893025027264, 0.4036187419205338, 2.3748511971444510,
-3.8687773629502851, 10.4681666057470082, 1.6527180866171229, 2.9341795819365384, -2.1708176372763099,
-1.1749070042775340, 3.8182131490333013, -0.9614557539842487, -0.2829211393003727, 11.6426598443305522,
-3.0453418711591529, 1.6527180866171229, 16.0050616934176233, 1.1689747208793086, 1.6665090945954870,
-1.6747995811678336, 6.5285942298270347, -2.2974738351519077, -0.2283091172980954, -0.1508136510863874,
-0.1155580876143619, 2.9341795819365384, 1.1689747208793086, 6.4794329751637481, -1.9197339981871877,
3.1922255171058342, 2.8564814419366353, -1.5516184284572598, 1.1253516563552464, 4.3471343888063512,
-2.4025459467422121, -2.1708176372763099, 1.6665090945954870, -1.9197339981871877, 18.9149021356344598
}, 10, 10);
}
private static Vector getB() {
return new DenseVector(new double[] {
-0.552252, 0.038430, 0.058392, -1.234496, 1.240369, 0.373649, 0.505113, 0.503723, 1.215340, -0.391908
});
}
private static Matrix getIllConditionedMatrix() {
return reshape(new double[] {
0.00695278043678842, 0.09911830022078683, 0.01309584636255063, 0.00652917453032394, 0.04337631487735064,
0.14232165273321387, 0.05808722912361313, -0.06591965049732287, 0.06055771542862332, 0.00577423310349649,
0.09911830022078683, 1.50071402418061428, 0.14988743575884242, 0.07195514527480981, 0.63747362341752722,
1.30711819020414688, 0.82151609385115953, -0.72616125524587938, 1.03490136002022948, 0.12800239664439328,
0.01309584636255063, 0.14988743575884242, 0.04068462583124965, 0.02147022047006482, 0.07388113580146650,
0.58070223915076002, 0.11280336266257514, -0.21690068430020618, 0.04065087561300068, -0.00876895259593769,
0.00652917453032394, 0.07195514527480981, 0.02147022047006482, 0.01140105250542524, 0.03624164348693958,
0.31291554581393255, 0.05648457235205666, -0.11507583016077780, 0.01475756130709823, -0.00584453679519805,
0.04337631487735064, 0.63747362341752722, 0.07388113580146649, 0.03624164348693959, 0.27491543200760571,
0.73410543168748121, 0.36120630002843257, -0.36583546331208316, 0.41472509341940017, 0.04581458758255480,
0.14232165273321387, 1.30711819020414666, 0.58070223915076002, 0.31291554581393255, 0.73410543168748121,
9.02536073121807014, 1.25426385582883104, -3.16186335125594642, -0.19740140818905436, -0.26613760880058035,
0.05808722912361314, 0.82151609385115953, 0.11280336266257514, 0.05648457235205667, 0.36120630002843257,
1.25426385582883126, 0.48661058451606820, -0.57030511336562195, 0.49151280464818098, 0.04428280690189127,
-0.06591965049732286, -0.72616125524587938, -0.21690068430020618, -0.11507583016077781, -0.36583546331208316,
-3.16186335125594642, -0.57030511336562195, 1.16270815038078945, -0.14837898963724327, 0.05917203395002889,
0.06055771542862331, 1.03490136002022926, 0.04065087561300068, 0.01475756130709823, 0.41472509341940023,
-0.19740140818905436, 0.49151280464818103, -0.14837898963724327, 0.86693820682049716, 0.14089688752570340,
0.00577423310349649, 0.12800239664439328, -0.00876895259593769, -0.00584453679519805, 0.04581458758255480,
-0.26613760880058035, 0.04428280690189126, 0.05917203395002889, 0.14089688752570340, 0.02901858439788401
}, 10, 10);
}
/*
private static Matrix getAsymmetricMatrix() {
return reshape(new double[] {
0.1586493402398226, -0.8668244036239467, 0.4335233711065471, -1.1025223577469705, 1.1344100191664601,
-0.1399944083742454, 0.8879750333144295, -1.2139664527957903, 0.7154591081557057, -0.6320890356949669,
-2.4546945723009581, 0.6354748667295935, -0.1931993736354496, -0.1210449542073575, -1.0668745874463414,
0.6539061600017384, 2.4045520271091063,-0.3387572116155693, 0.1575188740437142, 1.1791073500243496,
-0.6418745429181755, 0.6836410530720005, -1.2447493564334062, -1.8840081252627843, 0.5663864914859502,
0.0819203791124956, 0.2004407540793239, 0.7350145066687849, 1.6525377683305262, -0.3156915229969668,
-0.1866701463141060, -0.3929673444397022, -0.4440946700501859, 0.1366803303987421, -0.2138101381625466,
0.5399874351478779, -1.0088091882703056, 0.0978023083150833, 1.8795777615527958, 0.3782417618354363,
-0.4564752186043173, 0.4014814252832269, 1.9691150950571501, 0.2424686682362568, 1.0965758964799504,
0.2751725463132324, -0.6652756564294597, -0.6256564536463288, 1.0332457212107204, -0.0330851504958215,
-1.0402096493279287, -0.6850389655533707, -1.8896839974451625, 1.1533231017445102, -0.5387306882127710,
0.0181850207098213, -0.2416652193929706, -0.9868171673047287, -1.5872573189377035, -0.8492253650362955,
1.1949977792951225, 0.7901168665120927, 0.9832676055718492, -0.0752834029327588, 1.0555006468941126,
0.6842531633106009, 0.2589700378872499, 0.3565253337268334, 0.1869608474650344, -0.1696524825242293,
0.6919898638809949, -1.4937187919435133, 1.0039151841775080, -0.2580993333173019, 0.1243386429912411,
1.3945380460721688, 0.3078165489952902, 1.1248734111054359, 0.5613308856003306, -0.9013329415656699,
-0.9197179846787753, 0.1167372728291174, -0.7807620712716467, 0.2210918047063067, -0.4813869727362010,
0.3870067788770671, 1.1974416632199159, 2.4676804711420330, 1.8492990765211168, -1.3089887830472471,
-0.7587845769668021, -1.0354138253278353, -0.3907902473275445, -2.1292895670916168, -0.7544686049709807,
-0.3431317172534703, 1.4959721683724390, 0.6004852467523584, 1.2140230344223786, 0.1279148299232956
}, 20, 5);
}
private static Vector getSmallB() {
return new DenseVector(new double[] {
0.114065955249272,
0.953981568944476,
-2.611106316607759,
0.652190962446307,
1.298055218126384,
});
}
private static Matrix getLowrankSymmetricMatrix() {
Matrix m = new DenseMatrix(5,5);
Vector u = new DenseVector(new double[] {
-0.0364638798936962,
1.0219291133418171,
-0.5649933120375343,
-1.0050553315595800,
-0.5264178580727512
});
Vector v = new DenseVector(new double[] {
-1.345847117891187,
0.553386426498032,
1.912020072696648,
-0.820959934779948,
1.223358044171859
});
return m.plus(u.cross(u)).plus(v.cross(v));
}
private static Matrix getLowrankAsymmetricMatrix() {
Matrix m = new DenseMatrix(20,5);
Vector u = new DenseVector(new double[] {
-0.0364638798936962,
1.0219291133418171,
-0.5649933120375343,
-1.0050553315595800,
-0.5264178580727512
});
Vector v = new DenseVector(new double[] {
-1.345847117891187,
0.553386426498032,
1.912020072696648,
-0.820959934779948,
1.223358044171859
});
m.assignRow(0, u);
m.assignRow(0, v);
return m;
}
*/
private static Matrix reshape(double[] values, int rows, int columns) {
Matrix m = new DenseMatrix(rows, columns);
int i = 0;
for (double v : values) {
m.set(i % rows, i / rows, v);
i++;
}
return m;
}
}