/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.tree; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import org.chombo.mr.FeatureField; import org.chombo.util.BasicUtils; import org.chombo.util.FeatureSchema; import org.chombo.util.Pair; /** * List of decisions paths * @author pranab * */ public class DecisionPathList { private List<DecisionPath> decisionPaths; public List<DecisionPath> getDecisionPaths() { return decisionPaths; } public void setDecisionPaths(List<DecisionPath> decisionPaths) { this.decisionPaths = decisionPaths; } public void addDecisionPath(DecisionPath decPath) { if (null == decisionPaths) { decisionPaths = new ArrayList<DecisionPath>(); } decisionPaths.add(decPath); } /** * @param predcateStrings * @return */ public DecisionPath findDecisionPath(String[] predcateStrings) { DecisionPath foundDecPath = null; for (DecisionPath decPath : decisionPaths) { if (decPath.isMatchedByPredicates(predcateStrings)) { foundDecPath = decPath; break; } } return foundDecPath; } /** * @param predcateString * @return */ public DecisionPath findDecisionPath(String predcateString) { DecisionPath foundDecPath = null; for (DecisionPath decPath : decisionPaths) { if (decPath.isMatchedByPredicateString(predcateString)) { foundDecPath = decPath; break; } } return foundDecPath; } /** * @param predicates * @return */ public static String[] stripSplitId(String[] predicates) { String[] strippedPredicates = new String[predicates.length]; for (int i = 0; i < predicates.length; ++i ) { if (predicates[i].equals(DecisionTreeBuilder.ROOT_PATH)) { strippedPredicates[i] = predicates[i]; } else { strippedPredicates[i] = BasicUtils.splitOnFirstOccurence(predicates[i], DecisionTreeBuilder.SPLIT_DELIM, true)[1]; } } return strippedPredicates; } /** * Decision path containing a list of predicates * @author pranab * */ public static class DecisionPath { private List<DecisionPathPredicate> predicates; private int population; private double infoContent; private boolean stopped; private Map<String, Double> classValPr; public DecisionPath() { } /** * @param predicates * @param population * @param infoContent * @param stopped */ public DecisionPath(List<DecisionPathPredicate> predicates, int population, double infoContent, boolean stopped, Map<String, Double> classValPr) { super(); this.predicates = predicates; this.population = population; this.infoContent = infoContent; this.stopped = stopped; this.classValPr = classValPr; } /** * @param population * @param infoContent */ public DecisionPath(int population, double infoContent, Map<String, Double> classValPr) { this.population = population; this.infoContent = infoContent; this.stopped = false; this.classValPr = classValPr; } /** * @param predcateStrings * @return */ public boolean isMatchedByPredicates(String[] predcateStrings) { boolean matched = true; if (null == predicates ) { //root matched = predcateStrings[0].equals(DecisionTreeBuilder.ROOT_PATH); } else { int i = 0; for (DecisionPathPredicate predicate : predicates) { if (!predicate.getPredicateStr().equals(predcateStrings[i++])) { matched = false; break; } } } return matched; } /** * @param predcateString * @return */ public boolean isMatchedByPredicateString(String predcateString) { boolean matched = false;; if (null == predicates) { matched = predcateString.equals(DecisionTreeBuilder.ROOT_PATH); } else { matched = toStringAllPredicate().equals(predcateString); } return matched; } /** * @return */ public List<DecisionPathPredicate> getPredicates() { return predicates; } /** * @param predicates */ public void setPredicates(List<DecisionPathList.DecisionPathPredicate> predicates) { this.predicates = predicates; } /** * @return */ public int getPopulation() { return population; } /** * @param population */ public void setPopulation(int population) { this.population = population; } /** * @return */ public double getInfoContent() { return infoContent; } /** * @param infoContent */ public void setInfoContent(double infoContent) { this.infoContent = infoContent; } /** * @return */ public boolean isStopped() { return stopped; } /** * @param stopped */ public void setStopped(boolean stopped) { this.stopped = stopped; } /** * @return */ public Map<String, Double> getClassValPr() { return classValPr; } /** * @param classValPr */ public void setClassValPr(Map<String, Double> classValPr) { this.classValPr = classValPr; } /** * @return */ public String toStringAllPredicate() { List<String> strPredicates = new ArrayList<String>(); for (DecisionPathPredicate predicate : predicates) { strPredicates.add(predicate.toString()); } return BasicUtils.join(strPredicates, DecisionTreeBuilder.PRED_DELIM); } /** * @return */ public Pair<String, Double> getPrediction() { String predClVal = null; double maxProb = 0; for (String clVal : classValPr.keySet()) { if (classValPr.get(clVal) > maxProb) { predClVal = clVal; maxProb = classValPr.get(clVal); } } return new Pair<String, Double>(predClVal, maxProb); } } /** * Decision path predicate * @author pranab * */ public static class DecisionPathPredicate { private int attribute; private String operator; private int valueInt; private double valueDbl; private List<String> categoricalValues; private Integer otherBoundInt; private Double otherBoundDbl; private String predicateStr; public static final String OP_LE = "le"; public static final String OP_GT = "gt"; public static final String OP_IN = "in"; /** * @param predicateStr * @return */ public static DecisionPathPredicate createRootPredicate(String predicateStr) { DecisionPathPredicate predicate = new DecisionPathPredicate() ; predicate.setPredicateStr(predicateStr); return predicate; } /** * @param predicateStr * @return */ public static DecisionPathPredicate createIntPredicate(String predicateStr) { DecisionPathPredicate predicate = new DecisionPathPredicate() ; String[] items = predicateStr.split("\\s+"); predicate.setAttribute(Integer.parseInt(items[0])); predicate.setOperator(items[1]); predicate.setValueInt(Integer.parseInt(items[2])); if (items.length == 4) { predicate.setOtherBoundInt(Integer.parseInt(items[3])); } predicate.setPredicateStr(predicateStr); return predicate; } /** * @param predicateStr * @return */ public static DecisionPathPredicate createDoublePredicate(String predicateStr) { DecisionPathPredicate predicate = new DecisionPathPredicate() ; String[] items = predicateStr.split("\\s+"); predicate.setAttribute(Integer.parseInt(items[0])); predicate.setOperator(items[1]); predicate.setValueDbl(Double.parseDouble(items[2])); if (items.length == 4) { predicate.setOtherBoundDbl(Double.parseDouble(items[3])); } predicate.setPredicateStr(predicateStr); return predicate; } /** * @param predicateStr * @return */ public static DecisionPathPredicate createCategoricalPredicate(String predicateStr) { DecisionPathPredicate predicate = new DecisionPathPredicate() ; String[] items = predicateStr.split("\\s+"); predicate.setAttribute(Integer.parseInt(items[0])); predicate.setOperator(items[1]); String[] valueArray = items[2].split(":"); List<String> categoricalValues = Arrays.asList(valueArray); predicate.setCategoricalValues(categoricalValues); predicate.setPredicateStr(predicateStr); return predicate; } /** * @param predicatesStr * @param schema * @return */ public static List< DecisionPathList.DecisionPathPredicate> createPredicates(String predicatesStr, FeatureSchema schema) { List< DecisionPathList.DecisionPathPredicate> predicates = new ArrayList< DecisionPathList.DecisionPathPredicate>(); if (predicatesStr.equals(DecisionTreeBuilder.ROOT_PATH)) { predicates.add(DecisionPathPredicate.createRootPredicate(predicatesStr)); } else { String[] predicateItems = predicatesStr.split(";"); for (String predicateItem : predicateItems) { if(predicateItem.equals(DecisionTreeBuilder.ROOT_PATH)) { predicates.add(DecisionPathPredicate.createRootPredicate(predicateItem)); } else { int attr = Integer.parseInt(predicateItem.split("\\s+")[0]); FeatureField field = schema.findFieldByOrdinal(attr); DecisionPathList.DecisionPathPredicate predicate = deserializePredicate(predicateItem, field); predicates.add(predicate); } } } return predicates; } /** * @param predicateStr * @param field * @return */ public static DecisionPathList.DecisionPathPredicate deserializePredicate(String predicateStr, FeatureField field) { DecisionPathList.DecisionPathPredicate predicate = null; if (field.isInteger()) { predicate = DecisionPathList.DecisionPathPredicate.createIntPredicate(predicateStr); } else if (field.isDouble()) { predicate = DecisionPathList.DecisionPathPredicate.createDoublePredicate(predicateStr); } else if (field.isCategorical()) { predicate = DecisionPathList.DecisionPathPredicate.createCategoricalPredicate(predicateStr); } else { throw new IllegalArgumentException("invalid data type for predicates"); } return predicate; } public int getAttribute() { return attribute; } public void setAttribute(int attribute) { this.attribute = attribute; } public String getOperator() { return operator; } public void setOperator(String operator) { this.operator = operator; } public int getValueInt() { return valueInt; } public void setValueInt(int valueInt) { this.valueInt = valueInt; } public double getValueDbl() { return valueDbl; } public void setValueDbl(double valueDbl) { this.valueDbl = valueDbl; } public List<String> getCategoricalValues() { return categoricalValues; } public void setCategoricalValues(List<String> categoricalValues) { this.categoricalValues = categoricalValues; } public Integer getOtherBoundInt() { return otherBoundInt; } public void setOtherBoundInt(Integer otherBoundInt) { this.otherBoundInt = otherBoundInt; } public Double getOtherBoundDbl() { return otherBoundDbl; } public void setOtherBoundDbl(Double otherBoundDbl) { this.otherBoundDbl = otherBoundDbl; } public String getPredicateStr() { return predicateStr; } public void setPredicateStr(String predicateStr) { this.predicateStr = predicateStr; } @Override public int hashCode() { return predicateStr.hashCode(); } @Override public boolean equals(Object obj) { DecisionPathPredicate that = (DecisionPathPredicate)obj; return predicateStr.equals(that.predicateStr); } } }