package ml.shifu.shifu.core.binning;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.junit.Test;
import org.testng.Assert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Created by zhanhu on 4/18/17.
*/
public class DynamicCategoricalBinTest {
@Test
public void testDynamicCategoricalBin() throws IOException {
List<CategoricalBinInfo> categoricalBinInfoList = loadTestData();
Collections.sort(categoricalBinInfoList);
CateDynamicBinning binning = new CateDynamicBinning(2);
List<CategoricalBinInfo> finalBins = binning.merge(categoricalBinInfoList);
Assert.assertEquals(finalBins.size(), 2);
List<String> categoricalVals = new ArrayList<String>();
List<Long> negativeCnts = new ArrayList<Long>();
List<Long> positiveCnts = new ArrayList<Long>();
List<Double> positiveRates = new ArrayList<Double>();
for ( int i = 0; i < finalBins.size(); i ++ ) {
CategoricalBinInfo binInfo = finalBins.get(i);
categoricalVals.add("\"" + StringUtils.join(binInfo.getValues(), '^') + "\"");
negativeCnts.add(binInfo.getNegativeCnt());
positiveCnts.add(binInfo.getPositiveCnt());
positiveRates.add(binInfo.getPositiveRate());
}
System.out.println(StringUtils.join(categoricalVals, ','));
System.out.println(StringUtils.join(negativeCnts, ','));
System.out.println(StringUtils.join(positiveCnts, ','));
System.out.println(StringUtils.join(positiveRates, ','));
}
private List<CategoricalBinInfo> loadTestData() throws IOException {
List<String> lines = IOUtils.readLines(DynamicCategoricalBinTest.class
.getResourceAsStream("/example/binning-data/categorical-binning"));
String[] categories = lines.get(0)
.replaceAll("^.* \\[", "")
.replaceAll("].*$", "")
.replaceAll("\"", "").trim().split(",");
String[] binPosCounts = lines.get(2)
.replaceAll("^.* \\[", "")
.replaceAll("].*$", "")
.replaceAll("\"", "").trim().split(",");
String[] binNegCounts = lines.get(1)
.replaceAll("^.* \\[", "")
.replaceAll("].*$", "")
.replaceAll("\"", "").trim().split(",");
@SuppressWarnings("unused")
String[] positiveRates = lines.get(3)
.replaceAll("^.* \\[", "")
.replaceAll("].*$", "")
.replaceAll("\"", "").trim().split(",");
List<CategoricalBinInfo> categoricalBinInfos = new ArrayList<CategoricalBinInfo>();
for ( int i = 0; i < categories.length; i ++ ) {
CategoricalBinInfo binInfo = new CategoricalBinInfo();
List<String> values = new ArrayList<String>();
values.add(categories[i].trim());
binInfo.setValues(values);
binInfo.setPositiveCnt(Long.parseLong(binPosCounts[i].trim()));
binInfo.setNegativeCnt(Long.parseLong(binNegCounts[i].trim()));
categoricalBinInfos.add(binInfo);
}
return categoricalBinInfos;
}
}