/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.bayesian;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.chombo.util.BinCount;
import org.chombo.util.FeatureCount;
/**
* Bayesian model related probability distributions
* @author pranab
*
*/
public class BayesianModel {
private List<FeaturePosterior> featurePosteriors = new ArrayList<FeaturePosterior>();
private List<FeatureCount> featurePriors = new ArrayList<FeatureCount>();
private int count;
/**
* @param classValue
* @return
*/
public double getClassPriorProb(String classValue) {
FeaturePosterior feaPost = getFeaturePosterior(classValue);
return feaPost.getProb();
}
/**
* @param featureValues
* @return
*/
public double getFeaturePriorProb(List<Pair<Integer, Object>> featureValues) {
double prob = 1.0;
for (Pair<Integer, Object> feature : featureValues) {
FeatureCount feaCount = getFeatureCount( feature.getLeft());
if (feature.getRight() instanceof String) {
//categorical or binned numerical
prob *= feaCount.getProb((String)feature.getRight());
} else {
//continuous numerical
prob *= feaCount.getProb((Integer)feature.getRight());
}
}
return prob;
}
/**
* @param classVal
* @param featureValues
* @return
*/
public double getFeaturePostProb(String classVal, List<Pair<Integer, Object>> featureValues) {
FeaturePosterior feaPost = getFeaturePosterior(classVal);
double prob = feaPost.getFeaturePostProb(featureValues);
return prob;
}
/**
* @param classValue
* @param count
*/
public void addClassPrior(String classValue, int count) {
FeaturePosterior feaPost = getFeaturePosterior(classValue);
feaPost.addCount(count);
}
/**
* @param ordinal
* @param bin
* @param count
*/
public void addFeaturePrior(int ordinal, String bin, int count) {
FeatureCount feaCount = getFeatureCount( ordinal);
BinCount binCount = new BinCount(bin, count);
feaCount.addBinCount(binCount);
}
/**
* @param ordinal
* @param mean
* @param stdDev
*/
public void setFeaturePriorParaemeters(int ordinal, long mean, long stdDev) {
FeatureCount feaCount = getFeatureCount( ordinal);
feaCount.setDistrParameters(mean, stdDev);
}
/**
* @param classValue
* @param ordinal
* @param bin
* @param count
*/
public void addFeaturePosterior(String classValue, int ordinal, String bin, int count) {
FeaturePosterior feaPost = getFeaturePosterior(classValue);
FeatureCount feaCount = feaPost.getFeatureCount( ordinal);
BinCount binCount = new BinCount(bin, count);
feaCount.addBinCount(binCount);
}
/**
* @param classValue
* @param ordinal
* @param mean
* @param stdDev
*/
public void setFeaturePosteriorParaemeters(String classValue, int ordinal, long mean, long stdDev) {
FeaturePosterior feaPost = getFeaturePosterior(classValue);
FeatureCount feaCount = feaPost.getFeatureCount(ordinal);
feaCount.setDistrParameters(mean, stdDev);
}
/**
* @param ordinal
* @return
*/
private FeatureCount getFeatureCount(int ordinal) {
FeatureCount feaCount = null;
for (FeatureCount thisFeaCount : featurePriors) {
if (thisFeaCount.getOrdinal() == ordinal) {
feaCount = thisFeaCount;
break;
}
}
if (null == feaCount) {
feaCount = new FeatureCount(ordinal, "");
featurePriors.add(feaCount);
}
return feaCount;
}
/**
* @param classValue
* @return
*/
private FeaturePosterior getFeaturePosterior(String classValue) {
FeaturePosterior feaPost = null;
for (FeaturePosterior thisFeaPost : featurePosteriors) {
if (thisFeaPost.getClassValue().equals(classValue)) {
feaPost = thisFeaPost;
break;
}
}
if (null == feaPost) {
feaPost = new FeaturePosterior();
feaPost.setClassValue(classValue);
featurePosteriors.add(feaPost);
}
return feaPost;
}
/**
* @return
*/
public List<FeaturePosterior> getFeaturePosteriors() {
return featurePosteriors;
}
/**
* @param featurePosteriors
*/
public void setFeaturePosteriors(List<FeaturePosterior> featurePosteriors) {
this.featurePosteriors = featurePosteriors;
}
/**
* @return
*/
public List<FeatureCount> getFeaturePriors() {
return featurePriors;
}
/**
* @param featurePriors
*/
public void setFeaturePriors(List<FeatureCount> featurePriors) {
this.featurePriors = featurePriors;
}
/**
* @return
*/
public int getCount() {
return count;
}
/**
* @param count
*/
public void setCount(int count) {
this.count = count;
}
/**
*
*/
public void finishUp() {
//total count by adding all class prior counts
count = 0;
for (FeaturePosterior thisFeaPost : featurePosteriors) {
count += thisFeaPost.getCount();
}
//class prior and feature posterior
for (FeaturePosterior thisFeaPost : featurePosteriors) {
thisFeaPost.normalize(count);
}
//feature prior
for (FeatureCount thisFeaCount : featurePriors) {
thisFeaCount.normalize(count);
}
}
}