/*
* chombo: Hadoop Map Reduce utility
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.chombo.util;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* Sampler for class imbalanced data. Makes data set class balanced by sub sampling
* majority classess
* @author pranab
*
*/
public class ClassImbalancedSampler {
private int imbalanceRatioCalInterval = 100;
private int counter = 0;
private Random randomGenerator = new Random();
private Map<String,ClassAttrubuteStat> classAttrStats = new HashMap<String,ClassAttrubuteStat>();
public ClassImbalancedSampler(int imbalanceRatioCalInterval) {
this.imbalanceRatioCalInterval = imbalanceRatioCalInterval;
}
public boolean next(String classAttrVal) {
boolean sampleIt = false;
if (++counter % imbalanceRatioCalInterval == 0) {
calculateImbalanceRatio();
}
ClassAttrubuteStat classAttrStat = classAttrStats.get(classAttrVal);
if (null == classAttrStat) {
classAttrStat = new ClassAttrubuteStat();
classAttrStats.put(classAttrVal, classAttrStat);
}
classAttrStat.incrCount();
sampleIt = counter < imbalanceRatioCalInterval ? true :
(randomGenerator.nextInt(100) < classAttrStat.getImbalanceRatio());
if (sampleIt) {
classAttrStat.incrSampleCount();
}
return sampleIt;
}
private void calculateImbalanceRatio() {
int min = Integer.MAX_VALUE;
for (String clAttrVal : classAttrStats.keySet()) {
ClassAttrubuteStat clAttrStat = classAttrStats.get(clAttrVal);
if (null != clAttrStat && clAttrStat.getCount() < min) {
min = clAttrStat.getCount();
}
}
int ratio;
for (String clAttrVal : classAttrStats.keySet()) {
ClassAttrubuteStat clAttrStat = classAttrStats.get(clAttrVal);
if (null != clAttrStat) {
ratio = (100 * min) / clAttrStat.getCount();
clAttrStat.setImbalanceRatio(ratio);
}
}
}
private static class ClassAttrubuteStat {
private int count;
private int imbalanceRatio;
private int sampledCount;
public int getCount() {
return count;
}
public int getImbalanceRatio() {
return imbalanceRatio;
}
public void setImbalanceRatio(int imbalanceRatio) {
this.imbalanceRatio = imbalanceRatio;
}
public int getSampledCount() {
return sampledCount;
}
public void incrCount() {
++count;
}
public void incrSampleCount() {
++sampledCount;
}
}
}