/*******************************************************************************
* Copyright 2007, 2009 Jorge Villalon (jorge.villalon@uai.cl), Stephen O'Rourke (stephen.orourke@sydney.edu.au)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package tml.vectorspace.factorisation;
import Jama.Matrix;
/**
* NMF with Euclidean distance minimisation
*
* Details of this algorithm can be found in the paper:
*
* Lee, D. D., & Seung, H. S. (2001). Algorithms for Non-negative Matrix
* Factorization. Paper presented at the Proceedings of the 2000 Conference on
* Advances in Neural Information Processing Systems.
*/
public class NonnegativeMatrixFactorisationED extends NonnegativeMatrixFactorisationKL {
@Override
public void process(Matrix v) {
int m = v.getRowDimension();
int n = v.getColumnDimension();
int K2 = Math.min(K, Math.min(n, m) - 1);
// initialise h
if (initialH != null) {
h = initialH.copy();
} else {
h = Matrix.random(K2, n);
}
// initialise w
if (initialW != null) {
w = initialW.copy();
} else {
w = Matrix.random(m, K2);
}
// perform update iterations
double fnorm_previous = v.minus(w.times(h)).norm2();
for (int l = 0; l < maxIterations; l++) {
// simultaneous update of w and h
Matrix ht = h.transpose();
Matrix vht = v.times(ht);
Matrix whht = w.times(h).times(ht);
Matrix wt = w.transpose();
Matrix wtv = wt.times(v);
Matrix wtwh = wt.times(w).times(h);
for (int c = 0; c < K2; c++) {
// update h
for (int j = 0; j < n; j++) {
double value = h.get(c, j) * wtv.get(c, j) / (wtwh.get(c, j) + SMALL_VALUE);
h.set(c, j, value);
}
// update w
for (int i = 0; i < m; i++) {
double value = w.get(i, c) * vht.get(i, c) / (whht.get(i, c) + SMALL_VALUE);
w.set(i, c, value);
}
}
// normalise w columns vectors
for (int j = 0; j < K2; j++) {
double norm = 0;
for (int i = 0; i < m; i++) {
norm += Math.pow(w.get(i, j), 2);
}
norm = Math.sqrt(norm);
for (int i = 0; i < m; i++) {
w.set(i, j, w.get(i, j) / norm);
}
}
// check if converged
double fnorm = v.minus(w.times(h)).norm2();
double change = Math.abs(fnorm_previous - fnorm);
logger.debug(l + ".\t change " + fnorm);
if (change <= SMALL_VALUE) {
break;
}
fnorm_previous = fnorm;
}
decomposition = new SpaceDecomposition();
decomposition.setSkdata(Matrix.identity(K2, K2).getArray());
decomposition.setUkdata(w.getArray());
decomposition.setVkdata(h.transpose().getArray());
}
}