/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * WEuclideanLearner.java * Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */ package weka.clusterers.metriclearners; import java.util.*; import weka.core.*; import weka.core.metrics.*; import weka.clusterers.MPCKMeans; import weka.clusterers.InstancePair; /** * A closed-form learner for WeightedEuclidean * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu * @version $Revision: 1.5 $ */ public class WEuclideanLearner extends MPCKMeansMetricLearner { public void resetLearner() { } /** if clusterIdx is -1, all instances are used * (a single metric for all clusters is used) */ public boolean trainMetric(int clusterIdx) throws Exception { Init(clusterIdx); double[] weights = new double[m_numAttributes]; int violatedConstraints = 0; int numInstances = 0; for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) { int assignment = m_clusterAssignments[instIdx]; // only instances assigned to this cluster are of importance if (assignment == clusterIdx || clusterIdx == -1) { numInstances++; if (clusterIdx < 0) { m_centroid = m_kmeans.getClusterCentroids().instance(assignment); } // accumulate variance Instance instance = m_instances.instance(instIdx); Instance diffInstance = m_metric.createDiffInstance(instance, m_centroid); for (int attr = 0; attr < m_numAttributes; attr++) { weights[attr] += diffInstance.value(attr); } // check all constraints for this instance Object list = m_instanceConstraintMap.get(new Integer(instIdx)); if (list != null) { // there are constraints associated with this instance ArrayList constraintList = (ArrayList) list; for (int i = 0; i < constraintList.size(); i++) { InstancePair pair = (InstancePair) constraintList.get(i); int linkType = pair.linkType; int firstIdx = pair.first; int secondIdx = pair.second; Instance instance1 = m_instances.instance(firstIdx); Instance instance2 = m_instances.instance(secondIdx); int otherIdx = (firstIdx == instIdx) ? m_clusterAssignments[secondIdx] : m_clusterAssignments[firstIdx]; if (otherIdx != -1) { // check whether the constraint is violated if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) { diffInstance = m_metric.createDiffInstance(instance1, instance2); for (int attr = 0; attr < m_numAttributes; attr++) { weights[attr] += 0.5 * m_MLweight * diffInstance.value(attr); } } else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK){ diffInstance = m_metric.createDiffInstance(instance1, instance2); for (int attr = 0; attr < m_numAttributes; attr++) { // this constraint will be counted twice, hence 0.5 weights[attr] += 0.5 * m_CLweight * m_maxCLDiffInstance.value(attr); weights[attr] -= 0.5 * m_CLweight * diffInstance.value(attr); } } } } } } } // System.out.println("Updating cluster " + clusterIdx // + " containing " + numInstances); // check the weights double [] newWeights = new double[m_numAttributes]; double [] currentWeights = m_metric.getWeights(); boolean needNewtonRaphson = false; for (int attr = 0; attr < m_numAttributes; attr++) { if (weights[attr] <= 0) { // check to avoid divide by 0 - TODO! System.out.println("Negative weight " + weights[attr] + " for clusterIdx=" + clusterIdx + "; using prev value=" + currentWeights[attr]); newWeights[attr] = currentWeights[attr]; // needNewtonRaphson = true; // break; } else { if (m_regularize) { // solution of quadratic equation - TODO! int n = m_instances.numInstances(); double ratio = (m_logTermWeight * n) / (2 * weights[attr]); newWeights[attr] = ratio + Math.sqrt(ratio*ratio + (m_regularizerTermWeight*n) /weights[attr]); } else { newWeights[attr] = m_logTermWeight * numInstances / weights[attr]; } } } // do NR if needed if (needNewtonRaphson) { System.out.println("GOING TO NEWTON-RAPHSON!!!\n"); newWeights = updateWeightsUsingNewtonRaphson(currentWeights, weights); } // PRINT routine // System.out.println("Total constraints violated: " + violatedConstraints/2 + "; weights are:"); // for (int attr=0; attr<numAttributes; attr++) { // System.out.print(newWeights[attr] + "\t"); // } // System.out.println(); // end PRINT routine m_metric.setWeights(newWeights); return true; } /** calculates weights using Newton Raphson, to satisfy the positivity constraint of each attribute weight, returns learned attribute weights. Note: currentAttrWeights is the inverted version of the current m_metric weights. */ protected double [] updateWeightsUsingNewtonRaphson (double [] currentAttrWeights, double [] invUnconstrainedAttrWeights) throws Exception { int numAttributes = currentAttrWeights.length; double [] iterAttrWeights = currentAttrWeights; // System.out.println("Updating Weights Using NewtonRaphson"); // do { // // sets new attribute weights using NR with line search for alpha // iterAttrWeights = nrWithLineSearchForAlpha(iterAttrWeights, // invUnconstrainedAttrWeights); // // set current attribute weight to m_metric, recalculate obj. fn. // m_OldObjective = m_Objective; // ((LearnableM_metric) m_m_metric).setWeights(iterAttrWeights); // calculateObjectiveFunction(); // } while (!convergenceCheck(m_OldObjective, m_Objective, false)); // objective function not guaranteed to monotonically decrease across NR iterations, so don't do convergence check return iterAttrWeights; } /** Does one NR step, calculates the alpha (using line search) that does not violate positivity constraint of each attribute weight, returns new values of attribute weights */ protected double [] nrWithLineSearchForAlpha(double [] currAttrWeights, double [] invUnconstrainedAttrWeights) throws Exception { int numAttributes = currAttrWeights.length; double [] raphsonWeights = new double[numAttributes]; double top = 1, bottom = 0, alpha = 1; boolean satisfiesConstraints = true; // // initial check for alpha = top // System.out.println("Evaluating at alpha=1"); // for (int attr = 0; attr < numAttributes; attr++) { // raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1)); // if (raphsonWeights[attr] < 0) { // satisfiesConstraints = false; // System.out.println("Negative raphsonWeight for attr: " + attr + ", exiting loop"); // break; // } // // System.out.println("Curr weights: " + currAttrWeights[attr] + ", alpha: " + alpha + ", m_Objective: " + m_Objective); // // System.out.println("Raphson weights[" + attr +"] = " + raphsonWeights[attr]); // } // if (!satisfiesConstraints) { // // line search for alpha between bottom and top // // satisfiesConstraints is false at top, true at bottom // // we want max. alpha in [0,1] for which satisfiesConstraints is true // System.out.println("Starting line search for alpha"); // while ((top-bottom) > m_NRConvergenceDifference && bottom <= top) { // alpha = (bottom + top)/2; // satisfiesConstraints = true; // for (int attr = 0; attr < numAttributes; attr++) { // raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1)); // if (raphsonWeights[attr] < 0) { // satisfiesConstraints = false; // System.out.println("Negative raphsonWeight for attr: " + attr + ", exiting loop"); // break; // } // // System.out.println("In line search ... curr weights: " + currAttrWeights[attr] + ", alpha: " + alpha + ", m_Objective: " + m_Objective); // // System.out.println("In line search ... raphson weights[" + attr +"] = " + raphsonWeights[attr]); // } // if (!satisfiesConstraints) { // top = alpha; // } else { // bottom = alpha; // } // System.out.println("Top: " + top + ", Bottom: " + bottom); // } // alpha = bottom; // System.out.println("Final alpha: " + alpha + ", final objective: " + m_Objective); // System.out.print("Final weights: "); // for (int attr = 0; attr < numAttributes; attr++) { // raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1)); // System.out.print(raphsonWeights[attr] + "\t"); // } // System.out.println(); // } else { // System.out.println("Constraints satisfied"); // } return raphsonWeights; } // OLD CODE FOR MULTIPLE: // /** M-step of the KMeans clustering algorithm -- updates metric // * weights for the individual metrics. Invoked only whe metric is trainable // */ // protected boolean updateMultipleMetricWeightsEuclidean() throws Exception { // if (m_regularizeWeights) { // System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsEuclidean!"); // updateMultipleMetricWeightsEuclideanGD(); // } // int numAttributes = m_Instances.numAttributes(); // double[][] weights = new double[m_NumClusters][numAttributes]; // int []counts = new int[m_NumClusters]; // count how many instances are in each cluster // Instance diffInstance; // //begin debugging variance // boolean debugVariance = true; // double[][] trueWeights = new double[m_NumClusters][numAttributes]; // int [] majorityClasses = new int[m_NumClusters]; // int [][] classCounts = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()]; // // get the majority counts // // NB: m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled! // // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid... // for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) { // Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx); // classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++; // } // for (int i = 0; i < m_NumClusters; i++){ // int majorityClass = 0; // System.out.print("Cluster" + i + "\t" + classCounts[i][0]); // for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) { // System.out.print("\t" + classCounts[i][j]); // if (classCounts[i][j] > classCounts[i][majorityClass]) { // majorityClass = j; // } // } // System.out.println(); // majorityClasses[i] = majorityClass; // } // class MajorityChecker { // int [] m_majorityClasses = null; // public MajorityChecker(int [] majClasses) { m_majorityClasses = majClasses;} // public boolean belongsToMajority(Instances instances, int instIdx, int centroidIdx) { // // silly, must pass instance since can't access outer class fields otherwise from a local inner class // Instance fullInstance = instances.instance(instIdx); // int classValue = (int) fullInstance.classValue(); // if (classValue == m_majorityClasses[centroidIdx]) { // return true; // } else { // return false; // } // } // } // MajorityChecker majChecker = new MajorityChecker(majorityClasses); // //end debugging variance // int violatedConstraints = 0; // for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) { // int centroidIdx = m_ClusterAssignments[instIdx]; // diffInstance = m_metrics[centroidIdx].createDiffInstance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx)); // for (int attr=0; attr<numAttributes; attr++) { // weights[centroidIdx][attr] += diffInstance.value(attr); // Mahalanobis components // if (debugVariance && instIdx < m_TotalTrainWithLabels.numInstances()) { // if (majChecker.belongsToMajority(m_TotalTrainWithLabels, instIdx, centroidIdx)) { // trueWeights[centroidIdx][attr] += diffInstance.value(attr); // } // } // } // counts[centroidIdx]++; // Object list = m_instanceConstraintHash.get(new Integer(instIdx)); // if (list != null) { // there are constraints associated with this instance // ArrayList constraintList = (ArrayList) list; // for (int i = 0; i < constraintList.size(); i++) { // InstancePair pair = (InstancePair) constraintList.get(i); // int firstIdx = pair.first; // int secondIdx = pair.second; // double cost = 0; // if (pair.linkType == InstancePair.MUST_LINK) { // cost = m_MLweight; // } else if (pair.linkType == InstancePair.CANNOT_LINK) { // cost = m_CLweight; // } // Instance instance1 = m_Instances.instance(firstIdx); // Instance instance2 = m_Instances.instance(secondIdx); // int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx]; // // check whether the constraint is violated // if (otherIdx != -1) { // if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // violated must-link // if (m_verbose) { // System.out.println("Found violated must link between: " + firstIdx + " and " + secondIdx); // } // // we penalize weights for both clusters involved, splitting the penalty in half // Instance diffInstance1 = m_metrics[otherIdx].createDiffInstance(instance1, instance2); // Instance diffInstance2 = m_metrics[centroidIdx].createDiffInstance(instance1, instance2); // for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5*0.5 // weights[otherIdx][attr] += 0.25 * cost * diffInstance1.value(attr); // weights[centroidIdx][attr] += 0.25 * cost * diffInstance2.value(attr); // } // violatedConstraints++; // } // else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { //violated cannot-link // if (m_verbose) { // System.out.println("Found violated cannot link between: " + firstIdx + " and " + secondIdx); // } // // we penalize weights for just one cluster involved // diffInstance = m_metrics[centroidIdx].createDiffInstance(instance1, instance2); // Instance cannotDiffInstance = m_metrics[otherIdx].createDiffInstance(m_maxCLPoints[centroidIdx][0], // m_maxCLPoints[centroidIdx][1]); // for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5 // weights[centroidIdx][attr] += 0.5 * cost * cannotDiffInstance.value(attr); // weights[centroidIdx][attr] -= 0.5 * cost * diffInstance.value(attr); // } // violatedConstraints++; // } // } // end while // } // } // } // System.out.println(" Total constraints violated: " + violatedConstraints/2 + "; per-cluster weights are:"); // // check if NR needed // double [][] newWeights = new double[m_NumClusters][numAttributes]; // double [][] currentWeights = new double[m_NumClusters][numAttributes]; // for (int i=0; i<m_NumClusters; i++) { // currentWeights[i] = ((LearnableMetric) m_metrics[i]).getWeights(); // } // for (int i=0; i<m_NumClusters; i++) { // boolean needNewtonRaphson = false; // for (int attr=0; attr<numAttributes; attr++) { // if (weights[i][attr] < 0) { // check to avoid divide by 0 // System.out.println("WARNING! Cluster " + i + ", attribute " + attr + " weight=" + weights[i][attr]); // Cluster currentCluster = (Cluster) getClusters().get(i); // System.out.println("\nCluster " + i + ": " + currentCluster.size() + " instances"); // if (currentCluster == null) { // System.out.println("(empty)"); // } // else { // for (int j=0; j<currentCluster.size(); j++) { // Instance instance = (Instance) currentCluster.get(j); // System.out.println("Instance: " + instance); // } // } // needNewtonRaphson = true; // break; // } else if (weights[i][attr] == 0) { // newWeights[i][attr] = currentWeights[i][attr]; // System.out.println("WARNING! Cluster " + i + ", attribute " + attr + " has 0 weight; keeping it as " + weights[i][attr]); // } else { // newWeights[i][attr] = m_logTermWeight * counts[i]/weights[i][attr]; // invert weights // if (debugVariance) { // trueWeights[i][attr] = counts[i]/trueWeights[i][attr]; // } // } // } // // uncomment next line for debugging NR // // needNewtonRaphson = true; // // do NR if needed // if (needNewtonRaphson) { // // weights not inverted here -- done in NR routine // newWeights[i] = updateWeightsUsingNewtonRaphson(currentWeights[i], weights[i]); // System.out.println(" (NR) "); // } // // PRINT routine // // System.out.print("\t" + i + "(" + counts[i] + "): "); // // for (int attr=0; attr<numAttributes; attr++) { // // if (debugVariance) { // // System.out.print(((float)trueWeights[i][attr]) + "/~/"); // // } // // System.out.print(((float)newWeights[i][attr]) + "\t"); // // } // // System.out.println(); // // System.out.println("\t\tMean: " + m_ClusterCentroids.instance(i)); // // end PRINT routine // ((LearnableMetric) m_metrics[i]).setWeights(newWeights[i]); // } // return true; // } /** * Gets the current settings of KL * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [1]; int current = 0; while (current < options.length) { options[current++] = ""; } return options; } public void setOptions(String[] options) throws Exception { // TODO: add later } public Enumeration listOptions() { // TODO: add later return null; } }