/*
* NaiveBayes.java
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.classifiers.attributes.AttributeClassObserver;
import tr.gov.ulakbim.jDenetX.classifiers.attributes.GaussianNumericAttributeClassObserver;
import tr.gov.ulakbim.jDenetX.classifiers.attributes.NominalAttributeClassObserver;
import tr.gov.ulakbim.jDenetX.core.AutoExpandVector;
import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.Measurement;
import tr.gov.ulakbim.jDenetX.core.StringUtils;
import weka.core.Instance;
public class NaiveBayes extends AbstractClassifier {
private static final long serialVersionUID = 1L;
@SuppressWarnings("hiding")
public static final String classifierPurposeString = "Naive Bayes classifier: performs classic bayesian prediction while making naive assumption that all inputs are independent.";
protected DoubleVector observedClassDistribution;
protected AutoExpandVector<AttributeClassObserver> attributeObservers;
@Override
public void resetLearningImpl() {
this.observedClassDistribution = new DoubleVector();
this.attributeObservers = new AutoExpandVector<AttributeClassObserver>();
}
@Override
public String getPurposeString() {
return super.getPurposeString();
}
@Override
public void trainOnInstanceImpl(Instance inst) {
this.observedClassDistribution.addToValue((int) inst.classValue(), inst
.weight());
for (int i = 0; i < inst.numAttributes() - 1; i++) {
int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
AttributeClassObserver obs = this.attributeObservers.get(i);
if (obs == null) {
obs = inst.attribute(instAttIndex).isNominal() ? newNominalClassObserver()
: newNumericClassObserver();
this.attributeObservers.set(i, obs);
}
obs.observeAttributeClass(inst.value(instAttIndex), (int) inst
.classValue(), inst.weight());
}
}
public double[] getVotesForInstance(Instance inst) {
return doNaiveBayesPrediction(inst, this.observedClassDistribution,
this.attributeObservers);
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
Measurement[] measurement = null;
return null;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
for (int i = 0; i < this.observedClassDistribution.numValues(); i++) {
StringUtils.appendIndented(out, indent, "Observations for ");
out.append(getClassNameString());
out.append(" = ");
out.append(getClassLabelString(i));
out.append(":");
StringUtils.appendNewlineIndented(out, indent + 1,
"Total observed weight = ");
out.append(this.observedClassDistribution.getValue(i));
out.append(" / prob = ");
out.append(this.observedClassDistribution.getValue(i)
/ this.observedClassDistribution.sumOfValues());
for (int j = 0; j < this.attributeObservers.size(); j++) {
StringUtils.appendNewlineIndented(out, indent + 1,
"Observations for ");
out.append(getAttributeNameString(j));
out.append(": ");
out.append(this.attributeObservers.get(j));
}
StringUtils.appendNewline(out);
}
}
public boolean isRandomizable() {
return false;
}
protected AttributeClassObserver newNominalClassObserver() {
return new NominalAttributeClassObserver();
}
protected AttributeClassObserver newNumericClassObserver() {
return new GaussianNumericAttributeClassObserver();
}
public static double[] doNaiveBayesPrediction(Instance inst,
DoubleVector observedClassDistribution,
AutoExpandVector<AttributeClassObserver> attributeObservers) {
double[] votes = new double[observedClassDistribution.numValues()];
double observedClassSum = observedClassDistribution.sumOfValues();
for (int classIndex = 0; classIndex < votes.length; classIndex++) {
votes[classIndex] = observedClassDistribution.getValue(classIndex)
/ observedClassSum;
for (int attIndex = 0; attIndex < inst.numAttributes() - 1; attIndex++) {
int instAttIndex = modelAttIndexToInstanceAttIndex(attIndex,
inst);
AttributeClassObserver obs = attributeObservers.get(attIndex);
if ((obs != null) && !inst.isMissing(instAttIndex)) {
votes[classIndex] *= obs
.probabilityOfAttributeValueGivenClass(inst
.value(instAttIndex), classIndex);
}
}
}
// TODO: need logic to prevent underflow?
return votes;
}
public void manageMemory(int currentByteSize, int maxByteSize) {
}
}