/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.stats; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; import ml.shifu.common.Step; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelStatsConf; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.processor.BasicModelProcessor; import ml.shifu.shifu.core.processor.stats.AbstractStatsExecutor; import ml.shifu.shifu.core.processor.stats.AkkaStatsWorker; import ml.shifu.shifu.core.processor.stats.DIBStatsExecutor; import ml.shifu.shifu.core.processor.stats.MunroPatIStatsExecutor; import ml.shifu.shifu.core.processor.stats.MunroPatStatsExecutor; import ml.shifu.shifu.core.processor.stats.SPDTIStatsExecutor; import ml.shifu.shifu.core.processor.stats.SPDTStatsExecutor; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.JSONUtils; import ml.shifu.shifu.util.updater.ColumnConfigUpdater; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Stats step for stats functions in SHiuf pipeline. * * @author Zhang David (pengzhang@paypal.com) */ public class StatsStep extends Step<List<ColumnConfig>> { private final static Logger LOG = LoggerFactory.getLogger(StatsStep.class); public StatsStep(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, Object> otherConfigs) { super(ModelStep.STATS, modelConfig, columnConfigList, otherConfigs); } /* * (non-Javadoc) * * @see ml.shifu.common.Step#process() */ @Override public List<ColumnConfig> process() throws IOException { LOG.info("Step Start: stats"); long start = System.currentTimeMillis(); try { // User may change variable type after `shifu init` ColumnConfigUpdater.updateColumnConfigFlags(this.modelConfig, this.columnConfigList, ModelStep.STATS); LOG.info("Saving ModelConfig, ColumnConfig and then upload to HDFS ..."); JSONUtils.writeValue(new File(pathFinder.getModelConfigPath(SourceType.LOCAL)), modelConfig); JSONUtils.writeValue(new File(pathFinder.getColumnConfigPath(SourceType.LOCAL)), columnConfigList); if(SourceType.HDFS.equals(modelConfig.getDataSet().getSource())) { CommonUtils.copyConfFromLocalToHDFS(modelConfig, this.pathFinder); } AbstractStatsExecutor statsExecutor = null; if(modelConfig.isMapReduceRunMode()) { if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.DynamicBinning)) { statsExecutor = new DIBStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.MunroPat)) { statsExecutor = new MunroPatStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.MunroPatI)) { statsExecutor = new MunroPatIStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.SPDT)) { statsExecutor = new SPDTStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.SPDTI)) { statsExecutor = new SPDTIStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else { statsExecutor = new SPDTIStatsExecutor(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } } else if(modelConfig.isLocalRunMode()) { statsExecutor = new AkkaStatsWorker(new BasicModelProcessor(super.modelConfig, super.columnConfigList, super.otherConfigs), modelConfig, columnConfigList); } else { throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_MODE); } statsExecutor.doStats(); if(SourceType.HDFS.equals(modelConfig.getDataSet().getSource())) { CommonUtils.copyConfFromLocalToHDFS(modelConfig, this.pathFinder); } } catch (Exception e) { LOG.error("Error:", e); } LOG.info("Step Finished: stats with {} ms", (System.currentTimeMillis() - start)); return columnConfigList; } }