/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.impl.recommender.knn;
import java.util.Arrays;
public final class ConjugateGradientOptimizer implements Optimizer {
private static final double CONVERGENCE_LIMIT = 0.1;
private static final int MAX_ITERATIONS = 1000;
/**
* <p>
* Conjugate gradient optimization. Matlab code:
* </p>
*
* <p>
*
* <pre>
* function [x] = conjgrad(A,b,x0)
* x = x0;
* r = b - A*x0;
* w = -r;
* for i = 1:size(A);
* z = A*w;
* a = (r'*w)/(w'*z);
* x = x + a*w;
* r = r - a*z;
* if( norm(r) < 1e-10 )
* break;
* end
* B = (r'*z)/(w'*z);
* w = -r + B*w;
* end
* end
* </pre>
*
* </p>
*
* @param matrix
* matrix nxn positions
* @param b
* vector b, n positions
* @return vector of n weights
*/
@Override
public double[] optimize(double[][] matrix, double[] b) {
int k = b.length;
double[] x = new double[k];
double[] r = new double[k];
double[] w = new double[k];
double[] z = new double[k];
Arrays.fill(x, 3.0 / k);
// r = b - A*x0;
// w = -r;
for (int i = 0; i < k; i++) {
double v = 0.0;
double[] ai = matrix[i];
for (int j = 0; j < k; j++) {
v += ai[j] * x[j];
}
double ri = b[i] - v;
r[i] = ri;
w[i] = -ri;
}
for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
// z = A*w;
for (int i = 0; i < k; i++) {
double v = 0.0;
double[] ai = matrix[i];
for (int j = 0; j < k; j++) {
v += ai[j] * w[j];
}
z[i] = v;
}
// a = (r'*w)/(w'*z);
double anum = 0.0;
double aden = 0.0;
for (int i = 0; i < k; i++) {
anum += r[i] * w[i];
aden += w[i] * z[i];
}
double a = anum / aden;
// x = x + a*w;
// r = r - a*z;
for (int i = 0; i < k; i++) {
x[i] += a * w[i];
r[i] -= a * z[i];
}
// stop when residual is close to 0
double rdot = 0.0;
for (int i = 0; i < k; i++) {
double value = r[i];
rdot += value * value;
}
if (rdot <= CONVERGENCE_LIMIT) {
break;
}
// B = (r'*z)/(w'*z);
double bnum = 0.0;
double bden = 0.0;
for (int i = 0; i < k; i++) {
double zi = z[i];
bnum += r[i] * zi;
bden += w[i] * zi;
}
double B = bnum / bden;
// w = -r + B*w;
for (int i = 0; i < k; i++) {
w[i] = -r[i] + B * w[i];
}
}
return x;
}
}