/* * Copyright [2012-2014] PayPal Software Foundation * <p/> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.processor; import java.io.BufferedWriter; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Scanner; import java.util.SortedMap; import java.util.TreeMap; import ml.shifu.guagua.hadoop.util.HDPUtils; import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants; import ml.shifu.guagua.util.NumberFormatUtils; import ml.shifu.guagua.util.ReflectionUtils; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag; import ml.shifu.shifu.container.obj.ModelStatsConf; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.correlation.CorrelationMapper; import ml.shifu.shifu.core.correlation.CorrelationMultithreadedMapper; import ml.shifu.shifu.core.correlation.CorrelationReducer; import ml.shifu.shifu.core.correlation.CorrelationWritable; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.core.mr.input.CombineInputFormat; import ml.shifu.shifu.core.processor.stats.AbstractStatsExecutor; import ml.shifu.shifu.core.processor.stats.AkkaStatsWorker; import ml.shifu.shifu.core.processor.stats.DIBStatsExecutor; import ml.shifu.shifu.core.processor.stats.MunroPatIStatsExecutor; import ml.shifu.shifu.core.processor.stats.MunroPatStatsExecutor; import ml.shifu.shifu.core.processor.stats.SPDTIStatsExecutor; import ml.shifu.shifu.core.processor.stats.SPDTStatsExecutor; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.Environment; import org.apache.commons.codec.binary.Base64; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.io.IOUtils; import org.apache.commons.jexl2.JexlException; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.GenericOptionsParser; import org.apache.pig.impl.util.JarManager; import org.encog.ml.data.MLDataSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Predicate; import com.google.common.base.Splitter; /** * statistics, max/min/avg/std for each column dataset if it's numerical */ public class StatsModelProcessor extends BasicModelProcessor implements Processor { private final static Logger log = LoggerFactory.getLogger(StatsModelProcessor.class); private boolean isComputeCorr = false; public StatsModelProcessor(boolean isComputeCorr) { this.isComputeCorr = isComputeCorr; } public StatsModelProcessor() { } /** * runner for statistics */ @Override public int run() throws Exception { log.info("Step Start: stats"); long start = System.currentTimeMillis(); try { setUp(ModelStep.STATS); log.info("catMaxBinNum - {}", this.modelConfig.getStats().getCateMaxNumBin()); // resync ModelConfig.json/ColumnConfig.json to HDFS syncDataToHdfs(modelConfig.getDataSet().getSource()); if(isComputeCorr) { // 1. validate if run stats before run stats -correlation boolean foundValidMeanValueColumn = false; for(ColumnConfig config: this.columnConfigList) { if(!config.isMeta() && !config.isTarget() && config.isNumerical()) { if(config.getMean() != null) { foundValidMeanValueColumn = true; break; } } } if(!foundValidMeanValueColumn) { log.warn("Some mean value of column is null, could you check if you run 'shifu stats'."); return -1; } // 2. compute correlation log.info("Start computing correlation value ..."); runCorrMapReduceJob(); // 3. save column config list saveColumnConfigList(); } else { AbstractStatsExecutor statsExecutor = null; if(modelConfig.isMapReduceRunMode()) { if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.DynamicBinning)) { statsExecutor = new DIBStatsExecutor(this, modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.MunroPat)) { statsExecutor = new MunroPatStatsExecutor(this, modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.MunroPatI)) { statsExecutor = new MunroPatIStatsExecutor(this, modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.SPDT)) { statsExecutor = new SPDTStatsExecutor(this, modelConfig, columnConfigList); } else if(modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.SPDTI)) { statsExecutor = new SPDTIStatsExecutor(this, modelConfig, columnConfigList); } else { statsExecutor = new SPDTIStatsExecutor(this, modelConfig, columnConfigList); } } else if(modelConfig.isLocalRunMode()) { statsExecutor = new AkkaStatsWorker(this, modelConfig, columnConfigList); } else { throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_MODE); } statsExecutor.doStats(); } syncDataToHdfs(modelConfig.getDataSet().getSource()); clearUp(ModelStep.STATS); } catch (Exception e) { log.error("Error:", e); return -1; } log.info("Step Finished: stats with {} ms", (System.currentTimeMillis() - start)); return 0; } // OptionsParser doesn't to support *.jar currently. private String addRuntimeJars() { List<String> jars = new ArrayList<String>(16); // common-codec jars.add(JarManager.findContainingJar(Base64.class)); // commons-compress-*.jar jars.add(JarManager.findContainingJar(BZip2CompressorInputStream.class)); // commons-lang-*.jar jars.add(JarManager.findContainingJar(StringUtils.class)); // common-io-*.jar jars.add(JarManager.findContainingJar(org.apache.commons.io.IOUtils.class)); // common-collections jars.add(JarManager.findContainingJar(Predicate.class)); // guava-*.jar jars.add(JarManager.findContainingJar(Splitter.class)); // guagua-core-*.jar jars.add(JarManager.findContainingJar(NumberFormatUtils.class)); // shifu-*.jar jars.add(JarManager.findContainingJar(getClass())); // jexl-*.jar jars.add(JarManager.findContainingJar(JexlException.class)); // encog-core-*.jar jars.add(JarManager.findContainingJar(MLDataSet.class)); // jackson-databind-*.jar jars.add(JarManager.findContainingJar(ObjectMapper.class)); // jackson-core-*.jar jars.add(JarManager.findContainingJar(JsonParser.class)); // jackson-annotations-*.jar jars.add(JarManager.findContainingJar(JsonIgnore.class)); return StringUtils.join(jars, NNConstants.LIB_JAR_SEPARATOR); } private void runCorrMapReduceJob() throws IOException, InterruptedException, ClassNotFoundException { SourceType source = this.modelConfig.getDataSet().getSource(); Configuration conf = new Configuration(); // add jars to hadoop mapper and reducer new GenericOptionsParser(conf, new String[] { "-libjars", addRuntimeJars() }); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_MAP_TASKS_SPECULATIVE_EXECUTION, true); conf.setBoolean(GuaguaMapReduceConstants.MAPRED_REDUCE_TASKS_SPECULATIVE_EXECUTION, true); conf.set(NNConstants.MAPRED_JOB_QUEUE_NAME, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, "default")); conf.setInt(GuaguaMapReduceConstants.MAPREDUCE_JOB_MAX_SPLIT_LOCATIONS, 5000); conf.set( Constants.SHIFU_MODEL_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getModelConfigPath(source))).toString()); conf.set( Constants.SHIFU_COLUMN_CONFIG, ShifuFileUtils.getFileSystemBySourceType(source) .makeQualified(new Path(super.getPathFinder().getColumnConfigPath(source))).toString()); conf.set(Constants.SHIFU_MODELSET_SOURCE_TYPE, source.toString()); // too many data needed to be transfered to reducer, set default completed maps to a smaller one 0.7 to start // copy data in reducer earlier. conf.set("mapred.reduce.slowstart.completed.maps", Environment.getProperty("mapred.reduce.slowstart.completed.maps", "0.7")); String hdpVersion = HDPUtils.getHdpVersionForHDP224(); if(StringUtils.isNotBlank(hdpVersion)) { // for hdp 2.2.4, hdp.version should be set and configuration files should be add to container class path conf.set("hdp.version", hdpVersion); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf); HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf); } conf.setBoolean(CombineInputFormat.SHIFU_VS_SPLIT_COMBINABLE, true); conf.setBoolean("mapreduce.input.fileinputformat.input.dir.recursive", true); int threads = parseThreadNum(); conf.setInt("mapreduce.map.cpu.vcores", threads); // one can set guagua conf in shifuconfig for(Map.Entry<Object, Object> entry: Environment.getProperties().entrySet()) { if(CommonUtils.isHadoopConfigurationInjected(entry.getKey().toString())) { conf.set(entry.getKey().toString(), entry.getValue().toString()); } } setMapperMemory(conf, threads); @SuppressWarnings("deprecation") Job job = new Job(conf, "Shifu: Correlation Computing Job : " + this.modelConfig.getModelSetName()); job.setJarByClass(getClass()); job.setMapperClass(CorrelationMultithreadedMapper.class); CorrelationMultithreadedMapper.setMapperClass(job, CorrelationMapper.class); CorrelationMultithreadedMapper.setNumberOfThreads(job, threads); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(CorrelationWritable.class); job.setInputFormatClass(CombineInputFormat.class); FileInputFormat.setInputPaths( job, ShifuFileUtils.getFileSystemBySourceType(source).makeQualified( new Path(super.modelConfig.getDataSetRawPath()))); job.setReducerClass(CorrelationReducer.class); // 3000 features will be 30 reducers, 600 will be 6, much more reducer to avoid data all copied to one reducer // especially when features over 3000, each mapper output is 700M, 400 mapper will be 280G size job.setNumReduceTasks(this.columnConfigList.size() < 50 ? 2 : this.columnConfigList.size() / 50); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(Text.class); job.setOutputFormatClass(TextOutputFormat.class); String corrPath = super.getPathFinder().getCorrelationPath(source); FileOutputFormat.setOutputPath(job, new Path(corrPath)); // clean output firstly ShifuFileUtils.deleteFile(corrPath, source); // submit job if(job.waitForCompletion(true)) { dumpCorrelationResult(source, corrPath); } else { throw new RuntimeException("MapReduce Correlation Computing Job failed."); } } /** * If 3000 * 3000 correlation computing, per default threads number setting, memory should be set according to * column size to avoid OOM issue. */ private void setMapperMemory(Configuration conf, int threads) { int memoryBuffer = 500; // <1000 -> 2G; <=2000 2.5G; <=3000 3G; <=4000 4G; <=5000; 5G int memoryInContainer = this.columnConfigList.size(); if(memoryInContainer > 4000 && memoryInContainer <= 5000) { memoryInContainer = (int) (memoryInContainer * 1.1d); } else if(memoryInContainer > 5000) { memoryInContainer = (int) (memoryInContainer * 1.2d); } if(memoryInContainer < 2048) { memoryInContainer = 2048; // at least 2048M } memoryInContainer += memoryBuffer; // (MB, 500 is buffer) log.info("Corrrelation map memory is set to {}MB.", memoryInContainer); conf.set("mapreduce.map.memory.mb", memoryInContainer + ""); conf.set( "mapreduce.map.java.opts", "-Xms" + (memoryInContainer - memoryBuffer) + "m -Xmx" + (memoryInContainer - memoryBuffer) + "m -server -XX:MaxPermSize=128M -XX:PermSize=64M -XX:+UseParallelGC -XX:+UseParallelOldGC -XX:ParallelGCThreads=8 -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps"); } private int parseThreadNum() { int threads = 6; try { threads = Integer .parseInt(Environment.getProperty(Constants.SHIFU_CORRELATION_MULTI_THREADS, threads + "")); } catch (Exception e) { log.warn("'shifu.correlation.multi.threads' should be a int value, set default value: {}", threads); } if(threads <= 0) { threads = 6; } return threads; } private void dumpCorrelationResult(SourceType source, String corrPath) throws IOException { String outputFilePattern = corrPath + Path.SEPARATOR + "part-*"; if(!ShifuFileUtils.isFileExists(outputFilePattern, source)) { throw new RuntimeException("Correlation computing output file not exist."); } computeCorrValue(dumpCorrInfo(source, outputFilePattern)); } /** * Compute correlation value according to correlation statistics from correlation MR job. * * @param corrMap * CorrelationWritable map read from MR job output file * @throws IOException * any IOException to write correlation value to csv file. */ private void computeCorrValue(SortedMap<Integer, CorrelationWritable> corrMap) throws IOException { String localCorrelationCsv = super.pathFinder.getLocalCorrelationCsvPath(); ShifuFileUtils.createFileIfNotExists(localCorrelationCsv, SourceType.LOCAL); BufferedWriter writer = null; Map<Integer, double[]> finalCorrMap = new HashMap<Integer, double[]>(); try { writer = ShifuFileUtils.getWriter(localCorrelationCsv, SourceType.LOCAL); writer.write(getColumnIndexes()); writer.newLine(); writer.write(getColumnNames()); writer.newLine(); for(Entry<Integer, CorrelationWritable> entry: corrMap.entrySet()) { ColumnConfig xColumnConfig = this.columnConfigList.get(entry.getKey()); if(xColumnConfig.getColumnFlag() == ColumnFlag.Meta) { continue; } CorrelationWritable xCw = corrMap.get(entry.getKey()); double[] corrArray = new double[this.columnConfigList.size()]; for(int i = 0; i < corrArray.length; i++) { ColumnConfig yColumnConfig = this.columnConfigList.get(i); if(yColumnConfig.getColumnFlag() == ColumnFlag.Meta) { continue; } if(entry.getKey() > i) { double[] reverseDoubleArray = finalCorrMap.get(i); if(reverseDoubleArray != null) { corrArray[i] = reverseDoubleArray[entry.getKey()]; } else { corrArray[i] = 0d; } // not compute all, only up-right matrix are computed, such case, just get [i, j] from [j, i] continue; } double numerator = xCw.getAdjustCount()[i] * xCw.getXySum()[i] - xCw.getAdjustSumX()[i] * xCw.getAdjustSumY()[i]; double denominator1 = Math.sqrt(xCw.getAdjustCount()[i] * xCw.getXxSum()[i] - xCw.getAdjustSumX()[i] * xCw.getAdjustSumX()[i]); double denominator2 = Math.sqrt(xCw.getAdjustCount()[i] * xCw.getYySum()[i] - xCw.getAdjustSumY()[i] * xCw.getAdjustSumY()[i]); if(Double.compare(denominator1, Double.valueOf(0d)) == 0 || Double.compare(denominator2, Double.valueOf(0d)) == 0) { corrArray[i] = 0d; } else { corrArray[i] = numerator / (denominator1 * denominator2); } } // put to current map finalCorrMap.put(entry.getKey(), corrArray); // write to csv String corrStr = Arrays.toString(corrArray); String adjustCorrStr = corrStr.substring(1, corrStr.length() - 1); writer.write(entry.getKey() + "," + this.columnConfigList.get(entry.getKey()).getColumnName() + "," + adjustCorrStr); writer.newLine(); } } finally { IOUtils.closeQuietly(writer); } } /** * Dump {@link CorrelationWritable} from correlation MR job output file. This may need more memory if high column * number. Local memory should be set to 4G instead of 2G. * * @param source * source type * @param outputFilePattern * output file pattern like part-* * @return Sorted map including CorrelationWritable info * @throws IOException * any IO exception in reading output file * @throws UnsupportedEncodingException * encoding exception to de-serialize correlation info in output file */ private SortedMap<Integer, CorrelationWritable> dumpCorrInfo(SourceType source, String outputFilePattern) throws IOException, UnsupportedEncodingException { SortedMap<Integer, CorrelationWritable> corrMap = new TreeMap<Integer, CorrelationWritable>(); FileStatus[] globStatus = ShifuFileUtils.getFileSystemBySourceType(source).globStatus( new Path(outputFilePattern)); if(globStatus == null || globStatus.length == 0) { throw new RuntimeException("Correlation computing output file not exist."); } for(FileStatus fileStatus: globStatus) { List<Scanner> scanners = ShifuFileUtils.getDataScanners(fileStatus.getPath().toString(), source); for(Scanner scanner: scanners) { while(scanner.hasNext()) { String str = scanner.nextLine().trim(); if(str.contains(Constants.TAB_STR)) { String[] splits = str.split(Constants.TAB_STR); String corrStr = splits[1]; int columnIndex = Integer.parseInt(splits[0].trim()); corrMap.put(columnIndex, bytesToObject(Base64.decodeBase64(corrStr.getBytes("utf-8")))); } } } closeScanners(scanners); } return corrMap; } /** * De-serialize from bytes to object. One should provide the class name before de-serializing the object. * * @throws NullPointerException * if className or data is null. * @throws RuntimeException * if any io exception or other reflection exception. */ public CorrelationWritable bytesToObject(byte[] data) { if(data == null) { throw new NullPointerException(String.format( "data and className should not be null. data:%s, className:%s", Arrays.toString(data))); } CorrelationWritable result = (CorrelationWritable) ReflectionUtils.newInstance(CorrelationWritable.class .getName()); DataInputStream dataIn = null; try { ByteArrayInputStream in = new ByteArrayInputStream(data); dataIn = new DataInputStream(in); result.readFields(dataIn); } catch (Exception e) { throw new RuntimeException(e); } finally { if(dataIn != null) { try { dataIn.close(); } catch (IOException e) { throw new RuntimeException(e); } } } return result; } private String getColumnIndexes() { StringBuilder header = new StringBuilder("ColumnIndex,"); for(ColumnConfig config: columnConfigList) { header.append(',').append(config.getColumnNum()); } return header.toString(); } private String getColumnNames() { StringBuilder header = new StringBuilder(",ColumnName"); for(ColumnConfig config: columnConfigList) { header.append(',').append(config.getColumnName()); } return header.toString(); } }