/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * GDMetricLearner.java * Copyright (C) 2002 Mikhail Bilenko * */ package weka.core.metrics; import java.util.*; import java.io.Serializable; import java.io.*; import java.text.SimpleDateFormat; import java.text.DecimalFormat; import java.text.NumberFormat; import weka.classifiers.*; import weka.classifiers.functions.*; import weka.core.*; import weka.attributeSelection.*; /** * GDMetricLearner - sets the weights of a metric * using gradient descent * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.1 $ */ public class GDMetricLearner extends MetricLearner implements Serializable, OptionHandler { /** The metric that the classifier was used to learn, useful for external-calculation based metrics */ protected LearnableMetric m_metric = null; /** Maximum number of iterations */ protected int m_maxIterations = 20; /** The learning rate */ protected double m_learningRate = 0.0000001; /** The training data */ protected Instances m_instances = null; protected ArrayList m_pairList = null; protected int m_numPosPairs = 200; protected int m_numNegPairs = 200; /** The convergence criterion for total weight updates */ protected double m_epsilon = 10e-5; /** The pairwise selector used by the metric */ protected PairwiseSelector m_selector = new RandomPairwiseSelector(); /** Create a new gradient descent metric learner * @param classifierName the name of the classifier class to be used */ public GDMetricLearner() { } /** * Train a given metric using given training instances * * @param metric the metric to train * @param instances data to train the metric on * @exception Exception if training has gone bad. */ public void trainMetric(LearnableMetric metric, Instances instances) throws Exception { // If the data doesn't have a class attribute, bail if (instances.classIndex() < 0 || instances.numInstances() < 2) { metric.m_trained = false; System.out.println("Problem with training data"); return; } if (metric.getExternal()) { throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } System.out.println(getTimestamp() + " Starting to calculate weights over " + metric.getNumAttributes() +" attributes"); m_metric = metric; m_instances = instances; m_pairList = m_selector.createPairList(m_instances, m_numPosPairs, m_numNegPairs, metric); int numWeights = metric.getNumAttributes(); double[] currentWeights = new double[numWeights]; Arrays.fill(currentWeights,1.0/numWeights); metric.setWeights(currentWeights); int iterCount = 0; boolean converged = false; while (iterCount < m_maxIterations && !converged) { // calculate the gradient vector double [] gradients = calculateGradients(currentWeights); double updateTotal = 0; // update the weights for (int i = 0; i < numWeights; i++) { // System.out.println("update: " + gradients[i]); double update = m_learningRate * gradients[i]; updateTotal += Math.abs(update); currentWeights[i] += update; } currentWeights = normalizeWeights(currentWeights); metric.setWeights(currentWeights); // check convergence if (updateTotal <= m_epsilon) { converged = true; } iterCount++; } printTopAttributes(currentWeights, 10, iterCount); System.out.println(getTimestamp() + " Gradient descent complete after " + iterCount + " iterations"); metric.m_trained = true; } /** A helper function that calculates the current gradient value * @param weights the current weights vector * @return the values of the partial derivatives */ protected double[] calculateGradients(double[] weights) throws Exception { double [] gradients = new double[weights.length]; // calculate the gradients for (int i = 0; i < m_pairList.size(); i++) { TrainingPair pair = (TrainingPair) m_pairList.get(i); double[] pairGradients = m_metric.getGradients(pair.instance1, pair.instance2); // System.out.println(pair.instance1 + "\t" + pair.instance2 + "\t" + pair.positive); for (int j = 0; j < gradients.length; j++) { //System.out.print(gradients[j] + "(" + pairGradients[j] + ")\t"); gradients[j] = gradients[j] + (pair.positive ? pairGradients[j] : -pairGradients[j]); } // System.out.println(); } return gradients; } /** Normalize weights * @param weights an unnormalized array of weights * @return a normalized array of weights */ protected double[] normalizeWeights(double[] weights) { double sum = 0; for (int i = 0; i < weights.length; i++) { if (weights[i] < 0) { weights[i] = 0; } else { sum += weights[i]; } } double [] newWeights = new double[weights.length]; for (int i = 0; i < weights.length; i++) { newWeights[i] = weights[i] / sum; } return newWeights; } /** Get the norm-2 length of an instance assuming all attributes are numeric * and utilizing the attribute weights * @returns norm-2 length of an instance */ public double lengthWeighted(Instance instance, double[] weights) { int classIndex = instance.classIndex(); double length = 0; if (instance instanceof SparseInstance) { // remap classIndex to an internal index if (classIndex >= 0) { classIndex = ((SparseInstance)instance).locateIndex(classIndex); } for (int i = 0; i < instance.numValues(); i++) { if (i != classIndex) { double value = instance.valueSparse(i); length += weights[i] * value * value; } } } else { // non-sparse instance double[] values = instance.toDoubleArray(); for (int i = 0; i < values.length; i++) { if (i != classIndex) { length += weights[i] * values[i] * values[i]; } } } return Math.sqrt(length); } /** * Use the Classifier for an estimation of similarity * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns sim an approximate similarity obtained from the classifier */ public double getSimilarity(Instance instance1, Instance instance2) throws Exception{ throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } /** * Use the Classifier for an estimation of distance * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns an approximate distance obtained from the classifier */ public double getDistance(Instance instance1, Instance instance2) throws Exception{ throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } /** Set the convergence criterion * @param epsilon the maximum sum of weight updates required for GD to converge */ public void setEpsilon(double epsilon) { m_epsilon = epsilon; } /** Get the convergence criterion * @return the maximum sum of weight updates required for GD to converge */ public double getEpsilon() { return m_epsilon; } /** Set the learning rate * @param learningRate the gradient update coefficient */ public void setLearningRate(double learningRate) { m_learningRate = learningRate; } /** Get the learning rate * @return the gradient update coefficient */ public double getLearningRate() { return m_learningRate; } /** Set the maximum number of update iterations rate * @param maxIterations the maximum number of gradient updates */ public void setMaxIterations(int maxIterations) { m_maxIterations = maxIterations; } /** Get the maximum number of update iterations rate * @return the maximum number of gradient updates */ public int getMaxIterations() { return m_maxIterations; } /** Set the number of same-class training pairs * @param numPosPairs the number of same-class training pairs to create for training */ public void setNumPosPairs(int numPosPairs) { m_numPosPairs = numPosPairs; } /** Get the number of same-class training pairs * @return the number of same-class training pairs to create for training */ public int getNumPosPairs() { return m_numPosPairs; } /** Set the number of different-class training pairs * @param numNegPairs the number of different-class training pairs to create for training */ public void setNumNegPairs(int numNegPairs) { m_numNegPairs = numNegPairs; } /** Get the number of different-class training pairs * @return the number of different-class training pairs to create for training */ public int getNumNegPairs() { return m_numNegPairs; } /** Set the pairwise selector * @param selector the selector for training pairs */ public void setSelector (PairwiseSelector selector) { m_selector = selector; } /** Get the pairwise selector * @return the selector for training pairs */ public PairwiseSelector getSelector() { return m_selector; } /** * Gets the current settings of WeightedDotP. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [25]; int current = 0; options[current++] = "-e"; options[current++] = "" + m_epsilon; options[current++] = "-p"; options[current++] = "" + m_numPosPairs; options[current++] = "-n"; options[current++] = "" + m_numNegPairs; options[current++] = "-i"; options[current++] = "" + m_maxIterations; options[current++] = "-l"; options[current++] = "" + m_learningRate; options[current++] = "-S"; options[current++] = m_selector.getClass().getName(); while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. Valid options are:<p> * * -B classifierstring */ public void setOptions(String[] options) throws Exception { } /** * Gets a string containing current date and time. * * @return a string containing the date and time. */ protected static String getTimestamp() { return (new SimpleDateFormat("HH:mm:ss:")).format(new Date()); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(0); return newVector.elements(); } /** Obtain a textual description of the metriclearner * @return a textual description of the metric learner */ public String toString() { return new String("GDMetricLearner " + concatStringArray(getOptions())); } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** Create a lists of pairs of two kinds: pairs of instances belonging to same class, * and pairs of instances belonging to different classes. * */ protected ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs) { ArrayList pairList = new ArrayList(); // A hashmap where each class will be mapped to a list of instnaces belonging to it HashMap classInstanceMap = new HashMap(); // A list of classes, each element is the double value of the class attribute ArrayList classValueList = new ArrayList(); // go through all instances, hashing them into lists corresponding to each class Enumeration enum = instances.enumerateInstances(); while (enum.hasMoreElements()) { Instance instance = (Instance) enum.nextElement(); if (instance.classIsMissing()) { System.err.println("Instance has missing class!!!"); continue; } Double classValue = new Double(instance.classValue()); // if this class has been seen, add instance to its list if (classInstanceMap.containsKey(classValue)) { ArrayList classInstanceList = (ArrayList) classInstanceMap.get(classValue); classInstanceList.add(instance); } else { // create a new list of instances for a previously unseen class ArrayList classInstanceList = new ArrayList(); classInstanceList.add(instance); classInstanceMap.put(classValue, classInstanceList); classValueList.add(classValue); } } // Create the desired number of random positive instances int numClasses = classInstanceMap.size(); Random random = new Random(); for (int i = 0; i < numPosPairs; i++) { // select a random class... TODO: probability must be proportional to the number of instances int class1 = random.nextInt(numClasses); ArrayList list = (ArrayList) classInstanceMap.get(classValueList.get(class1)); int idx1 = random.nextInt(list.size()); int idx2; do { idx2 = random.nextInt(list.size()); } while (idx1 == idx2); Instance instance1 = (Instance) list.get(idx1); Instance instance2 = (Instance) list.get(idx2); TrainingPair posPair = new TrainingPair(instance1, instance2, true, 0); pairList.add(posPair); } // Create negative diff-instances if (numClasses > 1) { random = new Random(); for (int i = 0; i < numNegPairs; i++) { // select two random distinct classes int class1 = random.nextInt(numClasses); int class2 = random.nextInt(numClasses); while (class2 == class1) { class2 = random.nextInt(numClasses); } ArrayList list1 = (ArrayList) classInstanceMap.get(classValueList.get(class1)); Instance instance1 = (Instance) list1.get(random.nextInt(list1.size())); ArrayList list2 = (ArrayList) classInstanceMap.get(classValueList.get(class2)); Instance instance2 = (Instance) list2.get(random.nextInt(list2.size())); TrainingPair negPair = new TrainingPair(instance1, instance2, false, 0); pairList.add(negPair); } } return pairList; } /** Print the heaviest-weighted attributes for a given set of weights */ public void printTopAttributes(double[] weights, int n, int iteration) { // Print top weights - to be moved out into a separate function System.out.println(iteration + " top components:"); int[] sortedIndeces = Utils.sort(weights); for (int i = sortedIndeces.length-1; i > sortedIndeces.length-n && i >=0; i--) { int idx = sortedIndeces[i]; System.out.println((sortedIndeces.length-1-i) + ": " + idx + ":" + m_instances.attribute(idx).name() + "(" + weights[idx] + ")"); } } }