/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.kernel.rvm; import java.util.Iterator; import java.util.LinkedList; import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction; import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial; import com.rapidminer.operator.learner.functions.kernel.rvm.util.SECholeskyDecomposition; import Jama.Matrix; /** * The standard (slow, non scaling) regression RVM (see bla). * * @author Piotr Kasprzak, Ingo Mierswa * @version $Id: RVMClassification.java,v 1.3 2008/05/09 19:22:57 ingomierswa Exp $ */ public class RVMClassification extends RVMBase { /** Constructor */ public RVMClassification(ClassificationProblem problem, Parameter parameter) { super(problem, parameter); } /** The hard work is done here */ public Model learn() { ClassificationProblem problem = (ClassificationProblem) this.problem; int numExamples = problem.getProblemSize(); int numBases = numExamples + 1; /** Set iteration control parameters */ int prune_point = 50; // Iteration number after which we switch to analytic pruning (in percent) prune_point = parameter.maxIterations * prune_point / 100; /** Create design/basis matrix PHI (N x M) with: * N: number of examples * M: number of examples + 1 (because of the bias in the first column) * PHI(n,m) = phi_m(x_n) = K_m(x_n, x_m) [K being the kernel; x_n, x_m the n-the/m-th example] */ double[][] x = problem.getInputVectors(); KernelBasisFunction[] kernels = problem.getKernels(); double[][] PHI = new double[numExamples][numBases]; int i, j; for (j = 0; j < numBases - 1; j++) { for (i = 0; i < numExamples; i++) { PHI[i][j + 1] = kernels[j + 1].eval(x[i]); } } // Set bias for (i = 0; i < numExamples; i++) { PHI[i][0] = 1.0; } /** Init weights/alpha/gamma/index vectors */ double[] alphas = new double[numBases]; // inverse variances for the weights for (j = 0; j < numBases; j++) { alphas[j] = parameter.initAlpha; } /** Matrix-esize everything */ Matrix matrixPHI = new Matrix(PHI); Matrix vectorAlpha = new Matrix(alphas, numBases); Matrix vectorWeights = new Matrix(numBases, 1, 0.0); Matrix prunedVectorWeights = null; Matrix matrixU = null; // Cholesky decomposition of IRLS-Hessian = Sigma (Laplace-Approx.) Matrix matrixUInv = null; // The inverse of the above LinkedList<Integer> unprunedIndicesList = null; // List of indices of unpruned alphas/weights/basisfuntions (= relevance vectors) int[] unprunedIndicesArray = null; // Array of the above, is also used for the reverse mapping of the indices of a pruned // vector to the one of a unpruned vector /** The main iteration */ for (i = 1; i <= parameter.maxIterations; i++) { // Prune associated basis functions / weights for too big alphas unprunedIndicesList = new LinkedList<Integer>(); for (j = 0; j < numBases; j++) { if (vectorAlpha.get(j, 0) >= parameter.alpha_max) { // pruned } else { unprunedIndicesList.add(Integer.valueOf(j)); } } unprunedIndicesArray = new int[unprunedIndicesList.size()]; Iterator iter = unprunedIndicesList.iterator(); for (j = 0; j < unprunedIndicesList.size(); j++) { unprunedIndicesArray[j] = ((Integer) iter.next()).intValue(); } Matrix prunedMatrixPHI = matrixPHI.getMatrix(0, matrixPHI.getRowDimension() - 1, unprunedIndicesArray); Matrix prunedVectorAlpha = vectorAlpha.getMatrix(unprunedIndicesArray, 0, 0); prunedVectorWeights = vectorWeights.getMatrix(unprunedIndicesArray, 0, 0); /** Use the iterative reweighted least algorithm (IRLS) to find the maximum w_mp of the marginal likelihood * with respect to the weights */ double minGradientChange = 1e-6; // Convergence criterion for IRLS double minLambda = Math.pow(2, -8); // Maximum overshoot criterion Matrix matrixAlphaDiag = new Matrix(prunedVectorAlpha.getRowDimension(), prunedVectorAlpha.getRowDimension(), 0); for (j = 0; j < prunedVectorAlpha.getRowDimension(); j++) { matrixAlphaDiag.set(j, j, prunedVectorAlpha.get(j, 0)); } Matrix vectorY = prunedMatrixPHI.times(prunedVectorWeights); // Y = sigmoid(PHI * w); for (int k = 0; k < vectorY.getRowDimension(); k++) { vectorY.set(k, 0, sigmoid(vectorY.get(k, 0))); } /** Instead of maximising the marginal likelihood function, we minimize the negative log */ double dataTerm = 0; int[] t = problem.getTargetVectors(); for (int k = 0; k < t.length; k++) { if (t[k] == 1) { dataTerm -= Math.log(vectorY.get(k, 0)); } else { dataTerm -= Math.log(1.0 - vectorY.get(k ,0)); } } double penaltyTerm = 0; for (int k = 0; k < prunedVectorAlpha.getRowDimension(); k++) { penaltyTerm += prunedVectorAlpha.get(k, 0) * prunedVectorWeights.get(k, 0) * prunedVectorWeights.get(k, 0); } double error = (dataTerm + penaltyTerm / 2.0) / problem.getProblemSize(); for (j = 0; j < 25; j++) { /** Calculate IRLS-Hessian: * Hessian = X' * W * X * = PHI' * W_IRLS * PHI * = (W_IRLS * PHI)' * PHI */ Matrix matrixIRLSWeights = new Matrix(prunedMatrixPHI.getRowDimension(), prunedMatrixPHI.getRowDimension(), 0); for (int k = 0; k < matrixIRLSWeights.getRowDimension(); k++) { matrixIRLSWeights.set(k, k, vectorY.get(k, 0) * (1.0 - vectorY.get(k, 0))); } Matrix matrixHessian = prunedMatrixPHI.transpose().times(matrixIRLSWeights).times(prunedMatrixPHI); matrixHessian.plusEquals(matrixAlphaDiag); /** Calculate the IRLS-gradient */ Matrix vectorE = new Matrix(vectorY.getRowDimension(), 1, 0.0); for (int k = 0; k < vectorY.getRowDimension(); k++) { vectorE.set(k, 0, t[k] - vectorY.get(k, 0)); } Matrix vectorPenalty = (Matrix) prunedVectorAlpha.clone(); for (int k = 0; k < vectorPenalty.getRowDimension(); k++) { vectorPenalty.set(k, 0, vectorPenalty.get(k, 0) * prunedVectorWeights.get(k, 0)); } Matrix vectorGradient = prunedMatrixPHI.transpose().times(vectorE).minus(vectorPenalty); /** Check for badly conditioned hessian */ // if (j == 0) { // double cond = matrixHessian.cond(); // if (cond < Math.pow(2, -52)) { // Compare condition to mantissa-precision of the "double"-type // LogService.logMessage("(IRLS) ill-conditioned hession: cond = " + (new Double(cond)).toString(), LogService.STATUS); // LogService.logMessage("(IRLS) returning back to hyperparameter estimation ...", LogService.STATUS); // // /** Let's hope we have something to continue with ;) ... */ // break; // } // } /** Get inverse factor of hessian */ SECholeskyDecomposition CD = new SECholeskyDecomposition(matrixHessian.getArray()); matrixU = CD.getPTR().times(CD.getL()); matrixUInv = matrixU.inverse(); /** Check for IRLS-convergence */ if (j >= 2 && (vectorGradient.normF() / prunedVectorWeights.getRowDimension()) < minGradientChange) { break; } /** Make an IRLS step */ Matrix vectorDeltaWeights = matrixUInv.transpose().times((matrixUInv.times(vectorGradient))); double lambda = 1; /** Prevent Maximum overshooting */ while (lambda > minLambda) { /** w_new = w + labmda * w_delta */ Matrix vectorNewWeights = ((Matrix) prunedVectorWeights.clone()).plus(vectorDeltaWeights.times(lambda)); /** Recalculate error with respect to w_new */ vectorY = prunedMatrixPHI.times(vectorNewWeights); // Y = sigmoid(PHI * w); for (int k = 0; k < vectorY.getRowDimension(); k++) { vectorY.set(k, 0, sigmoid(vectorY.get(k, 0))); } dataTerm = 0; for (int k = 0; k < t.length; k++) { if (t[k] == 1) { dataTerm -= Math.log(vectorY.get(k, 0)); } else { dataTerm -= Math.log(1.0 - vectorY.get(k ,0)); } } penaltyTerm = 0; for (int k = 0; k < prunedVectorAlpha.getRowDimension(); k++) { penaltyTerm += prunedVectorAlpha.get(k, 0) * vectorNewWeights.get(k, 0) * vectorNewWeights.get(k, 0); } double error_new = (dataTerm + penaltyTerm / 2.0) / problem.getProblemSize(); /** If overshot, back off */ if (error_new > error) { lambda = lambda / 2.0; continue; } /** Everything's nice, we can continue with the next IRLS-step */ prunedVectorWeights = vectorNewWeights; break; } } // Get diagonal elements of covariance matrix SIGMA double[] diagSIGMA = new double[matrixUInv.getRowDimension()]; for (j = 0; j < diagSIGMA.length; j++) { double value = 0; for (int k = 0; k < diagSIGMA.length; k++) { value += matrixUInv.get(k, j) * matrixUInv.get(k, j); } diagSIGMA[j] = value; } // Calculate gammas: gamma = 1 - alpha * SIGMA_ii double[] gammas = new double[diagSIGMA.length]; for (j = 0; j < gammas.length; j++) { gammas[j] = 1.0 - prunedVectorAlpha.get(j, 0) * diagSIGMA[j]; } // Get log alphas double[] logAlphas = new double[prunedVectorAlpha.getRowDimension()]; for (j = 0; j < logAlphas.length; j++) { logAlphas[j] = Math.log(prunedVectorAlpha.get(j, 0)); } // Alpha update: alpha = gamma / mu^2 = gamma / w^2; for (j = 0; j < prunedVectorAlpha.getRowDimension(); j++) { double newAlpha = gammas[j] / (prunedVectorWeights.get(j, 0) * prunedVectorWeights.get(j, 0)); prunedVectorAlpha.set(j, 0, newAlpha); } // Check for iteration abort double maxLogAlphaChange = 0; for (j = 0; j < logAlphas.length; j++) { double change = Math.abs(logAlphas[j] - Math.log(prunedVectorAlpha.get(j, 0))); if (change > maxLogAlphaChange) maxLogAlphaChange = change; } if (maxLogAlphaChange < parameter.min_delta_log_alpha) { break; } // update the (unpruned) alpha vector with the corresponding values from the pruned alpha vector for (j = 0; j < prunedVectorAlpha.getRowDimension(); j++) { vectorAlpha.set(unprunedIndicesArray[j], 0, prunedVectorAlpha.get(j, 0)); } } /** Create model */ double[] finalWeights = new double[unprunedIndicesArray.length]; KernelBasisFunction[] finalKernels = new KernelBasisFunction[unprunedIndicesArray.length]; boolean bias = false; for (j = 0; j < unprunedIndicesArray.length; j++) { finalWeights[j] = prunedVectorWeights.get(j, 0); if (unprunedIndicesArray[j] == 0) { // bias wasn't pruned bias = true; finalKernels[j] = new KernelBasisFunction(new KernelRadial()); } else { finalKernels[j] = kernels[unprunedIndicesArray[j]]; } } Model model = new Model(finalWeights, finalKernels, bias, false); return model; } /** Sigmoid link function */ public double sigmoid(double x) { return 1.0 / (1.0 + Math.exp(-x)); } /** Identify the RVM */ public String toString() { return "Classification-RVM"; } }