/* * Copyright (c) 2016 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package sklearn2pmml; import java.util.ArrayList; import java.util.List; import com.google.common.base.CharMatcher; import com.google.common.base.Function; import com.google.common.collect.Lists; import net.razorvine.pickle.objects.ClassDict; import numpy.core.NDArray; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.Extension; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningBuildTask; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Model; import org.dmg.pmml.OpType; import org.dmg.pmml.PMML; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousLabel; import org.jpmml.converter.Feature; import org.jpmml.converter.Label; import org.jpmml.converter.Schema; import org.jpmml.converter.ValueUtil; import org.jpmml.converter.WildcardFeature; import org.jpmml.sklearn.ClassDictUtil; import org.jpmml.sklearn.SkLearnEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sklearn.Estimator; import sklearn.EstimatorUtil; import sklearn.HasNumberOfFeatures; import sklearn.Initializer; import sklearn.Transformer; import sklearn.TransformerUtil; import sklearn.TypeUtil; import sklearn.pipeline.Pipeline; public class PMMLPipeline extends Pipeline { public PMMLPipeline(){ this("sklearn2pmml", "PMMLPipeline"); } public PMMLPipeline(String module, String name){ super(module, name, true); } public PMML encodePMML(){ List<Transformer> transformers = getTransformers(); Estimator estimator = getEstimator(); String repr = getRepr(); if(estimator == null){ throw new IllegalArgumentException(); } SkLearnEncoder encoder = new SkLearnEncoder(); Label label = null; if(estimator.isSupervised()){ String targetField = getTargetField(); if(targetField == null){ targetField = "y"; logger.warn("The 'target_field' attribute is not set. Assuming {} as the name of the target field", targetField); } MiningFunction miningFunction = estimator.getMiningFunction(); switch(miningFunction){ case CLASSIFICATION: { List<?> classes = EstimatorUtil.getClasses(estimator); DataField dataField = encoder.createDataField(FieldName.create(targetField), OpType.CATEGORICAL, TypeUtil.getDataType(classes, DataType.STRING), formatTargetCategories(classes)); label = new CategoricalLabel(dataField); } break; case REGRESSION: { DataField dataField = encoder.createDataField(FieldName.create(targetField), OpType.CONTINUOUS, DataType.DOUBLE); label = new ContinuousLabel(dataField); } break; default: throw new IllegalArgumentException(); } } List<Feature> features = new ArrayList<>(); Transformer transformer = TransformerUtil.getHead(transformers); if(transformer != null){ if(!(transformer instanceof Initializer)){ features = initFeatures(transformer, transformer.getOpType(), transformer.getDataType(), encoder); } features = encodeFeatures(features, encoder); } else { features = initFeatures(estimator, estimator.getOpType(), estimator.getDataType(), encoder); } int numberOfFeatures = estimator.getNumberOfFeatures(); if(numberOfFeatures > -1){ ClassDictUtil.checkSize(numberOfFeatures, features); } Schema schema = new Schema(label, features); Model model = estimator.encodeModel(schema, encoder); PMML pmml = encoder.encodePMML(model); if(repr != null){ Extension extension = new Extension() .addContent(repr); MiningBuildTask miningBuildTask = new MiningBuildTask() .addExtensions(extension); pmml.setMiningBuildTask(miningBuildTask); } return pmml; } private List<Feature> initFeatures(ClassDict object, OpType opType, DataType dataType, SkLearnEncoder encoder){ List<String> activeFields = getActiveFields(); if(activeFields == null){ int numberOfFeatures = -1; if(object instanceof HasNumberOfFeatures){ HasNumberOfFeatures hasNumberOfFeatures = (HasNumberOfFeatures)object; numberOfFeatures = hasNumberOfFeatures.getNumberOfFeatures(); } // End if if(numberOfFeatures < 0){ throw new IllegalArgumentException("The first transformer or estimator object (" + ClassDictUtil.formatClass(object) + ") does not specify the number of input features"); } activeFields = new ArrayList<>(numberOfFeatures); for(int i = 0, max = numberOfFeatures; i < max; i++){ activeFields.add("x" + String.valueOf(i + 1)); } logger.warn("The 'active_fields' attribute is not set. Assuming {} as the names of active fields", activeFields); } List<Feature> result = new ArrayList<>(); for(String activeField : activeFields){ DataField dataField = encoder.createDataField(FieldName.create(activeField), opType, dataType); result.add(new WildcardFeature(encoder, dataField)); } return result; } @Override public List<Object[]> getSteps(){ return super.getSteps(); } public PMMLPipeline setSteps(List<Object[]> steps){ put("steps", steps); return this; } public String getRepr(){ return (String)get("repr_"); } public PMMLPipeline setRepr(String repr){ put("repr_", repr); return this; } public String getTargetField(){ return (String)get("target_field"); } public PMMLPipeline setTargetField(String targetField){ put("target_field", targetField); return this; } public List<String> getActiveFields(){ if(!containsKey("active_fields")){ return null; } return (List)ClassDictUtil.getArray(this, "active_fields"); } public PMMLPipeline setActiveFields(List<String> activeFields){ NDArray array = new NDArray(); array.put("data", activeFields); array.put("fortran_order", Boolean.FALSE); put("active_fields", array); return this; } static private List<String> formatTargetCategories(List<?> objects){ Function<Object, String> function = new Function<Object, String>(){ @Override public String apply(Object object){ String targetCategory = ValueUtil.formatValue(object); if(targetCategory == null || CharMatcher.WHITESPACE.matchesAnyOf(targetCategory)){ throw new IllegalArgumentException(targetCategory); } return targetCategory; } }; return Lists.transform(objects, function); } private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class); }