/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.util; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.avenir.explore.ClassPartitionGenerator.PartitionGeneratorReducer; /** * Info stat for splits based various statistical criteria, to enable selection * of optimum split for attributes * @author pranab * */ public class AttributeSplitStat { private int attrOrdinal; private Map<String, SplitStat> splitStats = new HashMap<String, SplitStat>(); private static final Logger LOG = Logger.getLogger(AttributeSplitStat.class); private Set<String> classValues = new HashSet<String>(); public static final String ALG_ENTROPY = "entropy"; public static final String ALG_GINI_INDEX = "giniIndex"; public static final String ALG_HELLINGER_DIST = "hellingerDistance"; public static final String ALG_CLASS_CONF = "classConfidenceRatio"; private String algorithm; public static void enableLog() { LOG.setLevel(Level.DEBUG); } /** * @param attrOrdinal * @param algorithm */ public AttributeSplitStat(int attrOrdinal, String algorithm) { this.attrOrdinal = attrOrdinal; this.algorithm = algorithm; } /** * @param key * @param segmentIndex * @param classVal * @param count */ public void countClassVal(String key, int segmentIndex, String classVal, int count) { SplitStat splitStat = splitStats.get(key); if (null == splitStat) { if (algorithm.equals(ALG_ENTROPY) || algorithm.equals(ALG_GINI_INDEX)) { splitStat = new SplitInfoContent(key); } else if (algorithm.equals(ALG_HELLINGER_DIST)){ splitStat = new SplitHellingerDistance(key); } else { splitStat = new SplitClassCofidenceRatio(key); } splitStats.put(key, splitStat); } splitStat.countClassVal(segmentIndex, classVal, count); classValues.add(classVal); } /** * @param algorithm * @return */ public Map<String, Double> processStat(String algorithm) { Map<String, Double> stats =new HashMap<String, Double>(); for (String key : splitStats.keySet()) { SplitStat splitStat = splitStats.get(key); stats.put(key, splitStat.processStat(algorithm, classValues)); } return stats; } /** * @param splitKey * @return */ public Map<Integer, Map<String, Double>> getClassProbab(String splitKey) { SplitStat splitStat = splitStats.get(splitKey); return splitStat.getClassProbab(); } /** * @param splitKey * @return */ public double getInfoContent(String splitKey) { SplitStat splitStat = splitStats.get(splitKey); return splitStat.getInfoContent(); } /** * Stats for a split which consists of multiple segments * @author pranab * */ private static abstract class SplitStat { protected String key; protected Map<Integer, SplitStatSegment> segments = new HashMap<Integer, SplitStatSegment>(); public SplitStat(String key) { this.key = key; } public void countClassVal(int segmentIndex, String classVal, int count) { LOG.debug("counting SplitStat key:" + key); SplitStatSegment statSegment = segments.get(segmentIndex); if (null == statSegment) { statSegment = new SplitStatSegment(segmentIndex); segments.put(segmentIndex, statSegment); } statSegment.countClassVal(classVal, count); } public abstract double processStat(String algorithm, Set<String> classValues); /** * @return */ public Map<Integer, Map<String, Double>> getClassProbab() { Map<Integer, Map<String, Double>> classProbab = new HashMap<Integer, Map<String, Double>>(); for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); classProbab.put(segmentIndex, statSegment.getClassValPr()); } return classProbab; } /** * @return */ public double getInfoContent() { int totalCount = 0; for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); totalCount += statSegment.getTotalCount(); } double pr = 0; double stat = 0; double log2 = Math.log(2); for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); pr = (double)statSegment.getTotalCount() / totalCount; stat -= pr * Math.log(pr) / log2; } return stat; } } /** * Entropy or gini index based split stat * @author pranab * Entropy or gini index based */ private static class SplitInfoContent extends SplitStat { /** * @param key */ public SplitInfoContent(String key) { super(key); LOG.debug("constructing SplitInfoContent key:" + key); } /* (non-Javadoc) * @see org.avenir.util.AttributeSplitStat.SplitStat#processStat(java.lang.String, java.util.Set) */ public double processStat(String algorithm, Set<String> classValues) { double stats = 0; LOG.debug("processing SplitStat key:" + key); double[] statArr = new double[segments.size()]; int[] countArr = new int[segments.size()]; int totalCount = 0; int i = 0; for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); double stat = statSegment.processStat(algorithm); statArr[i] = stat; int count = statSegment.getTotalCount(); countArr[i] = count; totalCount += count; ++i; } //weighted average double statSum = 0; for (int j = 0; j < statArr.length; ++j) { statSum += statArr[j] * countArr[j]; } stats = statSum / totalCount; LOG.debug("split key:" + key + " stats:" + stats); return stats; } } /** * Hellinger distance base split stat * @author pranab * Hellinger distance based crtieria * */ private static class SplitHellingerDistance extends SplitStat { /** * @param key */ public SplitHellingerDistance(String key) { super(key); LOG.debug("constructing SplitHellingerDistance key:" + key); } /* (non-Javadoc) * @see org.avenir.util.AttributeSplitStat.SplitStat#processStat(java.lang.String, java.util.Set) */ public double processStat(String algorithm, Set<String> classValues) { double stats = 0; LOG.debug("processing SplitStat key:" + key); if (classValues.size() != 2) { throw new IllegalArgumentException( "Hellinger distance algorithm is only valid for binary valued class attributes"); } //binary class values String[] classValueArr = new String[2]; int ci = 0; for (String classVal : classValues) { classValueArr[ci++] = classVal; } //class value counts int[] classValCount = new int[2]; for (int i = 0; i < 2; ++i) { classValCount[i] = 0; for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); classValCount[i] += statSegment.getCountForClassVal(classValueArr[i]); } } //hellinger distance double sum = 0; for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); double val0 = (double)statSegment.getCountForClassVal(classValueArr[0]) / classValCount[0]; statSegment.setClassConfidence(classValueArr[0], val0); val0 = Math.sqrt(val0); double val1 = (double)statSegment.getCountForClassVal(classValueArr[1]) / classValCount[1]; statSegment.setClassConfidence(classValueArr[1], val1); val1 = Math.sqrt(val1); sum += (val0 - val1) * (val0 - val1); } stats = Math.sqrt(sum); LOG.debug("split key:" + key + " stats:" + stats); return stats; } } /** * Confidence ration based split stat * @author pranab * Class confidence based criteria */ private static class SplitClassCofidenceRatio extends SplitStat { public SplitClassCofidenceRatio(String key) { super(key); LOG.debug("constructing SplitClassCofidenceRatio key:" + key); } @Override public double processStat(String algorithm, Set<String> classValues) { Map <String, Integer> classValCount = new HashMap <String, Integer>(); double stat = 0; //class attribute total count for (String classVal : classValues) { for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); int count = statSegment.getCountForClassVal(classVal); if ( null == classValCount.get(classVal)) { classValCount.put(classVal, 0); } classValCount.put(classVal, classValCount.get(classVal) + count); } } // class confidence for (String classVal : classValues) { for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); int count = statSegment.getCountForClassVal(classVal); int totalCount = classValCount.get(classVal); statSegment.setClassConfidence(classVal, (double)count / totalCount); } } //class cofidence ratio int totalCount = 0; double sum = 0; for (Integer segmentIndex : segments.keySet()) { SplitStatSegment statSegment = segments.get(segmentIndex); double classConfRatio = statSegment.processClassConfidenceRatio(); int count = statSegment.getTotalCount(); sum += classConfRatio * count; totalCount += count; } stat = sum / totalCount; return stat; } } /** Stats for a split segment i.e. range for numerical attribute and group of * of attributes for categorical * @author pranab * */ private static class SplitStatSegment { private int segmentIndex; private Map<String, Integer> classValCount = new HashMap<String, Integer>(); private Map<String, Double> classValPr = new HashMap<String, Double>(); private Map<String, Double> classValConfidence = new HashMap<String, Double>(); private Map<String, Double> classValConfidenceRatio = new HashMap<String, Double>(); private int totalCount = 0; /** * @param segmentIndex */ public SplitStatSegment(int segmentIndex) { LOG.debug("constructing SplitStatSegment segmentIndex:" + segmentIndex); this.segmentIndex = segmentIndex; } /** * @param classVal * @param count */ public void countClassVal(String classVal, int count) { LOG.debug("counting SplitStatSegment segmentIndex:" + segmentIndex + " classVal:" + classVal + " count:" + count); if (null == classValCount.get(classVal)) { classValCount.put(classVal, 0); } classValCount.put(classVal, classValCount.get(classVal) + count); } /** * @param algorithm * @return */ public double processStat(String algorithm) { double stat = 0.0; totalCount = 0; for (String key : classValCount.keySet()) { totalCount += classValCount.get(key); } LOG.debug("processing segment index:" + segmentIndex + " total count:" + totalCount); if (algorithm.equals(ALG_ENTROPY)) { //entropy based double log2 = Math.log(2); for (String key : classValCount.keySet()) { double pr = (double)classValCount.get(key) / totalCount; stat -= pr * Math.log(pr) / log2; classValPr.put(key, pr); } } else if (algorithm.equals(ALG_GINI_INDEX)) { //gini index based double prSquare = 0; for (String key : classValCount.keySet()) { int count = classValCount.get(key); double pr = (double)count / totalCount; LOG.debug("class val:" + key + " count:" + count); prSquare += pr * pr; classValPr.put(key, pr); } stat = 1.0 - prSquare; } return stat; } /** * @return */ public int getTotalCount() { if (0 == totalCount) { for (String key : classValCount.keySet()) { totalCount += classValCount.get(key); } } return totalCount; } /** * @return */ public Map<String, Double> getClassValPr() { return classValPr; } /** * @param classVal * @return */ public int getCountForClassVal(String classVal) { Integer countObj = classValCount.get(classVal); return countObj == null? 0 : countObj; } /** * @param classVal * @param confidence */ public void setClassConfidence(String classVal, double confidence) { classValConfidence.put(classVal, confidence); } /** * @return */ public double processClassConfidenceRatio() { double entropy = 0; double totalClassConf = 0; for (String classVal : classValConfidence.keySet()) { totalClassConf += classValConfidence.get(classVal); } for (String classVal : classValConfidence.keySet()) { double classConfRatio = classValConfidence.get(classVal) / totalClassConf; classValConfidenceRatio.put(classVal, classConfRatio); } //entropy based on class confidence ratio double log2 = Math.log(2); for (String key : classValConfidenceRatio.keySet()) { double ccr = classValConfidenceRatio.get(key); entropy -= ccr * Math.log(ccr) / log2; } return entropy; } } }