/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.tree;
import java.util.ArrayList;
import java.util.List;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import sklearn.Estimator;
public class TreeModelUtil {
private TreeModelUtil(){
}
static
public <E extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators, MiningFunction miningFunction, Schema schema){
PredicateManager predicateManager = new PredicateManager();
return encodeTreeModelSegmentation(estimators, predicateManager, miningFunction, schema);
}
static
public <E extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators, final PredicateManager predicateManager, final MiningFunction miningFunction, final Schema schema){
Function<E, TreeModel> function = new Function<E, TreeModel>(){
@Override
public TreeModel apply(E estimator){
Schema treeModelSchema = toTreeModelSchema(schema.toAnonymousSchema(), estimator.getDataType());
return TreeModelUtil.encodeTreeModel(estimator, predicateManager, miningFunction, treeModelSchema);
}
};
return new ArrayList<>(Lists.transform(estimators, function));
}
static
public <E extends Estimator & HasTree> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, Schema schema){
PredicateManager predicateManager = new PredicateManager();
return encodeTreeModel(estimator, predicateManager, miningFunction, schema);
}
static
public <E extends Estimator & HasTree> TreeModel encodeTreeModel(E estimator, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema){
Tree tree = estimator.getTree();
int[] leftChildren = tree.getChildrenLeft();
int[] rightChildren = tree.getChildrenRight();
int[] features = tree.getFeature();
double[] thresholds = tree.getThreshold();
double[] values = tree.getValues();
Node root = new Node()
.setId("1")
.setPredicate(new True());
encodeNode(root, predicateManager, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema), root)
.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
static
private void encodeNode(Node node, PredicateManager predicateManager, int index, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, MiningFunction miningFunction, Schema schema){
int featureIndex = features[index];
// A non-leaf (binary split) node
if(featureIndex >= 0){
Feature feature = schema.getFeature(featureIndex);
float threshold = (float)thresholds[index];
Predicate leftPredicate;
Predicate rightPredicate;
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
if(threshold < 0 || threshold > 1){
throw new IllegalArgumentException();
}
String value = binaryFeature.getValue();
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
} else
{
ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT);
String value = ValueUtil.formatValue(threshold);
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
int leftIndex = leftChildren[index];
int rightIndex = rightChildren[index];
Node leftChild = new Node()
.setId(String.valueOf(leftIndex + 1))
.setPredicate(leftPredicate);
encodeNode(leftChild, predicateManager, leftIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
Node rightChild = new Node()
.setId(String.valueOf(rightIndex + 1))
.setPredicate(rightPredicate);
encodeNode(rightChild, predicateManager, rightIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
node.addNodes(leftChild, rightChild);
} else
// A leaf node
{
if((MiningFunction.CLASSIFICATION).equals(miningFunction)){
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
double[] scoreRecordCounts = getRow(values, leftChildren.length, categoricalLabel.size(), index);
double recordCount = 0;
for(double scoreRecordCount : scoreRecordCounts){
recordCount += scoreRecordCount;
}
node.setRecordCount(recordCount);
String score = null;
Double probability = null;
for(int i = 0; i < categoricalLabel.size(); i++){
String value = categoricalLabel.getValue(i);
ScoreDistribution scoreDistribution = new ScoreDistribution(value, scoreRecordCounts[i]);
node.addScoreDistributions(scoreDistribution);
double scoreProbability = (scoreRecordCounts[i] / recordCount);
if(probability == null || probability.compareTo(scoreProbability) < 0){
score = scoreDistribution.getValue();
probability = scoreProbability;
}
}
node.setScore(score);
} else
if((MiningFunction.REGRESSION).equals(miningFunction)){
String score = ValueUtil.formatValue(values[index]);
node.setScore(score);
} else
{
throw new IllegalArgumentException();
}
}
}
static
public Schema toTreeModelSchema(Schema schema, final DataType dataType){
Function<Feature, Feature> function = new Function<Feature, Feature>(){
@Override
public Feature apply(Feature feature){
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
return binaryFeature;
} else
{
ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
return continuousFeature;
}
}
};
return schema.toTransformedSchema(function);
}
static
private double[] getRow(double[] values, int rows, int columns, int row){
if(values.length != (rows * columns)){
throw new IllegalArgumentException("Expected " + (rows * columns) + " element(s), got " + values.length + " element(s)");
}
double[] result = new double[columns];
System.arraycopy(values, (row * columns), result, 0, columns);
return result;
}
}