/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.udf.stats; import java.util.*; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.udf.CalculateStatsUDF; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.lang.StringUtils; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.data.DataBag; import org.apache.pig.data.Tuple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * CategoricalVarStats class */ public class CategoricalVarStats extends AbstractVarStats { private static Logger log = LoggerFactory.getLogger(CategoricalVarStats.class); private Map<String, Integer> categoricalBinMap; public CategoricalVarStats(ModelConfig modelConfig, ColumnConfig columnConfig, Double valueThreshold) { super(modelConfig, columnConfig, valueThreshold); } /* * (non-Javadoc) * * @see ml.shifu.shifu.udf.stats.AbstractVarStats#runVarStats(java.lang.String, org.apache.pig.data.DataBag) */ @Override public void runVarStats(String binningInfo, DataBag databag) throws ExecException { String[] binningDataArr = StringUtils.split(binningInfo, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR); log.info("Column Name - " + this.columnConfig.getColumnName() + ", Column Bin Length - " + binningDataArr.length); columnConfig.setBinCategory(Arrays.asList(binningDataArr)); categoricalBinMap = new HashMap<String, Integer>(columnConfig.getBinCategory().size()); for(int i = 0; i < columnConfig.getBinCategory().size(); i++) { List<String> catValues = CommonUtils.flattenCatValGrp(columnConfig.getBinCategory().get(i)); for ( String cval : catValues ) { categoricalBinMap.put(cval, Integer.valueOf(i)); } } statsCategoricalColumnInfo(databag, columnConfig); } /** * @param databag * @param columnConfig * @throws ExecException */ private void statsCategoricalColumnInfo(DataBag databag, ColumnConfig columnConfig) throws ExecException { // The last bin is for missingOrInvalid values Integer[] binCountPos = new Integer[columnConfig.getBinCategory().size() + 1]; Integer[] binCountNeg = new Integer[columnConfig.getBinCategory().size() + 1]; Double[] binWeightCountPos = new Double[columnConfig.getBinCategory().size() + 1]; Double[] binWeightCountNeg = new Double[columnConfig.getBinCategory().size() + 1]; int lastBinIndex = columnConfig.getBinCategory().size(); initializeZeroArr(binCountPos); initializeZeroArr(binCountNeg); initializeZeroArr(binWeightCountPos); initializeZeroArr(binWeightCountNeg); Iterator<Tuple> iterator = databag.iterator(); boolean isMissingValue = false; boolean isInvalidValue = false; while(iterator.hasNext()) { isInvalidValue = false; isMissingValue = false; Tuple element = iterator.next(); if(element.size() < 4) { continue; } Object value = element.get(1); String tag = CommonUtils.trimTag((String) element.get(2)); Double weight = (Double) element.get(3); int binNum = 0; if(value == null || modelConfig.getDataSet().getMissingOrInvalidValues() .contains(value.toString().toLowerCase().trim())) { // TODO check missing value list in ModelConfig?? missingValueCnt++; isMissingValue = true; } else { String str = StringUtils.trim(value.toString()); binNum = quickLocateCategorialBin(str); if(binNum < 0) { invalidValueCnt++; isInvalidValue = true; } } if(isInvalidValue || isMissingValue) { binNum = lastBinIndex; } if(modelConfig.getPosTags().contains(tag)) { increaseInstCnt(binCountPos, binNum); increaseInstCnt(binWeightCountPos, binNum, weight); } else if(modelConfig.getNegTags().contains(tag)) { increaseInstCnt(binCountNeg, binNum); increaseInstCnt(binWeightCountNeg, binNum, weight); } } columnConfig.setBinCountPos(Arrays.asList(binCountPos)); columnConfig.setBinCountNeg(Arrays.asList(binCountNeg)); columnConfig.setBinWeightedPos(Arrays.asList(binWeightCountPos)); columnConfig.setBinWeightedNeg(Arrays.asList(binWeightCountNeg)); calculateBinPosRateAndAvgScore(); for(int i = 0; i < columnConfig.getBinCountPos().size(); i++) { int posCount = columnConfig.getBinCountPos().get(i); int negCount = columnConfig.getBinCountNeg().get(i); binning.addData(columnConfig.getBinPosRate().get(i), posCount); binning.addData(columnConfig.getBinPosRate().get(i), negCount); streamStatsCalculator.addData(columnConfig.getBinPosRate().get(i), posCount); streamStatsCalculator.addData(columnConfig.getBinPosRate().get(i), negCount); } columnConfig.setMax(streamStatsCalculator.getMax()); columnConfig.setMean(streamStatsCalculator.getMean()); columnConfig.setMin(streamStatsCalculator.getMin()); if(binning.getMedian() == null) { columnConfig.setMedian(streamStatsCalculator.getMean()); } else { columnConfig.setMedian(binning.getMedian()); } columnConfig.setStdDev(streamStatsCalculator.getStdDev()); // Currently, invalid value will be regarded as missing columnConfig.setMissingCnt(missingValueCnt + invalidValueCnt); columnConfig.setTotalCount(databag.size()); columnConfig.setMissingPercentage(((double) columnConfig.getMissingCount()) / columnConfig.getTotalCount()); columnConfig.getColumnStats().setSkewness(streamStatsCalculator.getSkewness()); columnConfig.getColumnStats().setKurtosis(streamStatsCalculator.getKurtosis()); } private int quickLocateCategorialBin(String val) { Integer binNum = categoricalBinMap.get(val); return ((binNum == null) ? -1 : binNum); } public static void main(String[] args) { System.out.println(Math.log((9 * 1.0d / 21d) / (41 * 1.0d / 90d))); } }