/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sklearn;
import java.util.LinkedHashSet;
import java.util.Set;
import com.google.common.collect.ImmutableSet;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Batch;
import org.junit.Test;
public class ClassifierTest extends EstimatorTest {
@Test
public void evaluateDecisionTreeAudit() throws Exception {
evaluate("DecisionTree", "Audit");
}
@Test
public void evaluateDecisionTreeAuditDict() throws Exception {
evaluate("DecisionTree", "AuditDict");
}
@Test
public void evaluateDecisionTreeAuditNA() throws Exception {
evaluate("DecisionTree", "AuditNA");
}
@Test
public void evaluateDecisionTreeEnsembleAudit() throws Exception {
evaluate("DecisionTreeEnsemble", "Audit");
}
@Test
public void evaluateDummyAudit() throws Exception {
evaluate("Dummy", "Audit");
}
@Test
public void evaluatExtraTreesAudit() throws Exception {
evaluate("ExtraTrees", "Audit");
}
@Test
public void evaluateGradientBoostingAudit() throws Exception {
evaluate("GradientBoosting", "Audit");
}
@Test
public void evaluateLGBMAudit() throws Exception {
evaluate("LGBM", "Audit");
}
@Test
public void evaluateLinearDiscriminantAnalysisAudit() throws Exception {
evaluate("LinearDiscriminantAnalysis", "Audit");
}
@Test
public void evaluateLogisticRegressionAudit() throws Exception {
evaluate("LogisticRegression", "Audit");
}
@Test
public void evaluateLogisticRegressionAuditDict() throws Exception {
evaluate("LogisticRegression", "AuditDict");
}
@Test
public void evaluateLogisticRegressionAuditNA() throws Exception {
evaluate("LogisticRegression", "AuditNA");
}
@Test
public void evaluateLogisticRegressionEnsembleAudit() throws Exception {
evaluate("LogisticRegressionEnsemble", "Audit");
}
@Test
public void evaluateNaiveBayesAudit() throws Exception {
evaluate("NaiveBayes", "Audit");
}
@Test
public void evaluateRandomForestAudit() throws Exception {
evaluate("RandomForest", "Audit");
}
@Test
public void evaluateRidgeAudit() throws Exception {
evaluate("Ridge", "Audit");
}
@Test
public void evaluateRidgeEnsembleAudit() throws Exception {
evaluate("RidgeEnsemble", "Audit");
}
@Test
public void evaluateSVCAudit() throws Exception {
evaluate("SVC", "Audit");
}
@Test
public void evaluateVotingEnsembleAudit() throws Exception {
evaluate("VotingEnsemble", "Audit");
}
@Test
public void evaluateXGBAudit() throws Exception {
try(Batch batch = createBatch("XGB", "Audit")){
evaluate(batch, null, 1e-5, 1e-5);
}
}
@Test
public void evaluateDecisionTreeIris() throws Exception {
evaluate("DecisionTree", "Iris");
}
@Test
public void evaluateDecisionTreeEnsembleIris() throws Exception {
evaluate("DecisionTreeEnsemble", "Iris");
}
@Test
public void evaluateDummyIris() throws Exception {
evaluate("Dummy", "Iris");
}
@Test
public void evaluateExtraTreesIris() throws Exception {
evaluate("ExtraTrees", "Iris");
}
@Test
public void evaluateGradientBoostingIris() throws Exception {
evaluate("GradientBoosting", "Iris");
}
@Test
public void evaluateKNNIris() throws Exception {
try(Batch batch = createBatch("KNN", "Iris")){
Set<FieldName> ignoredFields = createFieldList("neighbor", 5);
evaluate(batch, ignoredFields);
}
}
@Test
public void evaluateLGBMIris() throws Exception {
evaluate("LGBM", "Iris");
}
@Test
public void evaluateLinearDiscriminantAnalysisIris() throws Exception {
evaluate("LinearDiscriminantAnalysis", "Iris");
}
@Test
public void evaluateLogisticRegressionIris() throws Exception {
evaluate("LogisticRegression", "Iris");
}
@Test
public void evaluateLogisticRegressionEnsembleIris() throws Exception {
evaluate("LogisticRegressionEnsemble", "Iris");
}
@Test
public void evaluateNaiveBayesIris() throws Exception {
evaluate("NaiveBayes", "Iris");
}
@Test
public void evaluateMLPIris() throws Exception {
evaluate("MLP", "Iris");
}
@Test
public void evaluateRandomForestIris() throws Exception {
evaluate("RandomForest", "Iris");
}
@Test
public void evaluateRidgeIris() throws Exception {
evaluate("Ridge", "Iris");
}
@Test
public void evaluateRidgeEnsembleIris() throws Exception {
evaluate("RidgeEnsemble", "Iris");
}
@Test
public void evaluateSGDIris() throws Exception {
evaluate("SGD", "Iris");
}
@Test
public void evaluateSGDLogIris() throws Exception {
evaluate("SGDLog", "Iris");
}
@Test
public void evaluateSVCIris() throws Exception {
evaluate("SVC", "Iris");
}
@Test
public void evaluateNuSVCIris() throws Exception {
evaluate("NuSVC", "Iris");
}
@Test
public void evaluateVotingEnsembleIris() throws Exception {
try(Batch batch = createBatch("VotingEnsemble", "Iris")){
Set<FieldName> ignoredFields = ImmutableSet.of(FieldName.create("probability(setosa)"), FieldName.create("probability(versicolor)"), FieldName.create("probability(virginica)"));
evaluate(batch, ignoredFields);
}
}
@Test
public void evaluateXGBIris() throws Exception {
try(Batch batch = createBatch("XGB", "Iris")){
evaluate(batch, null, 1e-5, 1e-5);
}
}
@Test
public void evaluateLogisticRegressionSentiment() throws Exception {
evaluate("LogisticRegression", "Sentiment");
}
@Test
public void evaluateRandomForestSentiment() throws Exception {
evaluate("RandomForest", "Sentiment");
}
@Test
public void evaluateDecisionTreeVersicolor() throws Exception {
evaluate("DecisionTree", "Versicolor");
}
@Test
public void evaluateDummyVersicolor() throws Exception {
evaluate("Dummy", "Versicolor");
}
@Test
public void evaluateKNNVersicolor() throws Exception {
try(Batch batch = createBatch("KNN", "Versicolor")){
Set<FieldName> ignoredFields = createFieldList("neighbor", 5);
evaluate(batch, ignoredFields);
}
}
@Test
public void evaluateMLPVersicolor() throws Exception {
evaluate("MLP", "Versicolor");
}
@Test
public void evaluateSGDVersicolor() throws Exception {
evaluate("SGD", "Versicolor");
}
@Test
public void evaluateSGDLogVersicolor() throws Exception {
evaluate("SGDLog", "Versicolor");
}
@Test
public void evaluateSVCVersicolor() throws Exception {
evaluate("SVC", "Versicolor");
}
@Test
public void evaluateNuSVCVersicolor() throws Exception {
evaluate("NuSVC", "Versicolor");
}
static
private Set<FieldName> createFieldList(String prefix, int count){
Set<FieldName> result = new LinkedHashSet<>();
for(int i = 0; i < count; i++){
result.add(FieldName.create(prefix + "(" + (i + 1) + ")"));
}
return result;
}
}