/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.naive_bayes;
import java.util.List;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Classifier;
public class GaussianNB extends Classifier {
public GaussianNB(String module, String name){
super(module, name);
}
@Override
public int getNumberOfFeatures(){
int[] shape = getThetaShape();
return shape[1];
}
@Override
public NaiveBayesModel encodeModel(Schema schema){
int[] shape = getThetaShape();
int numberOfClasses = shape[0];
int numberOfFeatures = shape[1];
List<? extends Number> theta = getTheta();
List<? extends Number> sigma = getSigma();
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
BayesInputs bayesInputs = new BayesInputs();
for(int i = 0; i < numberOfFeatures; i++){
Feature feature = schema.getFeature(i);
List<? extends Number> means = CMatrixUtil.getColumn(theta, numberOfClasses, numberOfFeatures, i);
List<? extends Number> variances = CMatrixUtil.getColumn(sigma, numberOfClasses, numberOfFeatures, i);
ContinuousFeature continuousFeature = feature.toContinuousFeature();
BayesInput bayesInput = new BayesInput(continuousFeature.getName())
.setTargetValueStats(encodeTargetValueStats(categoricalLabel.getValues(), means, variances));
bayesInputs.addBayesInputs(bayesInput);
}
List<Integer> classCount = getClassCount();
BayesOutput bayesOutput = new BayesOutput(categoricalLabel.getName(), null)
.setTargetValueCounts(encodeTargetValueCounts(categoricalLabel.getValues(), classCount));
NaiveBayesModel naiveBayesModel = new NaiveBayesModel(0d, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema), bayesInputs, bayesOutput)
.setOutput(ModelUtil.createProbabilityOutput(schema));
return naiveBayesModel;
}
public List<Integer> getClassCount(){
return ValueUtil.asIntegers((List)ClassDictUtil.getArray(this, "class_count_"));
}
public List<? extends Number> getTheta(){
return (List)ClassDictUtil.getArray(this, "theta_");
}
public List<? extends Number> getSigma(){
return (List)ClassDictUtil.getArray(this, "sigma_");
}
private int[] getThetaShape(){
return ClassDictUtil.getShape(this, "theta_", 2);
}
static
private TargetValueStats encodeTargetValueStats(List<String> values, List<? extends Number> means, List<? extends Number> variances){
TargetValueStats targetValueStats = new TargetValueStats();
ClassDictUtil.checkSize(values, means, variances);
for(int i = 0; i < values.size(); i++){
GaussianDistribution gaussianDistribution = new GaussianDistribution(ValueUtil.asDouble(means.get(i)), ValueUtil.asDouble(variances.get(i)));
TargetValueStat targetValueStat = new TargetValueStat(values.get(i))
.setContinuousDistribution(gaussianDistribution);
targetValueStats.addTargetValueStats(targetValueStat);
}
return targetValueStats;
}
static
private TargetValueCounts encodeTargetValueCounts(List<String> values, List<Integer> counts){
TargetValueCounts targetValueCounts = new TargetValueCounts();
ClassDictUtil.checkSize(values, counts);
for(int i = 0; i < values.size(); i++){
TargetValueCount targetValueCount = new TargetValueCount(values.get(i), counts.get(i));
targetValueCounts.addTargetValueCounts(targetValueCount);
}
return targetValueCounts;
}
}