/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.udf;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelStatsConf;
import ml.shifu.shifu.core.binning.AbstractBinning;
import ml.shifu.shifu.core.binning.CategoricalBinning;
import ml.shifu.shifu.core.binning.EqualIntervalBinning;
import ml.shifu.shifu.core.binning.MunroPatBinning;
import org.apache.commons.lang.StringUtils;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import java.io.IOException;
import java.util.Iterator;
/**
* Created by zhanhu on 7/5/16.
*/
public class GenSmallBinningInfoUDF extends AbstractTrainerUDF<Tuple> {
private int scaleFactor = 1000;
public GenSmallBinningInfoUDF(String source, String pathModelConfig, String pathColumnConfig, String histoScaleFactor) throws IOException {
super(source, pathModelConfig, pathColumnConfig);
this.scaleFactor = Integer.parseInt(histoScaleFactor);
}
@Override
public Tuple exec(Tuple input) throws IOException {
if ( input == null || input.size() != 1 ) {
return null;
}
Integer columnId = null;
ColumnConfig columnConfig = null;
@SuppressWarnings("rawtypes")
AbstractBinning binning = null;
DataBag dataBag = (DataBag) input.get(0);
Iterator<Tuple> iterator = dataBag.iterator();
while ( iterator.hasNext() ) {
Tuple tuple = iterator.next();
if ( tuple != null && tuple.size() >= 3 ) {
if ( columnId == null ) {
columnId = (Integer) tuple.get(0);
columnConfig = super.columnConfigList.get(columnId);
binning = getBinningHandler(columnConfig);
}
Boolean isPostive = (Boolean)tuple.get(2);
if ( isToBinningVal(columnConfig, isPostive) ) {
String val = (String) tuple.get(1);
binning.addData(val);
}
}
}
Tuple output = TupleFactory.getInstance().newTuple(2);
output.set(0, columnId);
output.set(1, StringUtils.join(binning.getDataBin(), AbstractBinning.FIELD_SEPARATOR));
return output;
}
private boolean isToBinningVal(ColumnConfig columnConfig, Boolean isPostive) {
return columnConfig.isCategorical()
|| modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.EqualTotal)
|| modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.EqualInterval)
|| (modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.EqualPositive) && isPostive)
|| (modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.EqualNegtive) && !isPostive)
|| modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.WeightEqualTotal)
|| modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.WeightEqualInterval)
|| (modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.WeightEqualPositive) && isPostive)
|| (modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.WeightEqualNegative) && !isPostive);
}
@SuppressWarnings("rawtypes")
private AbstractBinning getBinningHandler(ColumnConfig columnConfig) {
AbstractBinning binning = null;
if ( columnConfig.isNumerical() ) {
if ( modelConfig.getBinningMethod().equals(ModelStatsConf.BinningMethod.EqualInterval) ) {
binning = new EqualIntervalBinning(this.scaleFactor, super.modelConfig.getMissingOrInvalidValues());
} else {
binning = new MunroPatBinning(this.scaleFactor, super.modelConfig.getMissingOrInvalidValues());
}
} else {
binning = new CategoricalBinning(this.scaleFactor, super.modelConfig.getMissingOrInvalidValues());
}
return binning;
}
public Schema outputSchema(Schema input) {
try {
Schema tupleSchema = new Schema();
tupleSchema.add(new Schema.FieldSchema("columnId", DataType.INTEGER));
tupleSchema.add(new Schema.FieldSchema("bins", DataType.CHARARRAY));
return new Schema(new Schema.FieldSchema("binning", tupleSchema, DataType.TUPLE));
} catch (IOException e) {
log.error("Error in outputSchema", e);
return null;
}
}
}