/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor.worker;
import akka.actor.ActorRef;
import ml.shifu.shifu.container.ValueObject;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.BasicStatsCalculator;
import ml.shifu.shifu.core.Binning;
import ml.shifu.shifu.core.Binning.BinningDataType;
import ml.shifu.shifu.core.ColumnStatsCalculator;
import ml.shifu.shifu.core.ColumnStatsCalculator.ColumnMetrics;
import ml.shifu.shifu.message.StatsResultMessage;
import ml.shifu.shifu.message.StatsValueObjectMessage;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
/**
* StatsCalculateWorker class calculates the stats for each column
* It will do the binning for the column, calculate max/min/average, and calculate KS/IV
*/
public class StatsCalculateWorker extends AbstractWorkerActor {
private static Logger log = LoggerFactory.getLogger(StatsCalculateWorker.class);
private List<ValueObject> voList;
private int receivedMsgCnt;
private long missing;
private long total;
public StatsCalculateWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef,
ActorRef nextActorRef) {
super(modelConfig, columnConfigList, parentActorRef, nextActorRef);
voList = new ArrayList<ValueObject>();
receivedMsgCnt = 0;
missing = 0l;
total = 0;
}
@Override
public void handleMsg(Object message) {
if(message instanceof StatsValueObjectMessage) {
log.debug("Received value object list for stats");
StatsValueObjectMessage statsVoMessage = (StatsValueObjectMessage) message;
voList.addAll(statsVoMessage.getVoList());
this.missing += statsVoMessage.getMissing();
this.total += statsVoMessage.getTotal();
receivedMsgCnt++;
if(receivedMsgCnt == statsVoMessage.getTotalMsgCnt()) {
log.debug("received " + receivedMsgCnt + ", start to work");
ColumnConfig columnConfig = columnConfigList.get(statsVoMessage.getColumnNum());
calculateColumnStats(columnConfig, voList);
columnConfig.setMissingCnt(this.missing);
columnConfig.setTotalCount(this.total);
columnConfig.setMissingPercentage((double) missing / total);
parentActorRef.tell(new StatsResultMessage(columnConfig), this.getSelf());
}
} else {
unhandled(message);
}
}
/**
* Do the stats calculation
*
* @param columnConfig
* @param valueObjList
*/
private void calculateColumnStats(ColumnConfig columnConfig, List<ValueObject> valueObjList) {
if(CollectionUtils.isEmpty(valueObjList)) {
log.error("No values for column : {}, please check!", columnConfig.getColumnName());
return;
}
BinningDataType dataType;
if(columnConfig.isNumerical()) {
dataType = BinningDataType.Numerical;
} else if(columnConfig.isCategorical()) {
dataType = BinningDataType.Categorical;
} else {
dataType = BinningDataType.Auto;
}
// Binning
Binning binning = new Binning(modelConfig.getPosTags(), modelConfig.getNegTags(), dataType, valueObjList);
log.info("posTags - {}, negTags - {}, first example tag - {}", modelConfig.getPosTags(), modelConfig.getNegTags(), valueObjList.get(0).getTag());
binning.setMaxNumOfBins(modelConfig.getBinningExpectedNum());
binning.setBinningMethod(modelConfig.getBinningMethod());
binning.setAutoTypeThreshold(modelConfig.getAutoTypeThreshold());
binning.setMergeEnabled(Boolean.TRUE);
binning.doBinning();
// Calculate Basic Stats
BasicStatsCalculator basicStatsCalculator = new BasicStatsCalculator(binning.getUpdatedVoList(),
modelConfig.getNumericalValueThreshold());
// Calculate KSIV, based on Binning result
ColumnMetrics columnCountMetrics = ColumnStatsCalculator.calculateColumnMetrics(binning.getBinCountNeg(),
binning.getBinCountPos());
ColumnMetrics columnWeightMetrics = ColumnStatsCalculator.calculateColumnMetrics(binning.getBinWeightedNeg(),
binning.getBinWeightedPos());
dataType = binning.getUpdatedDataType();
if(dataType.equals(BinningDataType.Numerical)) {
columnConfig.setColumnType(ColumnType.N);
columnConfig.setBinBoundary(binning.getBinBoundary());
} else {
columnConfig.setColumnType(ColumnType.C);
columnConfig.setBinCategory(binning.getBinCategory());
}
columnConfig.setBinCountNeg(binning.getBinCountNeg());
columnConfig.setBinCountPos(binning.getBinCountPos());
columnConfig.setBinPosCaseRate(binning.getBinPosCaseRate());
columnConfig.setKs(columnCountMetrics.getKs());
columnConfig.setIv(columnCountMetrics.getIv());
columnConfig.getColumnStats().setWoe(columnCountMetrics.getWoe());
columnConfig.getColumnStats().setWeightedKs(columnWeightMetrics.getKs());
columnConfig.getColumnStats().setWeightedIv(columnWeightMetrics.getIv());
columnConfig.getColumnStats().setWeightedWoe(columnWeightMetrics.getWoe());
columnConfig.setMax(basicStatsCalculator.getMax());
columnConfig.setMin(basicStatsCalculator.getMin());
columnConfig.setMean(basicStatsCalculator.getMean());
columnConfig.setStdDev(basicStatsCalculator.getStdDev());
columnConfig.setMedian(basicStatsCalculator.getMedian());
columnConfig.setBinWeightedNeg(binning.getBinWeightedNeg());
columnConfig.setBinWeightedPos(binning.getBinWeightedPos());
columnConfig.getColumnBinning().setBinCountWoe(columnCountMetrics.getBinningWoe());
columnConfig.getColumnBinning().setBinWeightedWoe(columnWeightMetrics.getBinningWoe());
// columnConfig.setMissingCnt(cnt)
}
}