/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * PrototypeMetric.java * Copyright (C) 2004 Raymond J. Mooney * */ package weka.classifiers.misc; import weka.classifiers.Classifier; import weka.classifiers.DistributionClassifier; import weka.classifiers.Evaluation; import java.io.*; import java.util.*; import weka.core.*; import weka.core.metrics.*; /** * Prototype learner for purely real-valued instances that uses * a general weka.core.metrics.Metric. * Computes an average/mean/prototype vector for each class. * New examples are classified based on computing distance * from the instance feature vector to the closest prototype using * this Metric. * * By defaults acts as Rocchio-style classifier that uses cosine similarity * Assuming text data arff file is already TFIDF weighted. * * For example see: * Joachims, Thorsten, A Probabilistic Analysis of the Rocchio Algorithm with TFIDF * for Text Categorization. Proceedings of International Conference on Machine Learning * (ICML), 1997. * * @author Ray Mooney (mooney@cs.utexas.edu) * @version $Revision: 1.1 $ * */ public class PrototypeMetric extends DistributionClassifier implements OptionHandler{ /** Metric to be used to compare intances to prototype instance */ protected Metric m_Metric = new WeightedDotP(); /** Prototype instance for each class */ protected Instance[] m_Prototypes; /** The instances used for training. */ protected Instances m_Instances; /** * Set the distance metric * * @param s the metric */ public void setMetric (Metric m) { m_Metric = m; } /** * Get the distance metric * * @returns the distance metric used */ public Metric getMetric () { return m_Metric; } /** * Returns an enumeration describing the available options.. <p> * * @return an enumeration of all the available options. * **/ public Enumeration listOptions () { Vector newVector = new Vector(1); newVector.addElement(new Option( "\tUse a specific distance metric. (Default=WeightedDotP)\n", "M", 1, "-M")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { String metricString = Utils.getOption('M', options); if (metricString.length() != 0) { String[] metricSpec = Utils.splitOptions(metricString); String metricName = metricSpec[0]; metricSpec[0] = ""; System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec)); setMetric(Metric.forName(metricName, metricSpec)); } } /** * Gets the current settings. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String [] options = new String [50]; int current = 0; options[current++] = "-M"; options[current++] = Utils.removeSubstring(m_Metric.getClass().getName(), "weka.core.metrics."); if (m_Metric instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_Metric).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Simple algorithm that computes an average or prototype example for each class "+ "and then classifies instances based on distance to closest prototype using a given metric"; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.checkForStringAttributes() || instances.checkForNominalAttributes()) { throw new UnsupportedAttributeTypeException("Only handles numeric attributes"); } if (instances.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Only nominal class allowed"); } m_Instances = instances; // Create initial prototype instance for each class m_Prototypes = new Instance[instances.numClasses()]; Instances[] classPartitions = classPartitionInstances(instances); for (int j = 0; j < instances.numClasses(); j++) { m_Prototypes[j] = meanInstance(classPartitions[j]); } m_Metric.buildMetric(instances); // System.out.println(toString()); } /** Partition instances into a set for each class */ public Instances[] classPartitionInstances (Instances instances) { Instances[] classPartitions = new Instances[instances.numClasses()]; for (int j = 0; j < instances.numClasses(); j++) { classPartitions[j] = new Instances(instances, instances.numInstances()); } Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); int classNum = (int)instance.classValue(); classPartitions[classNum].add(instance); } return classPartitions; } /** Compute a mean instance for all the instances in a set */ public Instance meanInstance(Instances instances) { double [] meanVector; if (instances.numInstances() !=0 && instances.firstInstance() instanceof SparseInstance) meanVector = meanVectorSparse(instances); else meanVector = meanVectorFull(instances); // global centroid is generally dense Instance meanInstance = new Instance(1.0, meanVector); meanInstance.setDataset(instances); return meanInstance; } /** Compute mean vector for non-sparse instances using meanOrMode method on Instances */ protected double[] meanVectorFull (Instances instances) { double [] meanVector = new double[m_Instances.numAttributes()]; for (int j = 0; j < instances.numAttributes(); j++) { meanVector[j] = instances.meanOrMode(j); // uses usual meanOrMode } return meanVector; } /** Efficiently compute a mean vector for a set of sparse instances */ protected double[] meanVectorSparse (Instances instances) { int numAttributes = instances.numAttributes(); double[] meanVector = new double[numAttributes]; double totalWeight = 0; for (int j=0; j<instances.numInstances(); j++) { SparseInstance inst = (SparseInstance) (instances.instance(j)); totalWeight += inst.weight(); for (int i=0; i<inst.numValues(); i++) { int index = inst.index(i); meanVector[index] += inst.weight() * inst.valueSparse(i); } } for (int k=0; k<numAttributes; k++) { meanVector[k] = meanVector[k] / totalWeight; } return meanVector; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { double[] sim = new double[instance.numClasses()]; for (int j = 0; j < instance.numClasses(); j++) { sim[j] = m_Metric.similarity(instance, m_Prototypes[j]); } if (Utils.sum(sim) == 0) // If 0 similarity to all class prototypes just use uniform class distribution for (int j = 0; j < instance.numClasses(); j++) sim[j] = 1; Utils.normalize(sim); return sim; } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_Instances == null) { return "No model built yet."; } try { StringBuffer text = new StringBuffer("Prototype Model"); for (int i = 0; i < m_Instances.numClasses(); i++) { text.append("\n\nClass " + m_Instances.classAttribute().value(i)); text.append(m_Prototypes[i]); } return text.toString(); } catch (Exception e) { return "Can't print Prototype classifier!"; } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { Classifier scheme; try { scheme = new Prototype(); System.out.println(Evaluation.evaluateModel(scheme, argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }