/* * Copyright (c) 2017 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package sklearn.preprocessing; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import com.google.common.collect.Iterables; import numpy.core.NDArray; import org.dmg.pmml.FieldName; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Decorator; import org.jpmml.converter.Feature; import org.jpmml.converter.WildcardFeature; import org.jpmml.sklearn.SkLearnEncoder; import org.junit.Test; import sklearn.Transformer; import sklearn2pmml.decoration.CategoricalDomain; import sklearn2pmml.decoration.ContinuousDomain; import sklearn_pandas.DataFrameMapper; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; public class ImputerTest { @Test public void encodeCategorical(){ FieldName name = FieldName.create("x"); FieldName imputedName = FieldName.create("imputer(x)"); Imputer imputer = new Imputer("sklearn.preprocessing.imputation", "Imputer"); imputer.put("strategy", "most_frequent"); imputer.put("missing_values", "NaN"); imputer.put("statistics_", 0); SkLearnEncoder encoder = new SkLearnEncoder(); Feature feature = encodeFeature(name.getValue(), Arrays.asList(imputer), encoder); assertNotNull(encoder.getDataField(name)); assertNull(encoder.getDerivedField(imputedName)); List<Decorator> decorators = encoder.getDecorators(name); assertEquals(1, decorators.size()); assertTrue(feature instanceof WildcardFeature); assertEquals(name, feature.getName()); NDArray array = new NDArray(); array.put("data", Arrays.asList(0, 1, 2, 3, 4, 5, 6)); array.put("fortran_order", Boolean.FALSE); CategoricalDomain categoricalDomain = new CategoricalDomain("sklearn2pmml.decoration", "CategoricalDomain"); categoricalDomain.put("invalid_value_treatment", "as_is"); categoricalDomain.put("data_", array); encoder = new SkLearnEncoder(); feature = encodeFeature(name.getValue(), Arrays.asList(categoricalDomain, imputer), encoder); assertNotNull(encoder.getDataField(name)); assertNull(encoder.getDerivedField(imputedName)); decorators = encoder.getDecorators(name); assertEquals(2, decorators.size()); assertTrue(feature instanceof CategoricalFeature); assertEquals(name, feature.getName()); } @Test public void encodeContinuous(){ FieldName name = FieldName.create("x"); FieldName imputedName = FieldName.create("imputer(x)"); FieldName binarizedName = FieldName.create("binarizer(x)"); FieldName imputedBinarizedName = FieldName.create("imputer(" + binarizedName.getValue() + ")"); Imputer imputer = new Imputer("sklearn.preprocessing.imputation", "Imputer"); imputer.put("strategy", "mean"); imputer.put("missing_values", -999); imputer.put("statistics_", 0.5d); SkLearnEncoder encoder = new SkLearnEncoder(); Feature feature = encodeFeature(name.getValue(), Arrays.asList(imputer), encoder); assertNotNull(encoder.getDataField(name)); assertNull(encoder.getDerivedField(imputedName)); List<Decorator> decorators = encoder.getDecorators(name); assertEquals(1, decorators.size()); assertTrue(feature instanceof WildcardFeature); assertEquals(name, feature.getName()); ContinuousDomain continuousDomain = new ContinuousDomain("sklearn2pmml.decoration", "ContinuousDomain"); continuousDomain.put("invalid_value_treatment", "return_invalid"); continuousDomain.put("data_min_", 0d); continuousDomain.put("data_max_", 1d); encoder = new SkLearnEncoder(); feature = encodeFeature(name.getValue(), Arrays.asList(continuousDomain, imputer), encoder); assertNotNull(encoder.getDataField(name)); assertNull(encoder.getDerivedField(imputedName)); decorators = encoder.getDecorators(name); assertEquals(3, decorators.size()); assertTrue(feature instanceof ContinuousFeature); assertEquals(name, feature.getName()); Binarizer binarizer = new Binarizer("sklearn.preprocessing.data", "Binarizer"); binarizer.put("threshold", 1d / 3d); encoder = new SkLearnEncoder(); feature = encodeFeature(name.getValue(), Arrays.asList(continuousDomain, binarizer, imputer), encoder); assertNotNull(encoder.getDataField(name)); assertNotNull(encoder.getDerivedField(binarizedName)); assertNotNull(encoder.getDerivedField(imputedBinarizedName)); decorators = encoder.getDecorators(name); assertEquals(2, decorators.size()); assertTrue(feature instanceof ContinuousFeature); assertEquals(imputedBinarizedName, feature.getName()); } static private Feature encodeFeature(String name, List<? extends Transformer> transformers, SkLearnEncoder encoder){ DataFrameMapper dataFrameMapper = new DataFrameMapper("sklearn_pandas.dataframe_mapper", "DataFrameMapper") .setDefault(Boolean.FALSE) .setFeatures(Collections.singletonList(new Object[]{name, transformers})); List<Feature> features = dataFrameMapper.encodeFeatures(new ArrayList<Feature>(), encoder); return Iterables.getOnlyElement(features); } }