/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * NaiveBayesSimple.java * Copyright (C) 1999 Eibe Frank * */ package weka.classifiers.bayes; import weka.classifiers.Classifier; import weka.classifiers.DistributionClassifier; import weka.classifiers.Evaluation; import java.io.*; import java.util.*; import weka.core.*; /** * Class for building and using a simple Naive Bayes classifier. * Numeric attributes are modelled by a normal distribution. For more * information, see<p> * * Richard Duda and Peter Hart (1973).<i>Pattern * Classification and Scene Analysis</i>. Wiley, New York. * @author Eibe Frank (eibe@cs.waikato.ac.nz), Ray Mooney (mooney@cs.utexas.edu) * @version $Revision: 1.6 $ * * Changes by Ray Mooney to handle min Standard Deviation, back-off to class-independent mean and Std Deviation * when there is no class-specific data, calculate with logs of probabilities to avoid underflow, * switch to m-estimate smoothing rather than simple Laplace to avoid over-smoothing, and to handle * WeightedInstances */ public class NaiveBayesSimple extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler{ /** All the counts for nominal attributes. */ protected double [][][] m_Counts; /** The means for numeric attributes. */ protected double [][] m_Means; /** The standard deviations for numeric attributes. */ protected double [][] m_Devs; /** The prior probabilities of the classes. */ protected double [] m_Priors; /** The instances used for training. */ protected Instances m_Instances; /** Constant for normal distribution. */ protected static double NORM_CONST = Math.sqrt(2 * Math.PI); /** default minimum standard deviation */ protected double m_minStdDev = 1E-6; /** m parameter for Laplace m estimate, corresponding to size of pseudosample */ protected double m_m = 1.0; /** * Reset to default options */ protected void resetOptions () { m_minStdDev = 1e-6; m_m = 1.0; } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Simple Bayesian algorithm assuming conditional independence"; } /** * Returns an enumeration describing the available options.. <p> * * @return an enumeration of all the available options. * **/ public Enumeration listOptions () { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tM: Controls amount of Laplace smoothing " + "\t(Default = 1)", "M", 1,"-M <value>")); newVector.addElement(new Option("\tminimum allowable standard deviation " +"for normal density computation " +"\n\t(default 1e-6)" ,"D",1,"-D <num>")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { resetOptions(); String mString = Utils.getOption('M', options); if (mString.length() != 0) { setM(Double.parseDouble(mString)); } String optionString = Utils.getOption('D', options); if (optionString.length() != 0) { setMinStdDev((new Double(optionString)).doubleValue()); } } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minStdDevTipText() { return "set minimum allowable standard deviation"; } /** * Set the minimum value for standard deviation when calculating * normal density. Reducing this value can help prevent arithmetic * overflow resulting from multiplying large densities (arising from small * standard deviations) when there are many singleton or near singleton * values. * @param m minimum value for standard deviation */ public void setMinStdDev(double m) { m_minStdDev = m; } /** * Get the minimum allowable standard deviation. * @return the minumum allowable standard deviation */ public double getMinStdDev() { return m_minStdDev; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String mTipText() { return "set amount of smoothing (m in m-estimate)"; } /** Get Laplace m parameter that controls amouont of smoothing */ public double getM () { return m_m; } /** Set Laplace m parameter that controls amouont of smoothing */ public void setM(double m) { m_m = m; } /** * Gets the current settings. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String [] options = new String [4]; int current = 0; options[current++] = "-M"; options[current++] = "" + getM(); options[current++] = "-D"; options[current++] = ""+getMinStdDev(); return options; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { int attIndex = 0; double sum; if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (instances.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Naive Bayes: Class is numeric!"); } m_Instances = instances; // Reserve space m_Counts = new double[instances.numClasses()] [instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()] [instances.numAttributes() - 1]; m_Devs = new double[instances.numClasses()] [instances.numAttributes() - 1]; m_Priors = new double[instances.numClasses()]; Enumeration enum = instances.enumerateAttributes(); while (enum.hasMoreElements()) { Attribute attribute = (Attribute) enum.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); int classNum = (int)instance.classValue(); Enumeration enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[classNum][attIndex] [(int)instance.value(attribute)] += instance.weight(); } else { m_Means[classNum][attIndex] += instance.value(attribute) * instance.weight(); m_Counts[classNum][attIndex][0] += instance.weight(); m_Devs[classNum][attIndex] += instance.value(attribute) * instance.value(attribute) * instance.weight(); } } attIndex++; } m_Priors[classNum] += instance.weight(); } // Compute means, and std deviations across complete datset for use // when not sufficient class-specific info double[] overallMeans = new double[instances.numAttributes() - 1]; double[] overallDevs = new double[instances.numAttributes() - 1]; double[] overallCounts = new double[instances.numAttributes() - 1]; Enumeration enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { overallMeans[attIndex] += m_Means[j][attIndex]; overallDevs[attIndex] += m_Devs[j][attIndex]; overallCounts[attIndex] += m_Counts[j][attIndex][0]; } if (overallCounts[attIndex] !=0) overallMeans[attIndex] /= overallCounts[attIndex]; overallDevs[attIndex] = Math.sqrt(overallDevs[attIndex]/overallCounts[attIndex] - overallMeans[attIndex]*overallMeans[attIndex]); if (overallDevs[attIndex] <= m_minStdDev || Double.isNaN(overallDevs[attIndex])) overallDevs[attIndex] = m_minStdDev; } attIndex ++; } // Compute conditional probs, means, and std deviations enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); for (int j = 0; j < instances.numClasses(); j++) { if (attribute.isNumeric()) { if (m_Counts[j][attIndex][0] != 0) { m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]/m_Counts[j][attIndex][0] - m_Means[j][attIndex] * m_Means[j][attIndex]); if (m_Devs[j][attIndex] <= m_minStdDev || Double.isNaN(m_Devs[j][attIndex])) // Back-off to class independent Std dev if no data for class m_Devs[j][attIndex] = overallDevs[attIndex]; } else { // Back-off to class independent stats if no data for class m_Means[j][attIndex] = overallMeans[attIndex]; m_Devs[j][attIndex] = overallDevs[attIndex]; } } else if (attribute.isNominal()) { sum = Utils.sum(m_Counts[j][attIndex]); for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = Math.log((m_Counts[j][attIndex][i] + (m_m / (double)attribute.numValues())) / (sum + m_m)); } } } attIndex++; } // Normalize priors with laplace smoothing sum = Utils.sum(m_Priors); for (int j = 0; j < instances.numClasses(); j++) m_Priors[j] = Math.log ( (m_Priors[j] + (m_m /(double)instances.numClasses())) / (sum + m_m)); // System.out.println(toString()); } /** * Calculates the class membership probabilities for the given test instance. * Returns vector of unnormalized logs of probabilities for computational reasons. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] unNormalizedDistributionForInstance(Instance instance) throws Exception { double [] probs = new double[instance.numClasses()]; int attIndex; for (int j = 0; j < instance.numClasses(); j++) { probs[j] = 1; Enumeration enumAtts = instance.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { probs[j] += m_Counts[j][attIndex][(int)instance.value(attribute)]; } else { probs[j] += normalDens(instance.value(attribute), m_Means[j][attIndex], m_Devs[j][attIndex]);} } attIndex++; } probs[j] += m_Priors[j]; } return probs; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { double[] logProbs = unNormalizedDistributionForInstance(instance); normalizeLogs(logProbs); return logProbs; } /** Converts an unormalized vector of logs of probabilities into a normalized * distribution that sums to one */ public static void normalizeLogs(double[] logProbs) { // To avoid underflow problems, first scale logProbs by the maximum before // converting out of log space double max = logProbs[Utils.maxIndex(logProbs)]; for (int i = 0; i < logProbs.length; i++) { logProbs[i] = Math.exp(logProbs[i] - max); } Utils.normalize(logProbs); } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_Instances == null) { return "Naive Bayes (simple): No model built yet."; } try { StringBuffer text = new StringBuffer("Naive Bayes (simple)"); int attIndex; for (int i = 0; i < m_Instances.numClasses(); i++) { text.append("\n\nClass " + m_Instances.classAttribute().value(i) + ": P(C) = " + Utils.doubleToString(Math.exp(m_Priors[i]), 10, 8) + "\n\n"); Enumeration enumAtts = m_Instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); text.append("Attribute " + attribute.name() + "\n"); if (attribute.isNominal()) { for (int j = 0; j < attribute.numValues(); j++) { text.append(attribute.value(j) + "\t"); } text.append("\n"); for (int j = 0; j < attribute.numValues(); j++) text.append(Utils. doubleToString(Math.exp(m_Counts[i][attIndex][j]), 10, 8) + "\t"); } else { text.append("Mean: " + Utils. doubleToString(m_Means[i][attIndex], 10, 8) + "\t"); text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8)); } text.append("\n\n"); attIndex++; } } return text.toString(); } catch (Exception e) { return "Can't print Naive Bayes classifier!"; } } /** * Density function of normal distribution returning log of probability */ protected double normalDens(double x, double mean, double stdDev) { double diff = x - mean; return Math.log (1 / (NORM_CONST * stdDev)) - (diff * diff / (2 * stdDev * stdDev)); } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { Classifier scheme; try { scheme = new NaiveBayesSimple(); System.out.println(Evaluation.evaluateModel(scheme, argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }