/******************************************************************************* * Copyright 2007, 2009 Jorge Villalon (jorge.villalon@uai.cl), Stephen O'Rourke (stephen.orourke@sydney.edu.au) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package tml.vectorspace.factorisation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import Jama.Matrix; /** * NMF with Kullback-Leibler divergence minimisation). * * Details of this algorithm can be found in the paper: * * Lee, D. D., & Seung, H. S. (2001). Algorithms for Non-negative Matrix * Factorization. Paper presented at the Proceedings of the 2000 Conference on * Advances in Neural Information Processing Systems. */ public class NonnegativeMatrixFactorisationKL extends MatrixFactorisation { protected final Log logger = LogFactory.getLog(getClass()); protected Matrix initialW = null; protected Matrix initialH = null; protected Matrix w; protected Matrix h; protected static final double SMALL_VALUE = 10e-9; protected int maxIterations = 200; @Override public void process(Matrix v) { int m = v.getRowDimension(); int n = v.getColumnDimension(); int K2 = Math.min(K, Math.min(n, m) - 1); v = v.times(1 / v.norm1()); // initialise h if (initialH != null) { h = initialH.copy(); } else { h = Matrix.random(K2, n); } // initialise w if (initialW != null) { w = initialW.copy(); } else { w = Matrix.random(m, K2); } // perform update iterations double fnorm_previous = v.minus(w.times(h)).norm2(); for (int l = 0; l < maxIterations; l++) { // simultaneous update of w and h Matrix wh = w.times(h); Matrix h_copy = h.copy(); for (int c = 0; c < K2; c++) { // update h for (int j = 0; j < n; j++) { double sum1 = 0, sum2 = 0; for (int i = 0; i < m; i++) { sum1 += w.get(i, c) * v.get(i, j) / (wh.get(i, j) + SMALL_VALUE); sum2 += w.get(i, c); } h.set(c, j, h.get(c, j) * sum1 / sum2); } // update w for (int i = 0; i < m; i++) { double sum1 = 0, sum2 = 0; for (int j = 0; j < n; j++) { sum1 += h_copy.get(c, j) * v.get(i, j) / (wh.get(i, j) + SMALL_VALUE); sum2 += h_copy.get(c, j); } w.set(i, c, w.get(i, c) * sum1 / sum2); } } // check if converged double fnorm = v.minus(w.times(h)).norm2(); double change = Math.abs(fnorm_previous - fnorm); logger.debug(l + "\t change " + change); if (change <= SMALL_VALUE) { break; } fnorm_previous = fnorm; } decomposition = new SpaceDecomposition(); decomposition.setSkdata(Matrix.identity(K2, K2).getArray()); decomposition.setUkdata(w.getArray()); decomposition.setVkdata(h.transpose().getArray()); } public int getMaxIterations() { return maxIterations; } public void setMaxIterations(int maxIterations) { this.maxIterations = maxIterations; } }