/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.processor; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Scanner; import ml.shifu.guagua.hadoop.util.HDPUtils; import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants; import ml.shifu.shifu.actor.AkkaSystemExecutor; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.core.mr.input.CombineInputFormat; import ml.shifu.shifu.core.posttrain.FeatureImportanceMapper; import ml.shifu.shifu.core.posttrain.FeatureImportanceReducer; import ml.shifu.shifu.core.posttrain.FeatureStatsWritable; import ml.shifu.shifu.core.posttrain.PostTrainMapper; import ml.shifu.shifu.core.posttrain.PostTrainReducer; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.pig.PigExecutor; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.Environment; import org.apache.commons.codec.binary.Base64; import org.apache.commons.collections.Predicate; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.jexl2.JexlException; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.PathFilter; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.GenericOptionsParser; import org.apache.pig.impl.util.JarManager; import org.encog.ml.data.MLDataSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; /** * Post train processor, update the avg score */ public class PostTrainModelProcessor extends BasicModelProcessor implements Processor { /** * log object */ private final static Logger log = LoggerFactory.getLogger(PostTrainModelProcessor.class); /** * runner for post train */ @Override public int run() throws Exception { log.info("Step Start: posttrain"); long start = System.currentTimeMillis(); try { setUp(ModelStep.POSTTRAIN); syncDataToHdfs(modelConfig.getDataSet().getSource()); if(modelConfig.isClassification()) { throw new IllegalArgumentException( "post train step is only effective in regresion, not classification."); } if(modelConfig.isMapReduceRunMode()) { runMapRedPostTrain(); } else if(modelConfig.isLocalRunMode()) { runAkkaPostTrain(); } else { log.error("Invalid RunMode Setting!"); } clearUp(ModelStep.POSTTRAIN); } catch (Exception e) { log.error("Error:", e); return -1; } log.info("Step Finished: posttrain with {} ms", (System.currentTimeMillis() - start)); return 0; } // GuaguaOptionsParser doesn't to support *.jar currently. private String addRuntimeJars() { List<String> jars = new ArrayList<String>(16); // common-codec jars.add(JarManager.findContainingJar(Base64.class)); // commons-compress-*.jar jars.add(JarManager.findContainingJar(BZip2CompressorInputStream.class)); // commons-lang-*.jar jars.add(JarManager.findContainingJar(StringUtils.class)); // common-io-*.jar jars.add(JarManager.findContainingJar(org.apache.commons.io.IOUtils.class)); // common-collections jars.add(JarManager.findContainingJar(Predicate.class)); // guava-*.jar jars.add(JarManager.findContainingJar(Splitter.class)); // shifu-*.jar jars.add(JarManager.findContainingJar(getClass())); // jexl jars.add(JarManager.findContainingJar(JexlException.class)); // encog-core-*.jar jars.add(JarManager.findContainingJar(MLDataSet.class)); // jackson-databind-*.jar jars.add(JarManager.findContainingJar(ObjectMapper.class)); // jackson-core-*.jar jars.add(JarManager.findContainingJar(JsonParser.class)); // jackson-annotations-*.jar jars.add(JarManager.findContainingJar(JsonIgnore.class)); return StringUtils.join(jars, NNConstants.LIB_JAR_SEPARATOR); } private void runMapRedPostTrain() throws IOException, InterruptedException, ClassNotFoundException { SourceType source = modelConfig.getDataSet().getSource(); String postTrainOutputPath = super.getPathFinder().getTrainScoresPath(source); // run mr job to compute bin avg score runMRBinAvgScoreJob(source, postTrainOutputPath); // read from output file for avg score update updateAvgScores(source, postTrainOutputPath); ShifuFileUtils.deleteFile(new Path(postTrainOutputPath, "part-r-00000*").toString(), source); saveColumnConfigList(); if(super.modelConfig.getBasic().getPostTrainOn() != null && super.modelConfig.getBasic().getPostTrainOn()) { syncDataToHdfs(modelConfig.getDataSet().getSource()); String output = super.getPathFinder().getPostTrainOutputPath(source); runMRFeatureImportanceJob(source, output); List<Integer> fss = getFeatureImportance(source, output); log.info("Feature importance list is: {}", fss); } } private void updateAvgScores(SourceType source, String postTrainOutputPath) throws IOException { List<Scanner> scanners = null; try { scanners = ShifuFileUtils.getDataScanners(postTrainOutputPath, source, new PathFilter() { @Override public boolean accept(Path path) { return path.toString().contains("part-r-"); } }); for(Scanner scanner: scanners) { while(scanner.hasNextLine()) { String line = scanner.nextLine().trim(); String[] keyValues = line.split("\t"); String key = keyValues[0]; String value = keyValues[1]; ColumnConfig config = this.columnConfigList.get(Integer.parseInt(key)); List<Integer> binAvgScores = new ArrayList<Integer>(); String[] avgScores = value.split(","); for(int i = 0; i < avgScores.length; i++) { binAvgScores.add(Integer.parseInt(avgScores[i])); } config.setBinAvgScore(binAvgScores); } } } finally { // release closeScanners(scanners); } } private List<Integer> getFeatureImportance(SourceType source, String output) throws IOException { List<Integer> featureImportance = new ArrayList<Integer>(); List<Scanner> scanners = null; try { scanners = ShifuFileUtils.getDataScanners(output, source, new PathFilter() { @Override public boolean accept(Path path) { return path.toString().contains("part-r-"); } }); for(Scanner scanner: scanners) { while(scanner.hasNextLine()) { String line = scanner.nextLine().trim(); String[] keyValues = line.split("\t"); String key = keyValues[0]; featureImportance.add(Integer.parseInt(key)); } } } finally { // release closeScanners(scanners); } return featureImportance; } private void runMRBinAvgScoreJob(SourceType source, String postTrainOutputPath) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf = new Configuration(); // add jars to hadoop mapper and reducer new GenericOptionsParser(conf, new String[] { "-libjars", addRuntimeJars() }); conf.setBoolean(CombineInputFormat.SHIFU_VS_SPLIT_COMBINABLE, true); conf.setBoolean("mapreduce.input.fileinputformat.input.dir.recursive", true); conf.set(Constants.SHIFU_STATS_EXLCUDE_MISSING, Environment.getProperty(Constants.SHIFU_STATS_EXLCUDE_MISSING, "true")); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_MAP_TASKS_SPECULATIVE_EXECUTION, true); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_REDUCE_TASKS_SPECULATIVE_EXECUTION, true); conf.set( Constants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getModelConfigPath(source))).toString()); conf.set( Constants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getColumnConfigPath(source))).toString()); conf.set(NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, "default")); conf.set(Constants.SHIFU_MODELSET_SOURCE_TYPE, source.toString()); // set mapreduce.job.max.split.locations to 30 to suppress warnings conf.setInt(GuaguaMapReduceConstants.MAPREDUCE_JOB_MAX_SPLIT_LOCATIONS, 5000); conf.set("mapred.reduce.slowstart.completed.maps", Environment.getProperty("mapred.reduce.slowstart.completed.maps", "0.8")); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path conf.set("hdp.version", hdpVersion); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { conf.set(entry.getKey().toString(), entry.getValue().toString()); } } @SuppressWarnings("deprecation") Job job = new Job(conf, "Shifu: Post Train : " + this.modelConfig.getModelSetName()); job.setJarByClass(getClass()); job.setMapperClass(PostTrainMapper.class); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(FeatureStatsWritable.class); job.setInputFormatClass(CombineInputFormat.class); FileInputFormat.setInputPaths( job, ShifuFileUtils.getFileSystemBySourceType(source).makeQualified( new Path(super.modelConfig.getDataSetRawPath()))); MultipleOutputs.addNamedOutput(job, Constants.POST_TRAIN_OUTPUT_SCORE, TextOutputFormat.class, NullWritable.class, Text.class); job.setReducerClass(PostTrainReducer.class); job.setNumReduceTasks(1); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(Text.class); job.setOutputFormatClass(TextOutputFormat.class); FileOutputFormat.setOutputPath(job, new Path(postTrainOutputPath)); // clean output firstly ShifuFileUtils.deleteFile(postTrainOutputPath, source); // submit job if(!job.waitForCompletion(true)) { throw new RuntimeException("Post train Bin Avg Score MapReduce job is failed."); } } private void runMRFeatureImportanceJob(SourceType source, String output) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf = new Configuration(); // add jars to hadoop mapper and reducer new GenericOptionsParser(conf, new String[] { "-libjars", addRuntimeJars() }); conf.setBoolean(CombineInputFormat.SHIFU_VS_SPLIT_COMBINABLE, true); conf.setBoolean("mapreduce.input.fileinputformat.input.dir.recursive", true); conf.set(Constants.SHIFU_STATS_EXLCUDE_MISSING, Environment.getProperty(Constants.SHIFU_STATS_EXLCUDE_MISSING, "true")); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_MAP_TASKS_SPECULATIVE_EXECUTION, true); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_REDUCE_TASKS_SPECULATIVE_EXECUTION, true); conf.set( Constants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getModelConfigPath(source))).toString()); conf.set( Constants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getColumnConfigPath(source))).toString()); conf.set(NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, "default")); conf.set(Constants.SHIFU_MODELSET_SOURCE_TYPE, source.toString()); // set mapreduce.job.max.split.locations to 30 to suppress warnings conf.setInt(GuaguaMapReduceConstants.MAPREDUCE_JOB_MAX_SPLIT_LOCATIONS, 5000); conf.set("mapred.reduce.slowstart.completed.maps", Environment.getProperty("mapred.reduce.slowstart.completed.maps", "0.8")); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path conf.set("hdp.version", hdpVersion); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { conf.set(entry.getKey().toString(), entry.getValue().toString()); } } @SuppressWarnings("deprecation") Job job = new Job(conf, "Shifu: Post Train FeatureImportance : " + this.modelConfig.getModelSetName()); job.setJarByClass(getClass()); job.setMapperClass(FeatureImportanceMapper.class); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(DoubleWritable.class); job.setInputFormatClass(CombineInputFormat.class); FileInputFormat.setInputPaths( job, ShifuFileUtils.getFileSystemBySourceType(source).makeQualified( new Path(super.modelConfig.getDataSetRawPath()))); job.setReducerClass(FeatureImportanceReducer.class); job.setNumReduceTasks(1); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(DoubleWritable.class); job.setOutputFormatClass(TextOutputFormat.class); FileOutputFormat.setOutputPath(job, new Path(output)); // clean output firstly ShifuFileUtils.deleteFile(output, source); // submit job if(!job.waitForCompletion(true)) { throw new RuntimeException("Post train Feature Importance MapReduce job is failed."); } } /** * run pig post train * * @throws IOException * for any io exception */ @SuppressWarnings("unused") private void runPigPostTrain() throws IOException { SourceType sourceType = modelConfig.getDataSet().getSource(); ShifuFileUtils.deleteFile(pathFinder.getTrainScoresPath(), sourceType); ShifuFileUtils.deleteFile(pathFinder.getBinAvgScorePath(), sourceType); // prepare special parameters and execute pig Map<String, String> paramsMap = new HashMap<String, String>(); paramsMap.put("pathHeader", modelConfig.getHeaderPath()); paramsMap.put("pathDelimiter", CommonUtils.escapePigString(modelConfig.getHeaderDelimiter())); paramsMap.put("delimiter", CommonUtils.escapePigString(modelConfig.getDataSetDelimiter())); try { PigExecutor.getExecutor().submitJob(modelConfig, pathFinder.getScriptPath("scripts/PostTrain.pig"), paramsMap); } catch (IOException e) { throw new ShifuException(ShifuErrorCode.ERROR_RUNNING_PIG_JOB, e); } catch (Throwable e) { throw new RuntimeException(e); } // Sync Down columnConfigList = updateColumnConfigWithBinAvgScore(columnConfigList); saveColumnConfigList(); } /** * run akka post train * * @throws IOException * for any io exception */ private void runAkkaPostTrain() throws IOException { SourceType sourceType = modelConfig.getDataSet().getSource(); List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getSelectedRawDataPath(sourceType), sourceType); log.info("Num of Scanners: " + scanners.size()); AkkaSystemExecutor.getExecutor().submitPostTrainJob(modelConfig, columnConfigList, scanners); closeScanners(scanners); } /** * read the binary average score and update them into column list * * @param columnConfigList * input column config list * @return updated column config list * @throws IOException * for any io exception */ private List<ColumnConfig> updateColumnConfigWithBinAvgScore(List<ColumnConfig> columnConfigList) throws IOException { List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig .getDataSet().getSource()); // CommonUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource()); for(Scanner scanner: scanners) { while(scanner.hasNextLine()) { List<Integer> scores = new ArrayList<Integer>(); String[] raw = scanner.nextLine().split("\\|"); int columnNum = Integer.parseInt(raw[0]); for(int i = 1; i < raw.length; i++) { scores.add(Integer.valueOf(raw[i])); } ColumnConfig config = columnConfigList.get(columnNum); config.setBinAvgScore(scores); } } // release closeScanners(scanners); return columnConfigList; } }