/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* MahalanobisLearner.java
* Copyright (C) 2004 Mikhail Bilenko and Sugato Basu
*
*/
package weka.clusterers.metriclearners;
import java.util.*;
import weka.core.*;
import weka.core.metrics.*;
import weka.clusterers.MPCKMeans;
import weka.clusterers.InstancePair;
import Jama.Matrix;
/**
* A closed-form based learner for Mahalanobis
*
* @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu
* (sugato@cs.utexas.edu)
* @version $Revision: 1.5 $ */
public class MahalanobisLearner extends MPCKMeansMetricLearner {
/** min difference of objective function values for convergence*/
protected double m_minDet = 1e-5;
public void resetLearner() {
}
/** if clusterIdx is -1, all instances are used
* (a single metric for all clusters is used) */
public boolean trainMetric(int clusterIdx) throws Exception {
Init(clusterIdx);
Matrix updateMatrix = new Matrix(m_numAttributes, m_numAttributes);
int violatedConstraints = 0;
int numInstances = 0;
WeightedMahalanobis metric = (WeightedMahalanobis) m_metric;
Matrix maxMatrix = null;
if (m_instanceConstraintMap.size() > 0) {
if (clusterIdx == -1) {
maxMatrix = metric.createDiffMatrix(m_kmeans.m_maxCLPoints[0][0],
m_kmeans.m_maxCLPoints[0][1]);
} else {
maxMatrix = metric.createDiffMatrix(m_kmeans.m_maxCLPoints[clusterIdx][0],
m_kmeans.m_maxCLPoints[clusterIdx][1]);
}
maxMatrix = maxMatrix.times(0.5);
}
for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) {
int assignment = m_clusterAssignments[instIdx];
// only instances assigned to this cluster are of importance
if (assignment == clusterIdx || clusterIdx == -1) {
numInstances++;
if (clusterIdx < 0) {
m_centroid = m_kmeans.getClusterCentroids().instance(assignment);
}
Instance instance = m_instances.instance(instIdx);
Matrix diffMatrix = metric.createDiffMatrix(instance, m_centroid);
updateMatrix = updateMatrix.plus(diffMatrix);
// go through violated constraints
Object list = m_instanceConstraintMap.get(new Integer(instIdx));
if (list != null) { // there are constraints associated with this instance
ArrayList constraintList = (ArrayList) list;
for (int i = 0; i < constraintList.size(); i++) {
InstancePair pair = (InstancePair) constraintList.get(i);
int linkType = pair.linkType;
int firstIdx = pair.first;
int secondIdx = pair.second;
Instance instance1 = m_instances.instance(firstIdx);
Instance instance2 = m_instances.instance(secondIdx);
int otherIdx = (firstIdx == instIdx) ? m_clusterAssignments[secondIdx]
: m_clusterAssignments[firstIdx];
// check whether the constraint is violated
if (otherIdx != -1 ) {
if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) {
diffMatrix = metric.createDiffMatrix(instance1, instance2);
diffMatrix = diffMatrix.times(0.5);
updateMatrix = updateMatrix.plus(diffMatrix);
violatedConstraints++;
} else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK) {
diffMatrix = metric.createDiffMatrix(instance1, instance2);
diffMatrix = diffMatrix.times(0.5);
updateMatrix = updateMatrix.plus(maxMatrix);
updateMatrix = updateMatrix.minus(diffMatrix);
violatedConstraints++;
}
} // end while
}
}
}
}
updateMatrix = updateMatrix.times(1.0/numInstances);
double updateDet = updateMatrix.det();
int maxIterations = 1000;
int currIteration = 1;
Matrix newWeights = null;
// System.out.println("UPDATE weights: " + " (violated constraints: " + violatedConstraints + ")");
// for (int i = 0; i < updateMatrix.getArray().length; i++) {
// for (int j = 0; j < updateMatrix.getArray()[i].length; j++) {
// System.out.print((float)updateMatrix.getArray()[i][j] + "\t");
// }
// System.out.println();
// }
// check that the update matrix is non-singular
while (Math.abs(updateDet) < m_minDet && currIteration++ < maxIterations) {
Matrix regularizer = Matrix.identity(m_numAttributes, m_numAttributes);
regularizer = regularizer.times(updateMatrix.trace() * 0.01);
updateMatrix = updateMatrix.plus(regularizer);
System.out.print("\t" + currIteration + ". Singular update matrix, DET=" + (float)updateDet);
updateDet = updateMatrix.det();
System.out.println("; after regularization DET=" + (float)updateDet);
}
if (currIteration >= maxIterations) { // if the matrix is irrepairable, return to identity matrix
System.out.println("\n\nCOULDN'T REGULARIZE; GOING TO IDENTITY\n\n");
newWeights = Matrix.identity(m_numAttributes, m_numAttributes);
} else {
newWeights = updateMatrix.inverse();
}
// // check that matrix is positive semi-definite
// currIteration = 0;
// double det = newWeights.det();
// Matrix weightsSquare = newWeights.chol().getL();
// double sqDet = weightsSquare.det();
// while ((det < 0 || Math.abs(det) < m_ObjFunConvergenceDifference
// || Math.abs(sqDet) < m_ObjFunConvergenceDifference || Double.isNaN(sqDet))
// && currIteration++ < maxIterations) {
// // make sure the the matrix is symmetric positive definite
// if (det < 0) {
// EigenvalueDecomposition ed = newWeights.eig();
// Matrix eigenVectorsMatrix = ed.getV();
// double[] evalues = ed.getRealEigenvalues();
// double [][] evaluesM = new double[evalues.length][evalues.length];
// for (int i = 0; i < evalues.length; i++) {
// if (evalues[i] < 0) {
// evalues[i] = -evalues[i];
// } else {
// evaluesM[i][i] = evalues[i];
// }
// }
// Matrix eigenValuesMatrix = new Matrix(evaluesM);
// // update the weights: A' = V' * E * V
// newWeights = ((eigenVectorsMatrix.transpose()).times(eigenValuesMatrix)).times(eigenVectorsMatrix);
// System.out.println("\tNegative determinant; projecting for subsequent regularization");
// }
// // the weights matrix may end up singular (if determinant was negative, or det(updateMatrix) was very large
// sqDet = newWeights.chol().getL().det();
// det = newWeights.det();
// if (Math.abs(det) < m_ObjFunConvergenceDifference || Math.abs(sqDet) < m_ObjFunConvergenceDifference
// || Double.isNaN(sqDet)) {
// Matrix regularizer = Matrix.identity(m_numAttributes, m_numAttributes);
// regularizer = regularizer.times(newWeights.trace() * 0.01);
// newWeights = newWeights.plus(regularizer); // W = W + 0.01tr(W) * I
// System.out.println("\tsingular matrix, det=" + ((float)det) + ", sqDet=" + ((float)sqDet) +
// "\tafter FIXING AND REGULARIZATION det=" + newWeights.det());
// det = newWeights.det();
// sqDet = newWeights.chol().getL().det();
// }
// }
// // if the matrix is irrepairable, return to identity matrix
// if (currIteration >= maxIterations) {
// newWeights = Matrix.identity(m_numAttributes, m_numAttributes);
// }
metric.setWeights(newWeights);
// project all the instances for subsequent calculation of max-points for cannot-link penalties
for (int instIdx=0; instIdx<m_instances.numInstances(); instIdx++) {
if (clusterIdx < 0 || m_clusterAssignments[instIdx] == clusterIdx) {
metric.projectInstance(m_instances.instance(instIdx));
}
}
return true;
}
/**
* Gets the current settings of KL
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [1];
int current = 0;
while (current < options.length) {
options[current++] = "";
}
return options;
}
public void setOptions(String[] options) throws Exception {
// TODO: add later
}
public Enumeration listOptions() {
// TODO: add later
return null;
}
}
// protected void updateMetricWeightsMahalanobisGD() throws Exception {
// WeightedMahalanobis metric = (WeightedMahalanobis) m_metric;
// int numAttributes = m_Instances.numAttributes();
// Instance diffInstance;
// int violatedConstraints = 0;
// Matrix newWeights = metric.getWeightsMatrix().copy();
// // Do the GD
// int iteration = 0;
// boolean converged = false;
// // precompute the update matrix for maxCannotLinkInstance
// double[][] maxCLUpdate = new double[numAttributes][numAttributes];
// Instance maxCLDiffInstance = null;
// if (m_maxCLPoints != null) {
// maxCLDiffInstance = metric.createDiffInstance(m_maxCLPoints[0][0],
// m_maxCLPoints[0][1]);
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j <=i; j++) {
// maxCLUpdate[i][j] =
// maxCLUpdate[j][i] =
// maxCLDiffInstance.value(i) *maxCLDiffInstance.value(j);
// }
// }
// }
// // store the constant part of the gradient:
// double[][] gradientConst = new double[numAttributes][numAttributes];
// for (int instIdx = 0; instIdx < m_Instances.numInstances(); instIdx++) {
// // the (x-m)(x-m)' part
// int centroidIdx = m_ClusterAssignments[instIdx];
// Instance centroid = m_ClusterCentroids.instance(centroidIdx);
// diffInstance = metric.createDiffInstance(m_Instances.instance(instIdx),
// centroid);
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j <= i; j++) {
// gradientConst[i][j] =
// gradientConst[j][i] = diffInstance.value(i) * diffInstance.value(j);
// }
// }
// // the violated constraints
// Object list = m_instanceConstraintHash.get(new Integer(instIdx));
// if (list != null) { // there are constraints associated with this instance
// ArrayList constraintList = (ArrayList) list;
// for (int constrIdx = 0; constrIdx < constraintList.size(); constrIdx++) {
// InstancePair pair = (InstancePair) constraintList.get(constrIdx);
// int firstIdx = pair.first;
// int secondIdx = pair.second;
// double cost = 0;
// if (pair.linkType == InstancePair.MUST_LINK) {
// cost = m_MLweight;
// } else if (pair.linkType == InstancePair.CANNOT_LINK) {
// cost = m_CLweight;
// }
// Instance instance1 = m_Instances.instance(firstIdx);
// Instance instance2 = m_Instances.instance(secondIdx);
// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx]
// : m_ClusterAssignments[firstIdx];
// if (otherIdx == -1) {
// throw new Exception("One of the instances is unassigned in "
// + "updateMetricWeightsMahalanobisGD");
// }
// // check whether the constraint is violated
// if (otherIdx != centroidIdx &&
// pair.linkType == InstancePair.MUST_LINK) {
// diffInstance = metric.createDiffInstance(instance1, instance2);
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j <= i; j++) {
// gradientConst[i][j] =
// gradientConst[j][i] =
// 0.5 * cost * diffInstance.value(i) * diffInstance.value(j);
// }
// }
// violatedConstraints++;
// } else if (otherIdx == centroidIdx &&
// pair.linkType == InstancePair.CANNOT_LINK) {
// diffInstance = metric.createDiffInstance(instance1, instance2);
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j <= i; j++) {
// gradientConst[i][j] =
// gradientConst[j][i] =
// 0.5 * cost *
// (maxCLUpdate[i][j] -
// diffInstance.value(i) * diffInstance.value(j));
// }
// }
// violatedConstraints++;
// }
// }
// }
// }
// Matrix constUpdate = new Matrix(gradientConst);
// while (iteration < m_maxGDIterations && !converged) {
// // calculate the gradient
// Matrix update = constUpdate.copy();
// // factor in the A^-1
// Matrix Ai = newWeights.inverse();
// Ai.timesEquals(m_logTermWeight);
// update.minusEquals(Ai);
// // regularization (-1/sum(a_ij)^2)
// double regularizer = 0;
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j <= i; j++) {
// regularizer += 2.0/(newWeights.get(i, j) * newWeights.get(i, j));
// }
// }
// // correct for double-counted diagonal
// for (int i = 0; i < numAttributes; i++) {
// regularizer -= 1.0/newWeights.get(i, i);
// }
// regularizer *= m_currregularizerTermWeight;
// for (int i = 0; i < numAttributes; i++) {
// for (int j = 0; j < numAttributes; j++) {
// update.set(i, j, update.get(i,j) - regularizer);
// }
// }
// // update
// update.timesEquals(m_currEta);
// newWeights.minusEquals(update);
// // anneal if necessary and check for convergence
// m_currEta = m_currEta * m_etaDecayRate;
// // check for convergence
// double norm = update.norm1();
// System.out.println(iteration + ": norm=" + norm);
// if (norm < 0.0001) {
// converged = true;
// }
// iteration++;
// }
// // We're done, set the weights to newWeights
// }
// MULTIPLE:
// /** M-step of the KMeans clustering algorithm -- updates metric
// * weights. Invoked only when metric is an instance of Mahalanobis
// * @return value true if everything was alright; false if there was
// miserable failure and clustering needs to be restarted */
// protected boolean updateMultipleMetricWeightsMahalanobis() throws Exception {
// if (m_regularizeWeights) {
// System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsMahalanobisGD!");
// updateMultipleMetricWeightsMahalanobisGD();
// }
// int numAttributes = m_Instances.numAttributes();
// if (m_Instances.classIndex() >= 0) {
// numAttributes--;
// }
// Matrix [] updateMatrices = new Matrix[m_metrics.length];
// for (int i = 0; i < updateMatrices.length; i++) {
// updateMatrices[i] = new Matrix(numAttributes, numAttributes);
// }
// int violatedConstraints = 0;
// int [] counts = new int[updateMatrices.length];
// for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {
// int centroidIdx = m_ClusterAssignments[instIdx];
// Matrix diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_Instances.instance(instIdx),
// m_ClusterCentroids.instance(centroidIdx));
// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix);
// counts[centroidIdx]++;
// // go through violated constraints
// Object list = m_instanceConstraintHash.get(new Integer(instIdx));
// if (list != null) { // there are constraints associated with this instance
// ArrayList constraintList = (ArrayList) list;
// for (int i = 0; i < constraintList.size(); i++) {
// InstancePair pair = (InstancePair) constraintList.get(i);
// int firstIdx = pair.first;
// int secondIdx = pair.second;
// Instance instance1 = m_Instances.instance(firstIdx);
// Instance instance2 = m_Instances.instance(secondIdx);
// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];
// // check whether the constraint is violated
// if (otherIdx != -1) {
// if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) {
// Matrix diffMatrix1 = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);
// diffMatrix1 = diffMatrix1.times(0.25);
// Matrix diffMatrix2 = ((WeightedMahalanobis) m_metrics[otherIdx]).createDiffMatrix(instance1, instance2);
// diffMatrix2 = diffMatrix2.times(0.25);
// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix1);
// updateMatrices[otherIdx] = updateMatrices[otherIdx].plus(diffMatrix2);
// violatedConstraints++;
// } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) {
// diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);
// Matrix maxMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_maxCLPoints[centroidIdx][0],
// m_maxCLPoints[centroidIdx][1]);
// diffMatrix = diffMatrix.times(0.5);
// maxMatrix = maxMatrix.times(0.5);
// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(maxMatrix);
// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].minus(diffMatrix);
// violatedConstraints++;
// }
// } // end while
// }
// }
// }
// int [][] classCounts = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];
// // NB: m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!
// // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...
// for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {
// Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);
// classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;
// }
// for (int i = 0; i < m_NumClusters; i++){
// System.out.print("Cluster " + i + "(" + counts[i] + ")\t" + classCounts[i][0]);
// for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {
// System.out.print("\t" + classCounts[i][j]);
// }
// System.out.println();
// }
// // now update the actual weight matrices
// for (int i = 0; i < updateMatrices.length; i++) {
// int maxIterations = 100;
// if (counts[i] == 0) {
// //System.out.println("Cluster " + i + " has lost all instances; leaving weights as is");
// updateMatrices[i] = Matrix.identity(numAttributes, numAttributes);
// counts[i] = 1;
// //System.err.println("IRREPAIRABLE COVARIANCE MATRIX, RESTARTING");
// //return false;
// }
// updateMatrices[i] = updateMatrices[i].times(1.0/counts[i]);
// double updateDet = updateMatrices[i].det();
// int currIteration = 0;
// Matrix newWeights = null;
// // check that the update matrix is non-singular
// while (Math.abs(updateDet) < m_NRConvergenceDifference && currIteration++ < maxIterations) {
// Matrix regularizer = Matrix.identity(numAttributes, numAttributes);
// regularizer = regularizer.times(updateMatrices[i].trace() * 0.01);
// updateMatrices[i] = updateMatrices[i].plus(regularizer);
// System.out.print(i + "\tsingular UPDATE matrix, DET=" + ((float)updateDet));
// updateDet = updateMatrices[i].det();
// System.out.println("; after regularization DET=" + ((float)updateDet));
// // System.out.println("ACTUAL weights: ");
// // double[][] m_weights = updateMatrices[i].getArray();
// // for (int l = 0; l < m_weights.length; l++) {
// // for (int j = 0; j < m_weights[l].length; j++) {
// // System.out.print(((float)m_weights[l][j]) + "\t");
// // }
// // System.out.println();
// // }
// }
// if (currIteration >= maxIterations) { // if the matrix is irrepairable, return to identity matrix
// newWeights = Matrix.identity(numAttributes, numAttributes);
// System.err.println("IRREPAIRABLE UPDATE MATRIX, RESTARTING");
// } else {
// newWeights = updateMatrices[i].inverse();
// }
// ((WeightedMahalanobis) m_metrics[i]).setWeights(newWeights);
// // project all the instances for subsequent calculation of max-points for cannot-link penalties
// // TODO: we are projecting ALL instances just in case... possibly can optimize in the future
// for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {
// ((WeightedMahalanobis) m_metrics[i]).projectInstance(m_Instances.instance(instIdx));
// }
// }
// return true;
// }