/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel;
import ml.shifu.shifu.core.dtrain.dt.TreeNode;
import org.apache.commons.lang3.tuple.MutablePair;
import org.encog.ml.BasicML;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;
/**
* {@link TreeModel} is to load Random Forest or Gradient Boosted Decision Tree models from Encog interfaces. If user
* wouldn't like to depend on encog, {@link IndependentTreeModel} can be used to execute tree model from Shifu.
*
* <p>
* {@link #loadFromStream(InputStream, boolean)} can be used to read serialized models. Which is delegated to
* {@link IndependentTreeModel}.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class TreeModel extends BasicML implements MLRegression {
private static final long serialVersionUID = 479043597958785224L;
/**
* Tree model instance without dependency on encog.
*/
private transient IndependentTreeModel independentTreeModel;
/**
* Constructor on current {@link IndependentTreeModel}
*
* @param independentTreeModel
* the independent tree model
*/
public TreeModel(IndependentTreeModel independentTreeModel) {
this.independentTreeModel = independentTreeModel;
}
/**
* Compute model score based on given input double array.
*/
@Override
public final MLData compute(final MLData input) {
double[] data = input.getData();
return new BasicMLData(this.getIndependentTreeModel().compute(data));
}
/**
* How many input columns.
*/
@Override
public int getInputCount() {
return this.getIndependentTreeModel().getInputNode();
}
@Override
public void updateProperties() {
// No need implementation
}
public static TreeModel loadFromStream(InputStream input) throws IOException {
return loadFromStream(input, false);
}
public static TreeModel loadFromStream(InputStream input, boolean isConvertToProb) throws IOException {
return new TreeModel(IndependentTreeModel.loadFromStream(input, isConvertToProb));
}
@Override
public int getOutputCount() {
// mock as output is only 1 dimension
return 1;
}
public String getAlgorithm() {
return this.getIndependentTreeModel().getAlgorithm();
}
public String getLossStr() {
return this.getIndependentTreeModel().getLossStr();
}
public List<TreeNode> getTrees() {
return this.getIndependentTreeModel().getTrees();
}
public boolean isGBDT() {
return this.getIndependentTreeModel().isGBDT();
}
public boolean isClassfication() {
return this.getIndependentTreeModel().isClassification();
}
public IndependentTreeModel getIndependentTreeModel() {
return independentTreeModel;
}
/**
* Get feature importance of current model.
*
* @return map of feature importance, key is column index.
*/
public Map<Integer, MutablePair<String, Double>> getFeatureImportances() {
Map<Integer, MutablePair<String, Double>> importancesSum = new HashMap<Integer, MutablePair<String, Double>>();
Map<Integer, String> nameMapping = this.getIndependentTreeModel().getNumNameMapping();
int size = this.getIndependentTreeModel().getTrees().size();
for(TreeNode tree: this.getIndependentTreeModel().getTrees()) {
Map<Integer, Double> subImportances = tree.computeFeatureImportance();
for(Entry<Integer, Double> entry: subImportances.entrySet()) {
String featureName = nameMapping.get(entry.getKey());
MutablePair<String, Double> importance = MutablePair.of(featureName, entry.getValue());
if(!importancesSum.containsKey(entry.getKey())) {
importance.setValue(importance.getValue() / size);
importancesSum.put(entry.getKey(), importance);
} else {
MutablePair<String, Double> current = importancesSum.get(entry.getKey());
current.setValue(current.getValue() + importance.getValue() / size);
importancesSum.put(entry.getKey(), current);
}
}
}
return importancesSum;
}
/**
* Sort by feature impotance
*
* @return map of feature importance, key is column index.
*/
public static Map<Integer, MutablePair<String, Double>> sortByValue(
Map<Integer, MutablePair<String, Double>> unsortMap, final boolean order) {
List<Entry<Integer, MutablePair<String, Double>>> list = new LinkedList<Entry<Integer, MutablePair<String, Double>>>(
unsortMap.entrySet());
Collections.sort(list, new Comparator<Entry<Integer, MutablePair<String, Double>>>() {
public int compare(Entry<Integer, MutablePair<String, Double>> o1,
Entry<Integer, MutablePair<String, Double>> o2) {
if(order) {
return o1.getValue().getValue().compareTo(o2.getValue().getValue());
} else {
return o2.getValue().getValue().compareTo(o1.getValue().getValue());
}
}
});
// Maintaining insertion order with the help of LinkedList
Map<Integer, MutablePair<String, Double>> sortedMap = new LinkedHashMap<Integer, MutablePair<String, Double>>();
for(Entry<Integer, MutablePair<String, Double>> entry: list) {
sortedMap.put(entry.getKey(), entry.getValue());
}
return sortedMap;
}
@Override
public String toString() {
return this.getIndependentTreeModel().getTrees().toString();
}
}