/******************************************************************************* * Copyright 2007, 2009 Stephen O'Rourke (stephen.orourke@sydney.edu.au) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package tml.vectorspace.factorisation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import tml.utils.MatrixUtils; import Jama.Matrix; /** * Probabilistic latent semantic analysis (PLSA) * * An explanation of the algorithm can be found in the paper: * * Hofmann, T. (1999). Probabilistic Latent Semantic Indexing. Paper presented * at the Proceedings of the 22nd annual international ACM SIGIR conference on * Research and development in information retrieval. * * @author Stephen O'Rourke * */ public class ProbabilisticLatentSemanticAnalysis extends MatrixFactorisation { private final Log logger = LogFactory.getLog(getClass()); private static final double SMALL_VALUE = 10e-9; private Matrix Pz; private Matrix Pz_diag; private Matrix Pd_z; private Matrix Pw_z; private int maxIterations = 200; private double tolerence = 0.01; @Override public void process(Matrix x) { int m = x.getRowDimension(); int n = x.getColumnDimension(); int K2 = Math.min(K, Math.min(n, m) - 1); // initialise Pz, Pd_z, Pw_z Pz = new Matrix(K2, 1, 1); Pd_z = Matrix.random(n, K2); Pw_z = Matrix.random(m, K2); // normalise columns to sum to 1 Pz = MatrixUtils.normalizeColumnsL1(Pz); Pd_z = MatrixUtils.normalizeColumnsL1(Pd_z); Pw_z = MatrixUtils.normalizeColumnsL1(Pw_z); // initialise matrices for the posterior Matrix[] Pz_dw = new Matrix[K2]; for (int k = 0; k < K2; k++) { Pz_dw[k] = new Matrix(m, n); } double li_previous = 0; // start EM algorithm for (int l = 0; l < maxIterations; l++) { //E step, compute posterior on z for (int k = 0; k < K2; k++) { Pz_dw[k] = Pw_z.getMatrix(0, m - 1, k, k).times(Pd_z.getMatrix(0, n - 1, k, k).transpose()).times(Pz.get(k, 0)); } // normalise posterior for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { double sum = 0; for (int k = 0; k < K2; k++) { sum += Pz_dw[k].get(i, j); } for (int k = 0; k < K2; k++) { Pz_dw[k].set(i, j, Pz_dw[k].get(i, j) / sum); } } } //M step, maximise log-likelihood for (int k = 0; k < K2; k++) { Matrix Pw_k = x.arrayTimes(Pz_dw[k]).times(new Matrix(n, 1, 1)); Pw_z.setMatrix(0, m - 1, k, k, Pw_k); } for (int k = 0; k < K2; k++) { Matrix Pd_k = x.arrayTimes(Pz_dw[k]).transpose().times(new Matrix(m, 1, 1)); Pd_z.setMatrix(0, n - 1, k, k, Pd_k); } Pz = Pd_z.transpose().times(new Matrix(n, 1, 1)); // normalise columns to sum to 1 Pw_z = MatrixUtils.normalizeColumnsL1(Pw_z); Pd_z = MatrixUtils.normalizeColumnsL1(Pd_z); Pz = MatrixUtils.normalizeColumnsL1(Pz); // calculate log-likelihood Pz_diag = Matrix.identity(K2, K2); for (int k = 0; k < K2; k++) { Pz_diag.set(k, k, Pz.get(k, 0)); } Matrix logMatrix = Pw_z.times(Pz_diag).times(Pd_z.transpose()); for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { logMatrix.set(i, j, Math.log(logMatrix.get(i, j) + SMALL_VALUE)); } } double li = x.arrayTimes(logMatrix).norm1(); // add small value to Pw_z. Pw_z = Pw_z.plus(new Matrix(m, K2, SMALL_VALUE)); // check for convergence if (l > 1) { double change = li_previous - li; logger.debug(l + ".\t log-likelihood change " + change); if (change < tolerence) { break; } } li_previous = li; } decomposition = new SpaceDecomposition(); decomposition.setSkdata(Pz_diag.getArray()); decomposition.setUkdata(Pw_z.getArray()); decomposition.setVkdata(Pd_z.getArray()); } public int getMaxIterations() { return maxIterations; } public void setMaxIterations(int maxIterations) { this.maxIterations = maxIterations; } public double getTolerence() { return tolerence; } public void setTolerence(double tolerence) { this.tolerence = tolerence; } }