/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* NaiveBayesSimpleSoft.java
* Copyright (C) 2003 Ray Mooney
*
*/
package weka.classifiers.bayes;
import weka.classifiers.*;
import java.io.*;
import java.util.*;
import weka.core.*;
/**
* Version of NaiveBayesSimple that supports training on SoftClassifiedInstances
* and WeightedInstances for use with SemiSupEM
*
* @author Ray Mooney (mooney@cs.utexas.edu)
*/
public class NaiveBayesSimpleSoft extends NaiveBayesSimple implements SoftClassifier, OptionHandler,
WeightedInstancesHandler {
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(SoftClassifiedInstances instances) throws Exception {
int attIndex = 0;
double sum;
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
if (instances.classAttribute().isNumeric()) {
throw new UnsupportedClassTypeException("Naive Bayes: Class is numeric!");
}
m_Instances = instances;
// Reserve space
m_Counts = new double[instances.numClasses()]
[instances.numAttributes() - 1][0];
m_Means = new double[instances.numClasses()]
[instances.numAttributes() - 1];
m_Devs = new double[instances.numClasses()]
[instances.numAttributes() - 1];
m_Priors = new double[instances.numClasses()];
Enumeration enum = instances.enumerateAttributes();
while (enum.hasMoreElements()) {
Attribute attribute = (Attribute) enum.nextElement();
if (attribute.isNominal()) {
for (int j = 0; j < instances.numClasses(); j++) {
m_Counts[j][attIndex] = new double[attribute.numValues()];
}
} else {
for (int j = 0; j < instances.numClasses(); j++) {
m_Counts[j][attIndex] = new double[1];
}
}
attIndex++;
}
// Compute counts and sums taking soft class labels into account
Enumeration enumInsts = instances.enumerateInstances();
while (enumInsts.hasMoreElements()) {
Instance instance = (Instance) enumInsts.nextElement();
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
for (int classNum = 0; classNum < instances.numClasses(); classNum++) {
double weightedClassProb = ((SoftClassifiedInstance)instance).getClassProbability(classNum)
* instance.weight();
if (!instance.isMissing(attribute)) {
if (attribute.isNominal()) {
m_Counts[classNum][attIndex]
[(int)instance.value(attribute)] += weightedClassProb;
} else {
m_Means[classNum][attIndex] +=
instance.value(attribute) * weightedClassProb;
m_Counts[classNum][attIndex][0] += weightedClassProb;
m_Devs[classNum][attIndex] += instance.value(attribute) *
instance.value(attribute) * weightedClassProb;
}
}
}
attIndex++;
}
for (int classNum = 0; classNum < instances.numClasses(); classNum++) {
m_Priors[classNum] += ((SoftClassifiedInstance)instance).getClassProbability(classNum)
* instance.weight();
}
}
// Compute means, and std deviations across complete datset for use
// when not sufficient class-specific info
double[] overallMeans = new double[instances.numAttributes() - 1];
double[] overallDevs = new double[instances.numAttributes() - 1];
double[] overallCounts = new double[instances.numAttributes() - 1];
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNumeric()) {
for (int j = 0; j < instances.numClasses(); j++) {
overallMeans[attIndex] += m_Means[j][attIndex];
overallDevs[attIndex] += m_Devs[j][attIndex];
overallCounts[attIndex] += m_Counts[j][attIndex][0];
}
if (overallCounts[attIndex] !=0)
overallMeans[attIndex] /= overallCounts[attIndex];
overallDevs[attIndex] = Math.sqrt(overallDevs[attIndex]/overallCounts[attIndex] -
overallMeans[attIndex]*overallMeans[attIndex]);
if (overallDevs[attIndex] <= m_minStdDev || Double.isNaN(overallDevs[attIndex]))
overallDevs[attIndex] = m_minStdDev;
}
attIndex ++;
}
// Compute conditional probs, means, and std deviations
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
for (int j = 0; j < instances.numClasses(); j++) {
if (attribute.isNumeric()) {
if (m_Counts[j][attIndex][0] != 0) {
m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]/m_Counts[j][attIndex][0] -
m_Means[j][attIndex] * m_Means[j][attIndex]);
if (m_Devs[j][attIndex] <= m_minStdDev || Double.isNaN(m_Devs[j][attIndex]))
// Back-off to class independent Std dev if no data for class
m_Devs[j][attIndex] = overallDevs[attIndex];
} else { // Back-off to class independent stats if no data for class
m_Means[j][attIndex] = overallMeans[attIndex];
m_Devs[j][attIndex] = overallDevs[attIndex];
}
} else if (attribute.isNominal()) {
sum = Utils.sum(m_Counts[j][attIndex]);
for (int i = 0; i < attribute.numValues(); i++) {
m_Counts[j][attIndex][i] = Math.log((m_Counts[j][attIndex][i] + (m_m / (double)attribute.numValues()))
/ (sum + m_m));
}
}
}
attIndex++;
}
// Normalize priors with laplace smoothing
sum = Utils.sum(m_Priors);
for (int j = 0; j < instances.numClasses(); j++)
m_Priors[j] = Math.log ( (m_Priors[j] + (m_m /(double)instances.numClasses()))
/ (sum + m_m));
}
}