/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* WEuclideanLearnerGD.java
* Copyright (C) 2004 Mikhail Bilenko and Sugato Basu
*
*/
package weka.clusterers.metriclearners;
import java.util.*;
import weka.core.*;
import weka.core.metrics.*;
import weka.clusterers.MPCKMeans;
import weka.clusterers.InstancePair;
/**
* A gradient-descent based learner for WeightedEuclidean
*
* @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu
* (sugato@cs.utexas.edu)
* @version $Revision: 1.5 $ */
public class WEuclideanGDLearner extends GDMetricLearner {
/** if clusterIdx is -1, all instances are used
* (a single metric for all clusters is used) */
public boolean trainMetric(int clusterIdx) throws Exception {
Init(clusterIdx);
double [] gradients = new double[m_numAttributes];
Instance diffInstance;
int violatedConstraints = 0;
double [] currentWeights = m_metric.getWeights();
int numInstances = m_instances.numInstances();
double [] regularizerComponents = new double[m_numAttributes];
if (m_regularize) {
regularizerComponents = InitRegularizerComponents(currentWeights);
}
for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) {
int assignment = m_clusterAssignments[instIdx];
// only instances assigned to this cluster are of importance
if (assignment == clusterIdx || clusterIdx == -1) {
Instance instance = m_instances.instance(instIdx);
numInstances++;
if (clusterIdx < 0) {
m_centroid = m_kmeans.getClusterCentroids().instance(assignment);
}
diffInstance = m_metric.createDiffInstance(instance, m_centroid);
for (int attr = 0; attr < m_numAttributes; attr++) {
gradients[attr] += diffInstance.value(attr); // Euclidean components
if (currentWeights[attr] > 0) {
gradients[attr] -= m_logTermWeight/currentWeights[attr]; // log components
// if (m_regularize) {
// regularizerComponents[attr] = m_regularizerTermWeight *
// m_metric.getRegularizer().gradient(currentWeights[attr]);
// } else {
// regularizerComponents[attr] = 0;
// }
}
}
// go through violated constraints
Object list = m_instanceConstraintMap.get(new Integer(instIdx));
if (list != null) { // there are constraints associated with this instance
ArrayList constraintList = (ArrayList) list;
for (int i = 0; i < constraintList.size(); i++) {
InstancePair pair = (InstancePair) constraintList.get(i);
int linkType = pair.linkType;
int firstIdx = pair.first;
int secondIdx = pair.second;
Instance instance1 = m_instances.instance(firstIdx);
Instance instance2 = m_instances.instance(secondIdx);
int otherIdx = (firstIdx == instIdx) ?
m_clusterAssignments[secondIdx] : m_clusterAssignments[firstIdx];
// check whether the constraint is violated
if (otherIdx != -1) {
if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) {
diffInstance = m_metric.createDiffInstance(instance1, instance2);
for (int attr = 0; attr < m_numAttributes; attr++) {
gradients[attr] += 0.5 * m_MLweight * diffInstance.value(attr);
}
violatedConstraints++;
} else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK){
diffInstance = m_metric.createDiffInstance(instance1, instance2);
for (int attr = 0; attr < m_numAttributes; attr++) {
// this constraint will be counted twice, hence 0.5
gradients[attr] += 0.5 * m_CLweight * m_maxCLDiffInstance.value(attr);
gradients[attr] -= 0.5 * m_CLweight * diffInstance.value(attr);
}
violatedConstraints++;
}
} // end while
}
}
}
}
double [] newWeights = GDUpdate(currentWeights, gradients, regularizerComponents);
m_metric.setWeights(newWeights);
System.out.println(" Total constraints violated: " + violatedConstraints/2);
return true;
}
/// OLD CODE FOR MULTIPLE
// /** M-step of the KMeans clustering algorithm -- updates Euclidean
// * metric weights for the individual metrics using gradient
// * descent. Invoked only when m_regularizeWeights is true and
// * metric is trainable */
// protected boolean updateMultipleMetricWeightsEuclideanGD() throws Exception {
// // SUGATO: Added regularization code to updateMultipleMetricWeightsEuclideanGD
// int numAttributes = m_Instances.numAttributes();
// double[][] gradients = new double[m_NumClusters][numAttributes];
// double[][] regularizerComponents = new double[m_NumClusters][numAttributes];
// int [] clusterCounts = new int[m_NumClusters]; // count how many instances are in each cluster
// Instance diffInstance;
// double [][] currentWeights = new double[m_NumClusters][numAttributes];
// for (int i=0; i<m_NumClusters; i++) {
// currentWeights[i] = ((LearnableMetric) m_metrics[i]).getWeights();
// }
// //begin debugging variance
// boolean debugVariance = true;
// double[][] trueWeights = new double[m_NumClusters][numAttributes];
// int [] majorityClasses = new int[m_NumClusters];
// int [][] classCounts = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];
// // get the majority counts
// // NB: m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!
// // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...
// for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {
// Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);
// classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;
// }
// for (int i = 0; i < m_NumClusters; i++){
// int majorityClass = 0;
// System.out.print("Cluster" + i + "\t" + classCounts[i][0]);
// for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {
// System.out.print("\t" + classCounts[i][j]);
// if (classCounts[i][j] > classCounts[i][majorityClass]) {
// majorityClass = j;
// }
// }
// System.out.println();
// majorityClasses[i] = majorityClass;
// }
// class MajorityChecker {
// int [] m_majorityClasses = null;
// public MajorityChecker(int [] majClasses) { m_majorityClasses = majClasses;}
// public boolean belongsToMajority(Instances instances, int instIdx, int centroidIdx) {
// // silly, must pass instance since can't access outer class fields otherwise from a local inner class
// Instance fullInstance = instances.instance(instIdx);
// int classValue = (int) fullInstance.classValue();
// if (classValue == m_majorityClasses[centroidIdx]) {
// return true;
// } else {
// return false;
// }
// }
// }
// MajorityChecker majChecker = new MajorityChecker(majorityClasses);
// //end debugging variance
// int violatedConstraints = 0;
// for (int i=0; i<m_NumClusters; i++){
// for (int attr=0; attr<numAttributes; attr++) {
// regularizerComponents[i][attr] = 0;
// }
// }
// for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {
// int centroidIdx = m_ClusterAssignments[instIdx];
// diffInstance = m_metrics[centroidIdx].createDiffInstance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));
// for (int attr=0; attr<numAttributes; attr++) {
// gradients[centroidIdx][attr] += diffInstance.value(attr); // Mahalanobis components
// if (currentWeights[centroidIdx][attr] > 0) {
// gradients[centroidIdx][attr] -= 1/currentWeights[centroidIdx][attr]; // log components
// if (m_regularizeWeights) {
// regularizerComponents[centroidIdx][attr] += m_currregularizerTermWeight/(currentWeights[centroidIdx][attr] * currentWeights[centroidIdx][attr]);
// } else {
// regularizerComponents[centroidIdx][attr] = 0;
// }
// }
// if (debugVariance && instIdx < m_TotalTrainWithLabels.numInstances()) {
// if (majChecker.belongsToMajority(m_TotalTrainWithLabels, instIdx, centroidIdx)) {
// trueWeights[centroidIdx][attr] += diffInstance.value(attr);
// }
// }
// }
// clusterCounts[centroidIdx]++;
// Object list = m_instanceConstraintHash.get(new Integer(instIdx));
// if (list != null) { // there are constraints associated with this instance
// ArrayList constraintList = (ArrayList) list;
// for (int i = 0; i < constraintList.size(); i++) {
// InstancePair pair = (InstancePair) constraintList.get(i);
// int firstIdx = pair.first;
// int secondIdx = pair.second;
// double cost = 0;
// if (pair.linkType == InstancePair.MUST_LINK) {
// cost = m_MLweight;
// } else if (pair.linkType == InstancePair.CANNOT_LINK) {
// cost = m_CLweight;
// }
// Instance instance1 = m_Instances.instance(firstIdx);
// Instance instance2 = m_Instances.instance(secondIdx);
// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];
// // check whether the constraint is violated
// if (otherIdx != -1) {
// if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // violated must-link
// if (m_verbose) {
// System.out.println("Found violated must link between: " + firstIdx + " and " + secondIdx);
// }
// // we penalize weights for both clusters involved, splitting the penalty in half
// Instance diffInstance1 = m_metrics[otherIdx].createDiffInstance(instance1, instance2);
// Instance diffInstance2 = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);
// for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5*0.5
// gradients[otherIdx][attr] += 0.25 * cost * diffInstance1.value(attr);
// gradients[centroidIdx][attr] += 0.25 * cost * diffInstance2.value(attr);
// }
// violatedConstraints++;
// }
// else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { //violated cannot-link
// if (m_verbose) {
// System.out.println("Found violated cannot link between: " + firstIdx + " and " + secondIdx);
// }
// // we penalize weights for just one cluster involved
// diffInstance = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);
// Instance cannotDiffInstance = m_metrics[otherIdx].createDiffInstance(m_maxCLPoints[centroidIdx][0],
// m_maxCLPoints[centroidIdx][1]);
// for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5
// gradients[centroidIdx][attr] += 0.5 * cost * cannotDiffInstance.value(attr);
// gradients[centroidIdx][attr] -= 0.5 * cost * diffInstance.value(attr);
// }
// violatedConstraints++;
// }
// } // end while
// }
// }
// }
// double [][] newWeights = new double[m_metrics.length][numAttributes];
// for (int i = 0; i < m_metrics.length; i++) {
// for (int attr=0; attr<numAttributes; attr++) {
// gradients[i][attr] *= m_currEta;
// if (gradients[i][attr] > regularizerComponents[i][attr]) { // to take into account the direction of the gradient descent update
// newWeights[i][attr] = currentWeights[i][attr] - gradients[i][attr] + regularizerComponents[i][attr];
// } else {
// newWeights[i][attr] = currentWeights[i][attr] + gradients[i][attr] - regularizerComponents[i][attr];
// }
// if (newWeights[i][attr] < 0) {
// System.out.println("prevented negative weight " + newWeights[i][attr] + " for attribute " + m_Instances.attribute(attr).name());
// newWeights[i][attr] = 0;
// } else if (newWeights[i][attr] == 0) {
// System.out.println("zero weight for attribute " + m_Instances.attribute(attr).name());
// }
// }
// newWeights[i] = ClusterUtils.normalize(newWeights[i]);
// ((LearnableMetric) m_metrics[i]).setWeights(newWeights[i]);
// // PRINT top weights
// System.out.println("Cluster " + i + " (" + clusterCounts[i] + ")");
// TreeMap map = new TreeMap(Collections.reverseOrder());
// for (int j = 0; j < newWeights[i].length; j++) {
// map.put(new Double(newWeights[i][j]), new Integer(j));
// }
// Iterator it = map.entrySet().iterator();
// for (int j=0; j < 10 && it.hasNext(); j++) {
// Map.Entry entry = (Map.Entry) it.next();
// int idx = ((Integer)entry.getValue()).intValue();
// System.out.println("\t" + m_Instances.attribute(idx).name() + "\t" + newWeights[i][idx]
// + "\tgradient=" + gradients[i][idx] + "\tregularizer=" + regularizerComponents[i][idx]);
// }
// // end PRINT top weights
// }
// m_currEta = m_currEta * m_etaDecayRate;
// // m_currregularizerTermWeight *= m_etaDecayRate;
// // PRINT routine
// // System.out.println("Total constraints violated: " + violatedConstraints/2 + "; weights are:");
// // for (int attr=0; attr<numAttributes; attr++) {
// // System.out.print(newWeights[attr] + "\t");
// // }
// // System.out.println();
// // end PRINT routine
// return true;
// }
/**
* Gets the current settings of KL
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [1];
int current = 0;
while (current < options.length) {
options[current++] = "";
}
return options;
}
public void setOptions(String[] options) throws Exception {
// TODO: add later
}
public Enumeration listOptions() {
// TODO: add later
return null;
}
}