/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.bayes.estimator;
import rapaio.core.distributions.empirical.KDE;
import rapaio.core.distributions.empirical.KFunc;
import rapaio.core.distributions.empirical.KFuncGaussian;
import rapaio.data.Frame;
import rapaio.data.Var;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 5/18/15.
*/
public class KernelPdf implements NumericEstimator {
private static final long serialVersionUID = 7974390604811353859L;
private Map<String, KDE> kde = new ConcurrentHashMap<>();
private KFunc kfunc = new KFuncGaussian();
private double bandwidth = 0;
public KernelPdf() {
}
public KernelPdf(KFunc kfunc) {
this.kfunc = kfunc;
}
public KernelPdf(KFunc kfunc, double bandwidth) {
this.kfunc = kfunc;
this.bandwidth = bandwidth;
}
@Override
public String name() {
return "EmpiricKDE";
}
@Override
public void learn(Frame df, String targetVar, String testVar) {
kde.clear();
Arrays.stream(df.var(targetVar).levels()).forEach(
classLabel -> {
if ("?".equals(classLabel))
return;
Frame cond = df.stream().filter(s -> classLabel.equals(s.label(targetVar))).toMappedFrame();
Var v = cond.var(testVar);
KDE k = new KDE(v, kfunc, (bandwidth == 0) ? KDE.getSilvermanBandwidth(v) : bandwidth);
kde.put(classLabel, k);
});
}
@Override
public double cpValue(double testValue, String targetLabel) {
return kde.get(targetLabel).pdf(testValue);
}
@Override
public NumericEstimator newInstance() {
return new KernelPdf(kfunc, bandwidth);
}
@Override
public String learningInfo() {
return name() + "{ " + kfunc.summary() + " }";
}
}