/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * Prototype.java * Copyright (C) 2004 Raymond J. Mooney * */ package weka.classifiers.misc; import weka.classifiers.Classifier; import weka.classifiers.DistributionClassifier; import weka.classifiers.Evaluation; import java.io.*; import java.util.*; import weka.core.*; import weka.core.metrics.*; /** * Class for building and using a simple prototype classifier. * Computes an average/mean/prototype vector for each class. * New examples are classified based on computing distance * from the instance feature vector to the closest prototype. * * For real-valued attributes, standard vector mean and Euclidian distance are used. * To handle nominal attributes, the distribution of values for each category * are computed (as in naive Bayes) as part of the prototype. The distance along a * nominal attribute from an instance with a value V for this attribute to the * prototype for a given class is then: 1- P(V|class) * * In order to make each attribute contribute equally to the distance, values * are normalized to [0,1] by setting NormalizeAttributes, which is set by default * * Predicted class probabilities to make a DistributionClassifier are * assumed to be inversely proportional to the distances from the prototypes * * Borrows some structure from NaiveBayesSimple * * @author Ray Mooney (mooney@cs.utexas.edu) * @version $Revision: 1.1 $ * */ public class Prototype extends DistributionClassifier implements WeightedInstancesHandler, OptionHandler{ /** All the counts for nominal attributes. */ protected double [][][] m_Counts; /** The means for numeric attributes. */ protected double [][] m_Means; /** The range (from min to max) taken on by each of the numeric attributes */ protected double [] m_Ranges; /** The instances used for training. */ protected Instances m_Instances; /** If set, Normalize all real attribute values between 0 and 1 so that * each dimension contributes equally to distance */ protected boolean m_NormalizeAttributes = true; /** * Returns an enumeration describing the available options.. <p> * * @return an enumeration of all the available options. * **/ public Enumeration listOptions () { Vector newVector = new Vector(1); newVector.addElement(new Option("\tNormal attribute values to [0,1].", "N", 0, "-N")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { setNormalizeAttributes(Utils.getFlag('N', options)); } /** * Gets the current settings. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String [] options = new String [1]; int current = 0; if (m_NormalizeAttributes) options[current++] = "-N"; while (current < options.length) { options[current++] = ""; } return options; } public void setNormalizeAttributes (boolean v) { m_NormalizeAttributes = v; } public boolean getNormalizeAttributes () { return m_NormalizeAttributes; } public String normalizeAttributesTipText() { return "Scale all real-valued attributes to the range [0,1] to equalize contribution of all attributes."; } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Simple algorithm that computes an average or prototype example for each class "+ "and then classifies instances based on distance to closest prototype"; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (instances.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Only nominal class allowed"); } m_Instances = instances; // Reserve space m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Ranges = new double[instances.numAttributes() - 1]; double[] maxs = new double[instances.numAttributes() - 1]; double[] mins = new double[instances.numAttributes()- 1]; Enumeration enum = instances.enumerateAttributes(); int attIndex = 0; while (enum.hasMoreElements()) { Attribute attribute = (Attribute) enum.nextElement(); maxs[attIndex] = Double.NEGATIVE_INFINITY; mins[attIndex] = Double.POSITIVE_INFINITY; if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); int classNum = (int)instance.classValue(); Enumeration enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[classNum][attIndex][(int)instance.value(attribute)] += instance.weight(); } else { double value = instance.value(attribute); m_Means[classNum][attIndex] += value * instance.weight(); m_Counts[classNum][attIndex][0] += instance.weight(); if (m_NormalizeAttributes) { if (value < mins[attIndex]) mins[attIndex] = value; if (value > maxs[attIndex]) maxs[attIndex] = value; } } } attIndex++; } } // Compute means across complete datset for use // when not sufficient class-specific info double[] overallMeans = new double[instances.numAttributes() - 1]; double[] overallCounts = new double[instances.numAttributes() - 1]; Enumeration enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { overallMeans[attIndex] += m_Means[j][attIndex]; overallCounts[attIndex] += m_Counts[j][attIndex][0]; } if (overallCounts[attIndex] !=0) overallMeans[attIndex] /= overallCounts[attIndex]; } attIndex ++; } // Compute conditional probs, means, enumAtts = instances.enumerateAttributes(); attIndex = 0; double sum = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); m_Ranges[attIndex] = maxs[attIndex] - mins[attIndex]; for (int j = 0; j < instances.numClasses(); j++) { if (attribute.isNumeric()) { if (m_Counts[j][attIndex][0] != 0) { m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; } else { // Back-off to class independent stats if no data for class m_Means[j][attIndex] = overallMeans[attIndex]; } } else if (attribute.isNominal()) { sum = Utils.sum(m_Counts[j][attIndex]); if (sum != 0) for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = m_Counts[j][attIndex][i] / sum; } } } attIndex++; } // System.out.println(toString()); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { double[] dists = new double[instance.numClasses()]; for (int j = 0; j < instance.numClasses(); j++) { Enumeration enumAtts = instance.enumerateAttributes(); int attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { // Nominal distance is based on 1 - P(value | class) dists[j] += Math.pow((1 - m_Counts[j][attIndex][(int)instance.value(attribute)]),2); } else { double diff = (instance.value(attribute) - m_Means[j][attIndex]); if (m_NormalizeAttributes) { // Scaling by the range is equivalent to scaling all atrributes to [0,1] // and equalizes the contribution of all attributes if (m_Ranges[attIndex] == 0) diff = 1.0; else diff = diff / m_Ranges[attIndex]; } dists[j] += Math.pow(diff, 2); } } attIndex++; } // Use inverse of Euclidian distance from prototype as similarity that is normalized // to a probability distribution dists[j] = Math.sqrt(dists[j]); if (dists[j] == 0.0) dists[j] = Double.MAX_VALUE; else dists[j] = 1 / dists[j]; } Utils.normalize(dists); return dists; } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_Instances == null) { return "No model built yet."; } try { StringBuffer text = new StringBuffer("Prototype Model"); int attIndex; for (int i = 0; i < m_Instances.numClasses(); i++) { text.append("\n\nClass " + m_Instances.classAttribute().value(i) + "\n\n"); Enumeration enumAtts = m_Instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); text.append("Attribute " + attribute.name() + "\n"); if (attribute.isNominal()) { for (int j = 0; j < attribute.numValues(); j++) { text.append(attribute.value(j) + "\t"); } text.append("\n"); for (int j = 0; j < attribute.numValues(); j++) text.append(Utils. doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t"); } else { text.append("Mean: " + Utils. doubleToString(m_Means[i][attIndex], 10, 8) + "\t"); } text.append("\n\n"); attIndex++; } } return text.toString(); } catch (Exception e) { return "Can't print Prototype classifier!"; } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { Classifier scheme; try { scheme = new Prototype(); System.out.println(Evaluation.evaluateModel(scheme, argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }