/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.udf; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.binning.AbstractBinning; import ml.shifu.shifu.core.binning.DynamicBinning; import ml.shifu.shifu.core.binning.obj.NumBinInfo; import ml.shifu.shifu.fs.ShifuFileUtils; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.StringUtils; import org.apache.pig.data.DataBag; import org.apache.pig.data.Tuple; import org.apache.pig.data.TupleFactory; import java.io.IOException; import java.util.*; /** * Created by zhanhu on 7/6/16. */ public class DynamicBinningUDF extends AbstractTrainerUDF<Tuple> { private Map<Integer, String> smallBinsMap; public DynamicBinningUDF(String source, String pathModelConfig, String pathColumnConfig, String smallBinsPath) throws IOException { super(source, pathModelConfig, pathColumnConfig); smallBinsMap = new HashMap<Integer, String>(); List<String> smallBinsList = ShifuFileUtils.readFilePartsIntoList(smallBinsPath, SourceType.HDFS); for (String smallBin : smallBinsList) { String[] fields = StringUtils.split(smallBin, '\u0007'); if (fields.length == 2) { smallBinsMap.put(Integer.parseInt(fields[0]), fields[1]); } } } @Override public Tuple exec(Tuple input) throws IOException { if (input == null || input.size() != 1) { return null; } Integer columnId = null; ColumnConfig columnConfig = null; String binsData = null; Set<String> missingValSet = new HashSet<String>(super.modelConfig.getMissingOrInvalidValues()); List<NumBinInfo> binInfoList = null; DataBag columnDataBag = (DataBag) input.get(0); Iterator<Tuple> iterator = columnDataBag.iterator(); while (iterator.hasNext()) { Tuple tuple = iterator.next(); if (columnId == null) { columnId = (Integer) tuple.get(0); columnConfig = super.columnConfigList.get(columnId); String smallBins = smallBinsMap.get(columnId); if (columnConfig.isCategorical()) { binsData = smallBins; break; } else { binInfoList = NumBinInfo.constructNumBinfo(smallBins, AbstractBinning.FIELD_SEPARATOR); } } String val = (String) tuple.get(1); Boolean isPositiveInst = (Boolean) tuple.get(2); if (missingValSet.contains(val)) { continue; } Double d = null; try { d = Double.valueOf(val); } catch (Exception e) { // illegal number, just skip it continue; } NumBinInfo numBinInfo = binaryLocate(binInfoList, d); if (numBinInfo != null) { numBinInfo.incInstCnt(isPositiveInst); } } if (binsData == null && CollectionUtils.isNotEmpty(binInfoList)) { DynamicBinning dynamicBinning = new DynamicBinning(binInfoList, modelConfig.getStats().getMaxNumBin()); List<Double> binFields = dynamicBinning.getDataBin(); binsData = StringUtils.join(binFields, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR); } Tuple output = TupleFactory.getInstance().newTuple(2); output.set(0, columnId); output.set(1, binsData); return output; } public NumBinInfo binaryLocate(List<NumBinInfo> binInfoList, Double d) { int left = 0; int right = binInfoList.size() - 1; while (left <= right) { int middle = (left + right) / 2; NumBinInfo binInfo = binInfoList.get(middle); if (d >= binInfo.getLeftThreshold() && d < binInfo.getRightThreshold()) { return binInfo; } else if (d >= binInfo.getRightThreshold()) { left = middle + 1; } else if (d < binInfo.getLeftThreshold()) { right = middle - 1; } else { return null; } } return null; } }