/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.tree;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.avenir.model.ProbabilisticPredictiveModel;
import org.avenir.tree.DecisionPathList.DecisionPath;
import org.avenir.tree.DecisionPathList.DecisionPathPredicate;
import org.chombo.mr.FeatureField;
import org.chombo.util.FeatureSchema;
import org.chombo.util.Pair;
import org.codehaus.jackson.map.ObjectMapper;
/**
* @author pranab
*
*/
public class DecisionTreeModel extends ProbabilisticPredictiveModel {
private DecisionPathList decPathList;
private Map<DecisionPathPredicate, Boolean> predicateValues = new HashMap<DecisionPathPredicate, Boolean>();
private Map<Integer, FeatureField> fields = new HashMap<Integer, FeatureField>();
public DecisionTreeModel(FeatureSchema schema, InputStream modelStream) throws IOException {
super(schema);
if (null != modelStream) {
ObjectMapper mapper = new ObjectMapper();
decPathList = mapper.readValue(modelStream, DecisionPathList.class);
} else {
throw new IllegalStateException("null stteam for model");
}
}
/* (non-Javadoc)
* @see org.avenir.model.PredictiveModel#predictClassProb(java.lang.String[])
*/
@Override
protected Pair<String, Double> predictClassProb(String[] items) {
predicateValues.clear();
DecisionPath decPathMatched = null;
for (DecisionPath decPath : decPathList.getDecisionPaths()) {
boolean eval = true;
for (DecisionPathPredicate predicate : decPath.getPredicates()) {
Boolean predEval = predicateValues.get(predicate);
if (null != predEval) {
eval = eval && predEval;
} else {
int attrOrd = predicate.getAttribute();
FeatureField field = fields.get(attrOrd);
if (null == field) {
field = schema.findFieldByOrdinal(attrOrd);
fields.put(attrOrd, field);
}
predEval = evaluate(predicate, field, items);
eval = eval && predEval;
predicateValues.put(predicate, predEval);
}
if (!eval) {
//go to next decision path
break;
}
}
if (eval) {
//done
decPathMatched = decPath;
break;
}
}
return decPathMatched.getPrediction();
}
/**
* @param predicate
* @param field
* @param items
* @return
*/
private boolean evaluate(DecisionPathPredicate predicate, FeatureField field, String[] items ) {
boolean predEval = false;
int attrOrd = field.getOrdinal();
String operator = predicate.getOperator();
if (field.isInteger()) {
int operand = Integer.parseInt(items[attrOrd]);
int operandValue = predicate.getValueInt();
if (operator.equals(DecisionPathPredicate.OP_LE)) {
predEval = operand <= operandValue;
} else if (operator.equals(DecisionPathPredicate.OP_GT)) {
predEval = operand > operandValue;
Integer otherBound = predicate.getOtherBoundInt();
if (null != otherBound) {
predEval = predEval && operand <= otherBound;
}
} else {
throw new IllegalStateException("invalid operator type for int attribute");
}
} else if (field.isDouble()) {
double operand = Double.parseDouble(items[attrOrd]);
double operandValue = predicate.getValueDbl();
if (operator.equals(DecisionPathPredicate.OP_LE)) {
predEval = operand <= operandValue;
} else if (operator.equals(DecisionPathPredicate.OP_GT)) {
predEval = operand > operandValue;
Integer otherBound = predicate.getOtherBoundInt();
if (null != otherBound) {
predEval = predEval && operand <= otherBound;
}
} else {
throw new IllegalStateException("invalid operator type for double attribute");
}
} else if (field.isCategorical()) {
String operand = items[attrOrd];
List<String>operandValue = predicate.getCategoricalValues();
if (operator.equals(DecisionPathPredicate.OP_IN)) {
predEval = operandValue.contains(operand);
} else {
throw new IllegalStateException("invalid operator type for categorical attribute");
}
}
return predEval;
}
}