/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.processor; import java.io.BufferedWriter; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Scanner; import java.util.Set; import ml.shifu.shifu.column.NSColumn; import ml.shifu.shifu.container.meta.ValidateResult; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.shuffle.MapReduceShuffle; import ml.shifu.shifu.core.validator.ModelInspector; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.PathFinder; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.pig.PigExecutor; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.Environment; import ml.shifu.shifu.util.JSONUtils; import ml.shifu.shifu.util.updater.ColumnConfigUpdater; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.fs.Path; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Model Basic Processor, it helps to do basic manipulate in model, including load/save configuration, copy * configuration file */ public class BasicModelProcessor { private final static Logger LOG = LoggerFactory.getLogger(BasicModelProcessor.class); protected ModelConfig modelConfig; protected List<ColumnConfig> columnConfigList; protected PathFinder pathFinder; /** * If not specified SHIFU_HOME env, some key configurations like pig path or lib path can be configured here */ protected Map<String, Object> otherConfigs; public BasicModelProcessor() { } public BasicModelProcessor(Map<String, Object> otherConfigs) { this.otherConfigs = otherConfigs; } public BasicModelProcessor(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, Object> otherConfigs) { this.modelConfig = modelConfig; this.columnConfigList = columnConfigList; this.otherConfigs = otherConfigs; this.pathFinder = new PathFinder(modelConfig, otherConfigs); } /** * initialize the config file, pathFinder and other input * * @param step * Shifu running step * @throws Exception * any exception in setup */ protected void setUp(ModelStep step) throws Exception { if(hasInitialized()) { return; } // load model configuration and do validation loadModelConfig(); validateModelConfig(step); this.pathFinder = new PathFinder(modelConfig, this.getOtherConfigs()); checkAlgorithmParam(); LOG.info(String.format("Training Data Soure Location: %s", modelConfig.getDataSet().getSource())); switch(step) { case INIT: break; default: loadColumnConfig(); validateColumnConfig(); // update ColumnConfig and save to disk ColumnConfigUpdater.updateColumnConfigFlags(modelConfig, columnConfigList, step); saveColumnConfigList(); break; } } private void validateColumnConfig() { if(this.columnConfigList == null) { return; } Set<NSColumn> names = new HashSet<NSColumn>(); for(ColumnConfig config: this.columnConfigList) { if(StringUtils.isEmpty(config.getColumnName())) { throw new IllegalArgumentException("Empty column name, please check your header file."); } if(names.contains(new NSColumn(config.getColumnName()))) { LOG.warn("Duplicated {} in ColumnConfig.json file, later one will be append index to make it unique.", config.getColumnName()); } names.add(new NSColumn(config.getColumnName())); } if(!names.contains(new NSColumn(modelConfig.getTargetColumnName()))) { throw new IllegalArgumentException("target column " + modelConfig.getTargetColumnName() + " does not exist."); } if(StringUtils.isNotBlank(modelConfig.getWeightColumnName()) && !names.contains(new NSColumn(modelConfig.getWeightColumnName()))) { throw new IllegalArgumentException("weight column " + modelConfig.getWeightColumnName() + " does not exist."); } } /** * The post-logic after running * <p> * copy file to hdfs if SourceType is HDFS * * @param step * Shifu running step * @throws IOException * if any problem happen in copying files to HDFS */ protected void clearUp(ModelStep step) throws IOException { // do nothing now } /** * save Model Config * * @throws IOException * an exception in saving model config */ public void saveModelConfig() throws IOException { LOG.info("Saving ModelConfig..."); JSONUtils.writeValue(new File(pathFinder.getModelConfigPath(SourceType.LOCAL)), modelConfig); } /** * save the Column Config * * @throws IOException * an exception in saving column config */ public void saveColumnConfigList() throws IOException { LOG.info("Saving ColumnConfig..."); JSONUtils.writeValue(new File(pathFinder.getColumnConfigPath(SourceType.LOCAL)), columnConfigList); } /** * validate the modelconfig if it's well written. * * @param modelStep * the model step * @throws Exception * any exception in validation */ protected void validateModelConfig(ModelStep modelStep) throws Exception { ValidateResult result = new ValidateResult(false); if(modelConfig == null) { result.getCauses().add("The ModelConfig is not loaded!"); } else { result = ModelInspector.getInspector().probe(modelConfig, modelStep); } if(!result.getStatus()) { LOG.error("ModelConfig Validation - Fail! See below:"); for(String cause: result.getCauses()) { LOG.error("\t!!! " + cause); } throw new ShifuException(ShifuErrorCode.ERROR_MODELCONFIG_NOT_VALIDATION); } else { LOG.info("ModelConfig Validation - OK"); } } /** * Close all scanners * * @param scanners * the scanners */ public void closeScanners(List<Scanner> scanners) { if(CollectionUtils.isNotEmpty(scanners)) { for(Scanner scanner: scanners) { scanner.close(); } } } /** * Sync data into HDFS if necessary: * RunMode == pig and SourceType == HDFS * * @param sourceType * source type * @return if sync in hdfs or not * @throws IOException * any exception in file system io */ public boolean syncDataToHdfs(SourceType sourceType) throws IOException { if(SourceType.HDFS.equals(sourceType)) { CommonUtils.copyConfFromLocalToHDFS(modelConfig, this.pathFinder); return true; } return false; } public void copyModelFiles(String sourcePath, String targetPath) throws IOException { loadModelConfig(sourcePath + File.separator + "ModelConfig.json", SourceType.LOCAL); File targetFile = new File(targetPath); this.modelConfig.setModelSetName(targetFile.getName()); this.modelConfig.setModelSetCreator(Environment.getProperty(Environment.SYSTEM_USER)); try { JSONUtils.writeValue(new File(targetPath + File.separator + "ModelConfig.json"), modelConfig); } catch (IOException e) { throw new ShifuException(ShifuErrorCode.ERROR_WRITE_MODELCONFIG, e); } } /** * get the modelConfig instance * * @return the modelConfig */ public ModelConfig getModelConfig() { return modelConfig; } /** * get the columnConfigList instance * * @return the columnConfigList */ public List<ColumnConfig> getColumnConfigList() { return columnConfigList; } /** * get the pathFinder instance * * @return the pathFinder */ public PathFinder getPathFinder() { return pathFinder; } /** * check algorithm parameter * * @throws Exception * modelConfig is not loaded or save ModelConfig.json file error */ public void checkAlgorithmParam() throws Exception { String alg = modelConfig.getAlgorithm(); Map<String, Object> param = modelConfig.getParams(); LOG.info("Check algorithm parameter"); if(alg.equalsIgnoreCase("LR")) { if(!param.containsKey("LearningRate")) { param = new LinkedHashMap<String, Object>(); param.put("LearningRate", 0.1); modelConfig.setParams(param); saveModelConfig(); } } else if(alg.equalsIgnoreCase("NN")) { if(!param.containsKey("Propagation")) { param = new LinkedHashMap<String, Object>(); param.put("Propagation", "Q"); param.put("LearningRate", 0.1); param.put("NumHiddenLayers", 2); List<Integer> nodes = new ArrayList<Integer>(); nodes.add(20); nodes.add(10); param.put("NumHiddenNodes", nodes); List<String> func = new ArrayList<String>(); func.add("tanh"); func.add("tanh"); param.put("ActivationFunc", func); modelConfig.setParams(param); saveModelConfig(); } } else if(alg.equalsIgnoreCase("SVM")) { if(!param.containsKey("Kernel")) { param = new LinkedHashMap<String, Object>(); param.put("Kernel", "linear"); param.put("Gamma", 1.); param.put("Const", 1.); modelConfig.setParams(param); saveModelConfig(); } } else if(alg.equalsIgnoreCase("DT")) { // do nothing } else if(alg.equalsIgnoreCase("RF")) { if(!param.containsKey("FeatureSubsetStrategy")) { param = new LinkedHashMap<String, Object>(); param.put("FeatureSubsetStrategy", "all"); param.put("MaxDepth", 10); param.put("MaxStatsMemoryMB", 256); param.put("Impurity", "entropy"); modelConfig.setParams(param); saveModelConfig(); } } else if(alg.equalsIgnoreCase("GBT")) { if(!param.containsKey("FeatureSubsetStrategy")) { param = new LinkedHashMap<String, Object>(); param.put("FeatureSubsetStrategy", "all"); param.put("MaxDepth", 10); param.put("MaxStatsMemoryMB", 256); param.put("Impurity", "entropy"); param.put("Loss", "squared"); modelConfig.setParams(param); saveModelConfig(); } } else { throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG); } // log.info("Finished: check the algorithm parameter"); } /** * load Model Config method * * @throws IOException * in load model config */ private void loadModelConfig() throws IOException { modelConfig = CommonUtils.loadModelConfig(new Path(CommonUtils.getLocalModelSetPath(otherConfigs), Constants.LOCAL_MODEL_CONFIG_JSON).toString(), SourceType.LOCAL); } /** * load Model Config method * * @throws IOException * in load model config */ private void loadModelConfig(String pathToModel, SourceType source) throws IOException { modelConfig = CommonUtils.loadModelConfig(pathToModel, source); } /** * load Column Config * * @throws IOException * in load column config */ private void loadColumnConfig() throws IOException { columnConfigList = CommonUtils.loadColumnConfigList(new Path(CommonUtils.getLocalModelSetPath(otherConfigs), Constants.LOCAL_COLUMN_CONFIG_JSON).toString(), SourceType.LOCAL); } /** * Check the processor is initialized or not * * @return true - if the process is initialized * false - if not */ private boolean hasInitialized() { return (null != this.modelConfig && null != this.columnConfigList && null != this.pathFinder); } /** * create HEAD file contain the workspace * * @param modelName * model name * @throws IOException * any exception in create header */ protected void createHead(String modelName) throws IOException { File header = new File(modelName == null ? "" : modelName + "/.HEAD"); if(header.exists()) { LOG.error("File {} already exist.", header.getAbsolutePath()); return; } BufferedWriter writer = null; try { writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(header), Constants.DEFAULT_CHARSET)); writer.write("master"); } catch (IOException e) { LOG.error("Fail to create HEAD file to store the current workspace"); } finally { if(writer != null) { writer.close(); } } } /** * @return the otherConfigs */ public Map<String, Object> getOtherConfigs() { return otherConfigs; } /** * @param otherConfigs * the otherConfigs to set */ public void setOtherConfigs(Map<String, Object> otherConfigs) { this.otherConfigs = otherConfigs; } protected void runDataClean(boolean isToShuffle) throws IOException { SourceType sourceType = modelConfig.getDataSet().getSource(); String cleanedDataPath = this.pathFinder.getCleanedDataPath(); LOG.info("Start to generate clean data for tree model ... "); if(ShifuFileUtils.isFileExists(cleanedDataPath, sourceType)) { ShifuFileUtils.deleteFile(cleanedDataPath, sourceType); } Map<String, String> paramsMap = new HashMap<String, String>(); paramsMap.put("sampleRate", modelConfig.getNormalizeSampleRate().toString()); paramsMap.put("sampleNegOnly", ((Boolean) modelConfig.isNormalizeSampleNegOnly()).toString()); paramsMap.put("delimiter", CommonUtils.escapePigString(modelConfig.getDataSetDelimiter())); try { String normPigPath = pathFinder.getScriptPath("scripts/Normalize.pig"); paramsMap.put(Constants.IS_COMPRESS, "true"); paramsMap.put(Constants.IS_NORM_FOR_CLEAN, "true"); paramsMap.put(Constants.PATH_NORMALIZED_DATA, pathFinder.getCleanedDataPath()); PigExecutor.getExecutor().submitJob(modelConfig, normPigPath, paramsMap, sourceType, this.pathFinder); // cleaned validation data if(StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { String cleandedValidationDataPath = pathFinder.getCleanedValidationDataPath(); if(ShifuFileUtils.isFileExists(cleandedValidationDataPath, sourceType)) { ShifuFileUtils.deleteFile(cleandedValidationDataPath, sourceType); } paramsMap.put(Constants.IS_COMPRESS, "false"); paramsMap.put(Constants.PATH_RAW_DATA, modelConfig.getValidationDataSetRawPath()); paramsMap.put(Constants.PATH_NORMALIZED_DATA, pathFinder.getCleanedValidationDataPath()); PigExecutor.getExecutor().submitJob(modelConfig, normPigPath, paramsMap, sourceType, this.pathFinder); } } catch (IOException e) { throw new ShifuException(ShifuErrorCode.ERROR_RUNNING_PIG_JOB, e); } catch (Throwable e) { throw new RuntimeException(e); } if(isToShuffle) { MapReduceShuffle shuffler = new MapReduceShuffle(this.modelConfig); try { shuffler.run(pathFinder.getCleanedDataPath()); } catch (ClassNotFoundException e) { throw new RuntimeException("Fail to shuffle the cleaned data.", e); } catch (InterruptedException e) { throw new RuntimeException("Fail to shuffle the cleaned data.", e); } } LOG.info("Generate clean data for tree model successful."); } }