package weka.classifiers.bayes;
import weka.classifiers.Classifier;
import weka.classifiers.DistributionClassifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.Evaluation;
import weka.classifiers.UpdateableClassifier;
import java.io.*;
import java.util.*;
import weka.core.*;
import weka.estimators.*;
public class PBayes extends NaiveBayes{
public String Prob;
double [] m_probs;
/**
* Generates the classifier.
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if(instances.checkForStringAttributes())throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
if(instances.classAttribute().isNumeric())throw new UnsupportedClassTypeException("PBayes: Class is numeric!");
m_NumClasses = instances.numClasses();
if(m_NumClasses<0)throw new Exception ("Dataset has no class attribute");
// Copy the instances
m_Instances = new Instances(instances);
// Reserve space for the distributions
m_Distributions=new Estimator[m_Instances.numAttributes()-1][m_Instances.numClasses()];
m_ClassDistribution=new DiscreteEstimator(m_Instances.numClasses(),true);
int attIndex = 0;
Enumeration enum = m_Instances.enumerateAttributes();
while (enum.hasMoreElements()) {
Attribute attribute = (Attribute) enum.nextElement();
// If the attribute is numeric, determine the estimator
// numeric precision from differences between adjacent values
double numPrecision = DEFAULT_NUM_PRECISION;
if (attribute.type() == Attribute.NUMERIC) {
m_Instances.sort(attribute);
if ((m_Instances.numInstances() > 0)
&& !m_Instances.instance(0).isMissing(attribute)) {
double lastVal = m_Instances.instance(0).value(attribute);
double currentVal, deltaSum = 0;
int distinct = 0;
for (int i = 1; i < m_Instances.numInstances(); i++) {
Instance currentInst = m_Instances.instance(i);
if (currentInst.isMissing(attribute)) {
break;
}
currentVal = currentInst.value(attribute);
if (currentVal != lastVal) {
deltaSum += currentVal - lastVal;
lastVal = currentVal;
distinct++;
}
}
if (distinct > 0) {
numPrecision = deltaSum / distinct;
}
}
}
for (int j = 0; j < m_Instances.numClasses(); j++) {
switch (attribute.type()) {
case Attribute.NUMERIC:
if(m_UseKernelEstimator)m_Distributions[attIndex][j]=new KernelEstimator(numPrecision);
else m_Distributions[attIndex][j]=new NormalEstimator(numPrecision);
break;
case Attribute.NOMINAL:
m_Distributions[attIndex][j]=new DiscreteEstimator(attribute.numValues(),true);
break;
default:
throw new Exception("Attribute type unknown to PBayes");
}
}
attIndex++;
}
if(Prob.length()>0)updateClassifier(m_Instances,Prob);
else updateClassifier(m_Instances);
m_probs=new double[m_NumClasses];
updateClassifier();
// Save space
m_Instances = new Instances(m_Instances, 0);
}
public void updateClassifier(Instances instances,String Prob)throws Exception{
BufferedReader reader=new BufferedReader(new FileReader(Prob));
Enumeration enumInsts=instances.enumerateInstances();
Attribute classAttribute=instances.classAttribute();
while (enumInsts.hasMoreElements()) {
Instance instance = (Instance) enumInsts.nextElement();
String line=reader.readLine();
String [] lines=line.split("\\s+");
updateClassifier(instance,classAttribute.index(lines[1]),Double.parseDouble(lines[2]));
}
}
// NaiveBayes
public void updateClassifier(Instances instances)throws Exception{
Enumeration enumInsts=instances.enumerateInstances();
Attribute classAttribute=instances.classAttribute();
while (enumInsts.hasMoreElements()) {
Instance instance = (Instance) enumInsts.nextElement();
updateClassifier(instance);
}
}
/**
* Updates the classifier with the given instance.
* @param instance the new training instance to include in the model
* @exception Exception if the instance could not be incorporated in the model.
*/
public void updateClassifier(Instance instance,int classIndex,double prob)throws Exception{
int [] cs={classIndex,1-classIndex};
double [] ps={instance.weight()*prob,instance.weight()*(1-prob)};
Enumeration enumAtts=m_Instances.enumerateAttributes();
int attIndex=0;
while(enumAtts.hasMoreElements()){
Attribute attribute=(Attribute)enumAtts.nextElement();
if(!instance.isMissing(attribute)){
m_Distributions[attIndex][cs[0]].addValue(instance.value(attribute),ps[0]);
m_Distributions[attIndex][cs[1]].addValue(instance.value(attribute),ps[1]);
}
attIndex++;
}
m_ClassDistribution.addValue(cs[0],ps[0]);
m_ClassDistribution.addValue(cs[1],ps[1]);
}
// NaiveBayes
public void updateClassifier(Instance instance)throws Exception{
updateClassifier(instance,(int)instance.classValue(),1);
}
public void updateClassifier(){
for(int j=0;j<m_NumClasses;j++)m_probs[j]=m_ClassDistribution.getProbability(j);
}
/**
* Calculates the class membership probabilities for the given test instance.
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if there is a problem generating the prediction
*/
public double [] distributionForInstance(Instance instance)throws Exception{
double [] probs=new double[m_NumClasses];
System.arraycopy(m_probs,0,probs,0,m_NumClasses);
Enumeration enumAtts = instance.enumerateAttributes();
int attIndex=0,atts=0;
while(enumAtts.hasMoreElements()){
Attribute attribute=(Attribute)enumAtts.nextElement();
if(!instance.isMissing(attribute)){
atts++;
double temp,max=0;
for(int j=0;j<m_NumClasses;j++){
temp=Math.max(1e-75,m_Distributions[attIndex][j].getProbability(instance.value(attribute)));
probs[j]*=temp;
if(probs[j]>max)max=probs[j];
if(Double.isNaN(probs[j])){
throw new Exception("NaN returned from estimator for attribute "
+attribute.name()+":\n"
+m_Distributions[attIndex][j].toString());
}
}
if((max>0)&&(max<1e-75)){//Danger of probability underflow
for(int j=0;j<m_NumClasses;j++)probs[j]*=1e75;
}
}
attIndex++;
}
double power=1.0/atts;
//for(int i=0;i<m_NumClasses;i++)probs[i]=Math.pow(probs[i],power);
Utils.normalize(probs);
return probs;
}
/**
* Returns an enumeration describing the available options.
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(2);
newVector.addElement(
new Option("\tUse kernel density estimator rather than normal\n"
+"\tdistribution for numeric attributes",
"K", 0,"-K"));
newVector.addElement(new Option("\tSet Probability data file. Each line represent one instance's probabilities to be all classes seperated by space.","Prob",0,"-Prob"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -K <br>
* Use kernel estimation for modelling numeric attributes rather than
* a single normal distribution.<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
m_UseKernelEstimator = Utils.getFlag('K', options);
Prob=Utils.getOption("Prob",options);
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [2];
int current = 0;
options[current++]="-Prob";
if (m_UseKernelEstimator) {
options[current++] = "-K";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() {
StringBuffer text = new StringBuffer();
text.append("PBayes Classifier");
if (m_Instances == null) {
text.append(": No model built yet.");
} else {
try {
for (int i = 0; i < m_Distributions[0].length; i++) {
text.append("\n\nClass " + m_Instances.classAttribute().value(i) +
": Prior probability = " + Utils.
doubleToString(m_ClassDistribution.getProbability(i),
4, 2) + "\n\n");
Enumeration enumAtts = m_Instances.enumerateAttributes();
int attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
text.append(attribute.name() + ": "
+ m_Distributions[attIndex][i]);
attIndex++;
}
}
} catch (Exception ex) {
text.append(ex.getMessage());
}
}
return text.toString();
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new PBayes(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}