/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.bayes;
import rapaio.core.tools.DVector;
import rapaio.data.Frame;
import rapaio.data.Var;
import rapaio.data.VarType;
import rapaio.data.filter.FFilter;
import rapaio.ml.classifier.AbstractClassifier;
import rapaio.ml.classifier.CFit;
import rapaio.ml.classifier.bayes.estimator.*;
import rapaio.ml.common.Capabilities;
import rapaio.sys.WS;
import rapaio.util.Tag;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import java.util.stream.IntStream;
/**
* Naive Bayes Classifier.
* <p>
*
* @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a>
*/
public class NaiveBayes extends AbstractClassifier {
private static final long serialVersionUID = -7602854063045679683L;
private static final Logger logger = Logger.getLogger(NaiveBayes.class.getName());
// algorithm parameters
public static Tag<PriorSupplier> PRIORS_MLE = Tag.valueOf("PRIORS_MLE", (df, weights, nb) -> {
Map<String, Double> priors = new HashMap<>();
DVector dv = DVector.fromWeights(false, df.var(nb.firstTargetName()), weights, nb.firstTargetLevels());
dv.normalize();
for (int i = 1; i < nb.firstTargetLevels().length; i++) {
priors.put(nb.firstTargetLevels()[i], dv.get(i));
}
return priors;
});
public static Tag<PriorSupplier> PRIORS_UNIFORM = Tag.valueOf("PRIORS_UNIFORM", (df, weights, nb) -> {
Map<String, Double> priors = new HashMap<>();
double p = 1.0 / nb.firstTargetLevels().length;
for (int i = 1; i < nb.firstTargetLevels().length; i++) {
priors.put(nb.firstTargetLevels()[i], p);
}
return priors;
});
private double laplaceSmoother = 1;
private BinaryEstimator binEstimator = new MultinomialPmf();
private NumericEstimator numEstimator = new GaussianPdf();
// prediction artifacts
private NominalEstimator nomEstimator = new MultinomialPmf();
private Tag<PriorSupplier> priorSupplier = PRIORS_MLE;
private Map<String, Double> priors;
private Map<String, NumericEstimator> numMap;
private Map<String, NominalEstimator> nomMap;
private Map<String, BinaryEstimator> binMap;
@Override
public NaiveBayes newInstance() {
return new NaiveBayes()
.withBinEstimator(binEstimator)
.withNumEstimator(numEstimator)
.withNomEstimator(nomEstimator)
.withLaplaceSmoother(laplaceSmoother)
.withPriorSupplier(priorSupplier);
}
@Override
public String name() {
return "NaiveBayes";
}
@Override
public String fullName() {
return name() + "(numEstimator=" + numEstimator.name() + ", nomEstimator=" + nomEstimator.name() + ")";
}
@Override
public Capabilities capabilities() {
return new Capabilities()
.withInputCount(0, 1_000_000)
.withInputTypes(VarType.NOMINAL, VarType.NUMERIC, VarType.INDEX, VarType.BINARY)
.withTargetCount(1, 1)
.withTargetTypes(VarType.NOMINAL)
.withAllowMissingTargetValues(false)
.withAllowMissingInputValues(true);
}
public NaiveBayes withBinEstimator(BinaryEstimator binEstimator) {
this.binEstimator = binEstimator;
return this;
}
public NaiveBayes withNumEstimator(NumericEstimator numEstimator) {
this.numEstimator = numEstimator;
return this;
}
public NaiveBayes withNomEstimator(NominalEstimator nomEstimator) {
this.nomEstimator = nomEstimator;
return this;
}
public NaiveBayes withLaplaceSmoother(double laplaceSmoother) {
this.laplaceSmoother = laplaceSmoother;
return this;
}
public double laplaceSmoother() {
return laplaceSmoother;
}
public NaiveBayes withPriorSupplier(Tag<PriorSupplier> priorSupplier) {
this.priorSupplier = priorSupplier;
return this;
}
@Override
protected boolean coreTrain(Frame df, Var weights) {
// build priors
priors = PRIORS_MLE.get().learnPriors(df, weights, this);
// build conditional probabilities
nomMap = new ConcurrentHashMap<>();
numMap = new ConcurrentHashMap<>();
binMap = new ConcurrentHashMap<>();
logger.fine("start learning...");
Arrays.stream(df.varNames()).parallel().forEach(
testCol -> {
if (firstTargetName().equals(testCol)) {
return;
}
if (df.var(testCol).type().isBinary()) {
BinaryEstimator estimator = binEstimator.newInstance();
estimator.learn(this, df, weights, firstTargetName(), testCol);
binMap.put(testCol, estimator);
return;
}
if (df.var(testCol).type().isNumeric()) {
NumericEstimator estimator = numEstimator.newInstance();
estimator.learn(df, firstTargetName(), testCol);
numMap.put(testCol, estimator);
return;
}
if (df.var(testCol).type().isNominal()) {
NominalEstimator estimator = nomEstimator.newInstance();
estimator.learn(this, df, weights, firstTargetName(), testCol);
nomMap.put(testCol, estimator);
}
});
logger.fine("learning phase finished");
return true;
}
@Override
protected CFit coreFit(Frame df, final boolean withClasses, final boolean withDensities) {
logger.fine("start fitting values...");
CFit pred = CFit.build(this, df, withClasses, withDensities);
IntStream.range(0, df.rowCount()).parallel().forEach(
i -> {
DVector dv = DVector.empty(false, firstTargetLevels());
for (int j = 1; j < firstTargetLevels().length; j++) {
double sumLog = Math.log(priors.get(firstTargetLevel(j)));
for (String testCol : numMap.keySet()) {
if (df.missing(i, testCol))
continue;
sumLog += Math.log(numMap.get(testCol).cpValue(df.value(i, testCol), firstTargetLevel(j)));
}
for (String testCol : nomMap.keySet()) {
if (df.missing(i, testCol)) {
continue;
}
sumLog += Math.log(nomMap.get(testCol).cpValue(df.label(i, testCol), firstTargetLevel(j)));
}
for (String testCol : binMap.keySet()) {
if (df.missing(i, testCol)) {
continue;
}
sumLog += Math.log(binMap.get(testCol).cpValue(df.label(i, testCol), firstTargetLevel(j)));
}
dv.increment(j, Math.exp(sumLog));
}
dv.normalize();
if (withClasses) {
pred.firstClasses().setIndex(i, dv.findBestIndex());
}
if (withDensities) {
for (int j = 1; j < firstTargetLevels().length; j++) {
pred.firstDensity().setValue(i, j, dv.get(j));
}
}
});
logger.fine("fitting phase finished.");
return pred;
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
sb.append("NaiveBayes model\n");
sb.append("================\n\n");
sb.append("Description:\n");
sb.append(fullName()).append("\n\n");
sb.append("Capabilities:\n");
sb.append(capabilities().summary()).append("\n");
sb.append("Learned model:\n");
if (!hasLearned()) {
sb.append("Learning phase not called\n\n");
return sb.toString();
}
sb.append(baseSummary());
sb.append("prior probabilities:\n");
String targetName = firstTargetName();
Arrays.stream(firstTargetLevels()).skip(1).forEach(label -> sb.append("> P(").append(targetName).append("='").append(label).append("')=").append(WS.formatFlex(priors.get(label))).append("\n"));
if (!numMap.isEmpty()) {
sb.append("numerical estimators:\n");
numMap.entrySet().forEach(e -> sb.append("> ").append(e.getKey()).append(" : ").append(e.getValue().learningInfo()).append("\n"));
}
if (!nomMap.isEmpty()) {
sb.append("nominal estimators:\n");
nomMap.entrySet().forEach(e -> sb.append("> ").append(e.getKey()).append(" : ").append(e.getValue().learningInfo()).append("\n"));
}
return sb.toString();
}
@Override
public NaiveBayes withInputFilters(List<FFilter> filters) {
return (NaiveBayes) super.withInputFilters(filters);
}
@Override
public NaiveBayes withInputFilters(FFilter... filters) {
return (NaiveBayes) super.withInputFilters(filters);
}
interface PriorSupplier extends Serializable {
Map<String, Double> learnPriors(Frame df, Var weights, NaiveBayes nb);
}
}