/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.util; import com.google.common.base.Function; import com.google.common.base.Splitter; import com.google.common.collect.Collections2; import com.google.common.collect.Lists; import ml.shifu.shifu.column.NSColumn; import ml.shifu.shifu.column.NSColumnUtils; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType; import ml.shifu.shifu.container.obj.EvalConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.LR; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.core.TreeModel; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork; import ml.shifu.shifu.core.dtrain.lr.LogisticRegressionContants; import ml.shifu.shifu.core.model.ModelSpec; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.fs.PathFinder; import ml.shifu.shifu.fs.ShifuFileUtils; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.Predicate; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.math.NumberUtils; import org.apache.commons.lang3.tuple.MutablePair; import org.apache.hadoop.fs.*; import org.apache.hadoop.fs.FileSystem; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.data.Tuple; import org.encog.ml.BasicML; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.persist.EncogDirectoryPersistence; import org.encog.persist.PersistorRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.*; import java.util.Map.Entry; /** * {@link CommonUtils} is used to for almost all kinds of utility function in this framework. */ public final class CommonUtils { /** * Avoid using new for our utility class. */ private CommonUtils() { } private static final Logger log = LoggerFactory.getLogger(CommonUtils.class); /** * Sync up all local configuration files to HDFS. * * @param modelConfig * the model config * @param pathFinder * the path finder to locate file * @return if copy successful * * @throws IOException * If any exception on HDFS IO or local IO. * * @throws NullPointerException * If parameter {@code modelConfig} is null */ public static boolean copyConfFromLocalToHDFS(ModelConfig modelConfig, PathFinder pathFinder) throws IOException { FileSystem hdfs = HDFSUtils.getFS(); FileSystem localFs = HDFSUtils.getLocalFS(); Path pathModelSet = new Path(pathFinder.getModelSetPath(SourceType.HDFS)); // don't check whether pathModelSet is exists, should be remove by user. hdfs.mkdirs(pathModelSet); // Copy ModelConfig Path srcModelConfig = new Path(pathFinder.getModelConfigPath(SourceType.LOCAL)); Path dstModelConfig = new Path(pathFinder.getModelSetPath(SourceType.HDFS)); hdfs.copyFromLocalFile(srcModelConfig, dstModelConfig); // Copy ColumnConfig Path srcColumnConfig = new Path(pathFinder.getColumnConfigPath(SourceType.LOCAL)); Path dstColumnConfig = new Path(pathFinder.getColumnConfigPath(SourceType.HDFS)); if(ShifuFileUtils.isFileExists(srcColumnConfig.toString(), SourceType.LOCAL)) { hdfs.copyFromLocalFile(srcColumnConfig, dstColumnConfig); } // copy others Path srcVersion = new Path(pathFinder.getModelVersion(SourceType.LOCAL)); if(localFs.exists(srcVersion)) { Path dstVersion = new Path(pathFinder.getModelVersion(SourceType.HDFS)); hdfs.delete(dstVersion, true); hdfs.copyFromLocalFile(srcVersion, pathModelSet); } // Copy Models Path srcModels = new Path(pathFinder.getModelsPath(SourceType.LOCAL)); if(localFs.exists(srcModels)) { Path dstModels = new Path(pathFinder.getModelsPath(SourceType.HDFS)); hdfs.delete(dstModels, true); hdfs.copyFromLocalFile(srcModels, pathModelSet); } // Copy EvalSets Path evalsPath = new Path(pathFinder.getEvalsPath(SourceType.LOCAL)); if(localFs.exists(evalsPath)) { for(FileStatus evalset: localFs.listStatus(evalsPath)) { EvalConfig evalConfig = modelConfig.getEvalConfigByName(evalset.getPath().getName()); if(evalConfig != null) { copyEvalDataFromLocalToHDFS(modelConfig, evalConfig.getName()); } } } return true; } /** * Sync-up the evalulation data into HDFS * * @param modelConfig * - ModelConfig * @param evalName * eval name in ModelConfig * @throws IOException * - error occur when copying data */ @SuppressWarnings("deprecation") public static void copyEvalDataFromLocalToHDFS(ModelConfig modelConfig, String evalName) throws IOException { EvalConfig evalConfig = modelConfig.getEvalConfigByName(evalName); if(evalConfig != null) { FileSystem hdfs = HDFSUtils.getFS(); FileSystem localFs = HDFSUtils.getLocalFS(); PathFinder pathFinder = new PathFinder(modelConfig); Path evalDir = new Path(pathFinder.getEvalSetPath(evalConfig, SourceType.LOCAL)); Path dst = new Path(pathFinder.getEvalSetPath(evalConfig, SourceType.HDFS)); if(localFs.exists(evalDir) // local evaluation folder exists && localFs.getFileStatus(evalDir).isDir() // is directory && !hdfs.exists(dst)) { hdfs.copyFromLocalFile(evalDir, dst); } if(StringUtils.isNotBlank(evalConfig.getScoreMetaColumnNameFile())) { hdfs.copyFromLocalFile(new Path(evalConfig.getScoreMetaColumnNameFile()), new Path(pathFinder.getEvalSetPath(evalConfig))); } // sync evaluation meta.column.file to hdfs if(StringUtils.isNotBlank(evalConfig.getDataSet().getMetaColumnNameFile())) { hdfs.copyFromLocalFile(new Path(evalConfig.getDataSet().getMetaColumnNameFile()), new Path(pathFinder.getEvalSetPath(evalConfig))); } } } public static String getLocalModelSetPath(Map<String, Object> otherConfigs) { if(otherConfigs != null && otherConfigs.get(Constants.SHIFU_CURRENT_WORKING_DIR) != null) { return new Path(otherConfigs.get(Constants.SHIFU_CURRENT_WORKING_DIR).toString()).toString(); } else { return "."; } } /** * Load ModelConfig from local json ModelConfig.json file. * * @return model config instance from default model config file * @throws IOException * any io exception to load file */ public static ModelConfig loadModelConfig() throws IOException { return loadModelConfig(Constants.LOCAL_MODEL_CONFIG_JSON, SourceType.LOCAL); } /** * Load model configuration from the path and the source type. * * @param path * model file path * @param sourceType * source type of model file * @return model config instance * @throws IOException * if any IO exception in parsing json. * * @throws IllegalArgumentException * if {@code path} is null or empty, if sourceType is null. */ public static ModelConfig loadModelConfig(String path, SourceType sourceType) throws IOException { return loadJSON(path, sourceType, ModelConfig.class); } private static void checkPathAndMode(String path, SourceType sourceType) { if(StringUtils.isEmpty(path) || sourceType == null) { throw new IllegalArgumentException(String.format( "path should not be null or empty, sourceType should not be null, path:%s, sourceType:%s", path, sourceType)); } } /** * Load reason code map and change it to column > resonCode map. * * @param path * reason code path * @param sourceType * source type of file * @return reason code map * @throws IOException * if any IO exception in parsing json. * * @throws IllegalArgumentException * if {@code path} is null or empty, if sourceType is null. */ public static Map<String, String> loadAndFlattenReasonCodeMap(String path, SourceType sourceType) throws IOException { @SuppressWarnings("unchecked") Map<String, List<String>> rawMap = loadJSON(path, sourceType, Map.class); Map<String, String> reasonCodeMap = new HashMap<String, String>(); for(Map.Entry<String, List<String>> entry: rawMap.entrySet()) { for(String str: entry.getValue()) { reasonCodeMap.put(getRelativePigHeaderColumnName(str), entry.getKey()); } } return reasonCodeMap; } /** * Load JSON instance * * @param path * file path * @param sourceType * source type: hdfs or local * @param clazz * class of instance * @param <T> * class type to load * @return instance from json file * @throws IOException * if any IO exception in parsing json. * * @throws IllegalArgumentException * if {@code path} is null or empty, if sourceType is null. */ public static <T> T loadJSON(String path, SourceType sourceType, Class<T> clazz) throws IOException { checkPathAndMode(path, sourceType); log.debug("loading {} with sourceType {}", path, sourceType); BufferedReader reader = null; try { reader = ShifuFileUtils.getReader(path, sourceType); return JSONUtils.readValue(reader, clazz); } finally { IOUtils.closeQuietly(reader); } } /** * Load column configuration list. * * @return column config list * @throws IOException * if any IO exception in parsing json. */ public static List<ColumnConfig> loadColumnConfigList() throws IOException { return loadColumnConfigList(Constants.LOCAL_COLUMN_CONFIG_JSON, SourceType.LOCAL); } /** * Load column configuration list. * * @param path * file path * @param sourceType * source type: hdfs or local * @return column config list * @throws IOException * if any IO exception in parsing json. * @throws IllegalArgumentException * if {@code path} is null or empty, if sourceType is null. */ public static List<ColumnConfig> loadColumnConfigList(String path, SourceType sourceType) throws IOException { return Arrays.asList(loadJSON(path, sourceType, ColumnConfig[].class)); } /** * Return final selected column collection. * * @param columnConfigList * column config list * @return collection of column config list for final select is true */ public static Collection<ColumnConfig> getFinalSelectColumnConfigList(Collection<ColumnConfig> columnConfigList) { return Collections2.filter(columnConfigList, new com.google.common.base.Predicate<ColumnConfig>() { @Override public boolean apply(ColumnConfig input) { return input.isFinalSelect(); } }); } public static String[] getFinalHeaders(ModelConfig modelConfig) throws IOException { String[] fields = null; boolean isSchemaProvided = true; if(StringUtils.isNotBlank(modelConfig.getHeaderPath())) { fields = CommonUtils.getHeaders(modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter(), modelConfig .getDataSet().getSource()); } else { fields = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig .getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource()); if(StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) { // if first line contains target column name, we guess it is csv format and first line is header. isSchemaProvided = true; log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as first line of data set path."); } else { isSchemaProvided = false; log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ..."); log.warn("Please make sure weight column and tag column are also taking index as name."); } } for(int i = 0; i < fields.length; i++) { if(!isSchemaProvided) { fields[i] = i + ""; } else { fields[i] = getRelativePigHeaderColumnName(fields[i]); } } return fields; } public static String[] getFinalHeaders(EvalConfig evalConfig) throws IOException { String[] fields = null; boolean isSchemaProvided = true; if(StringUtils.isNotBlank(evalConfig.getDataSet().getHeaderPath())) { String delimiter = StringUtils.isBlank(evalConfig.getDataSet().getHeaderDelimiter()) ? evalConfig .getDataSet().getDataDelimiter() : evalConfig.getDataSet().getHeaderDelimiter(); fields = CommonUtils.getHeaders(evalConfig.getDataSet().getHeaderPath(), delimiter, evalConfig.getDataSet() .getSource()); } else { fields = CommonUtils.takeFirstLine(evalConfig.getDataSet().getDataPath(), StringUtils.isBlank(evalConfig .getDataSet().getHeaderDelimiter()) ? evalConfig.getDataSet().getDataDelimiter() : evalConfig .getDataSet().getHeaderDelimiter(), evalConfig.getDataSet().getSource()); if(StringUtils.join(fields, "").contains(evalConfig.getDataSet().getTargetColumnName())) { // if first line contains target column name, we guess it is csv format and first line is header. isSchemaProvided = true; log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as first line of data set path."); } else { isSchemaProvided = false; log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ..."); log.warn("Please make sure weight column and tag column are also taking index as name."); } } for(int i = 0; i < fields.length; i++) { if(!isSchemaProvided) { fields[i] = i + ""; } else { fields[i] = getRelativePigHeaderColumnName(fields[i]); } } return fields; } /** * Return header column list from header file. * * @param pathHeader * header path * @param delimiter * the delimiter of headers * @param sourceType * source type: hdfs or local * @return headers array * @throws IOException * if any IO exception in reading file. * * @throws IllegalArgumentException * if sourceType is null, if pathHeader is null or empty, if delimiter is null or empty. * * @throws RuntimeException * if first line of pathHeader is null or empty. */ public static String[] getHeaders(String pathHeader, String delimiter, SourceType sourceType) throws IOException { return getHeaders(pathHeader, delimiter, sourceType, false); } /** * Return header column array from header file. * * @param pathHeader * header path * @param delimiter * the delimiter of headers * @param sourceType * source type: hdfs or local * @param isFull * if full header name including name space * @return headers array * @throws IOException * if any IO exception in reading file. * * @throws IllegalArgumentException * if sourceType is null, if pathHeader is null or empty, if delimiter is null or empty. * * @throws RuntimeException * if first line of pathHeader is null or empty. */ public static String[] getHeaders(String pathHeader, String delimiter, SourceType sourceType, boolean isFull) throws IOException { if(StringUtils.isEmpty(pathHeader) || StringUtils.isEmpty(delimiter) || sourceType == null) { throw new IllegalArgumentException(String.format( "Null or empty parameters srcDataPath:%s, dstDataPath:%s, sourceType:%s", pathHeader, delimiter, sourceType)); } BufferedReader reader = null; String pigHeaderStr = null; try { reader = ShifuFileUtils.getReader(pathHeader, sourceType); pigHeaderStr = reader.readLine(); if(StringUtils.isEmpty(pigHeaderStr)) { throw new RuntimeException(String.format("Cannot reade header info from the first line of file: %s", pathHeader)); } } catch (Exception e) { log.error( "Error in getReader, this must be catched in this method to make sure the next reader can be returned.", e); throw new ShifuException(ShifuErrorCode.ERROR_HEADER_NOT_FOUND); } finally { IOUtils.closeQuietly(reader); } List<String> headerList = new ArrayList<String>(); Set<String> headerSet = new HashSet<String>(); int index = 0; for(String str: Splitter.on(delimiter).split(pigHeaderStr)) { String columnName = StringUtils.trimToEmpty(str); /* * if(isFull) { * columnName = getFullPigHeaderColumnName(str); * } else { * columnName = getRelativePigHeaderColumnName(str); * } */ if(headerSet.contains(columnName)) { columnName = columnName + "_" + index; } headerSet.add(columnName); index++; headerList.add(columnName); } return headerList.toArray(new String[0]); } /** * Get full column name from pig header. For example, one column is a::b, return a_b. If b, return b. * * @param raw * raw name * @return full name including namespace */ public static String getFullPigHeaderColumnName(String raw) { return raw == null ? raw : raw.replaceAll(Constants.PIG_COLUMN_SEPARATOR, Constants.PIG_FULL_COLUMN_SEPARATOR); } /** * Get relative column name from pig header. For example, one column is a::b, return b. If b, return b. * * @param raw * raw name * @return relative name including namespace * @throws NullPointerException * if parameter raw is null. */ public static String getRelativePigHeaderColumnName(String raw) { int position = raw.lastIndexOf(Constants.PIG_COLUMN_SEPARATOR); return position >= 0 ? raw.substring(position + Constants.PIG_COLUMN_SEPARATOR.length()) : raw; } /** * Given a column value, return bin list index. Return 0 for Category because of index 0 is started from * NEGATIVE_INFINITY. * * @param columnConfig * column config * @param columnVal * value of the column * @return bin index of than value * @throws IllegalArgumentException * if input is null or empty. * * @throws NumberFormatException * if columnVal does not contain a parsable number. */ public static int getBinNum(ColumnConfig columnConfig, String columnVal) { if(columnConfig.isCategorical()) { List<String> binCategories = columnConfig.getBinCategory(); for(int i = 0; i < binCategories.size(); i++) { if(isCategoricalBinValue(binCategories.get(i), columnVal)) { return i; } } return -1; } else { if(StringUtils.isBlank(columnVal)) { return -1; } double dval = 0.0; try { dval = Double.parseDouble(columnVal); } catch (Exception e) { return -1; } return getBinIndex(columnConfig.getBinBoundary(), dval); } } /** * Check some categorical value is in the categorical value group or not * * @param binVal * - categorical value group, the format is lik cn^us^uk^jp * @param cval * - categorical value to look up * @return true if the categorical value exists in group, else false */ public static boolean isCategoricalBinValue(String binVal, String cval) { return binVal.equals(cval) ? true : CommonUtils.flattenCatValGrp(binVal).contains(cval); } /** * Return the real bin number for one value. As the first bin value is NEGATIVE_INFINITY, invalid index is 0, not * -1. * * @param binBoundary * bin boundary list which should be sorted. * @param value * value of column * @return bin index * * @throws IllegalArgumentException * if binBoundary is null or empty. */ @SuppressWarnings("unused") private static int getNumericBinNum(List<Double> binBoundary, double value) { if(CollectionUtils.isEmpty(binBoundary)) { throw new IllegalArgumentException("binBoundary should not be null or empty."); } int n = binBoundary.size() - 1; while(n > 0 && value < binBoundary.get(n)) { n--; } return n; } /** * Common split function to ignore special character like '|'. It's better to return a list while many calls in our * framework using string[]. * * @param raw * raw string * @param delimiter * the delimeter to split the string * @return array of split Strings * * @throws IllegalArgumentException * {@code raw} and {@code delimiter} is null or empty. */ public static String[] split(String raw, String delimiter) { return splitAndReturnList(raw, delimiter).toArray(new String[0]); } /** * Common split function to ignore special character like '|'. * * @param raw * raw string * @param delimiter * the delimeter to split the string * @return list of split Strings * @throws IllegalArgumentException * {@code raw} and {@code delimiter} is null or empty. */ public static List<String> splitAndReturnList(String raw, String delimiter) { if(StringUtils.isEmpty(raw) || StringUtils.isEmpty(delimiter)) { throw new IllegalArgumentException(String.format( "raw and delimeter should not be null or empty, raw:%s, delimeter:%s", raw, delimiter)); } List<String> headerList = new ArrayList<String>(); for(String str: Splitter.on(delimiter).split(raw)) { headerList.add(str); } return headerList; } /** * Get target column. * * @param columnConfigList * column config list * @return target column index * @throws IllegalArgumentException * if columnConfigList is null or empty. * * @throws IllegalStateException * if no target column can be found. */ public static Integer getTargetColumnNum(List<ColumnConfig> columnConfigList) { if(CollectionUtils.isEmpty(columnConfigList)) { throw new IllegalArgumentException("columnConfigList should not be null or empty."); } // I need cast operation because of common-collections dosen't support generic. ColumnConfig cc = (ColumnConfig) CollectionUtils.find(columnConfigList, new Predicate() { @Override public boolean evaluate(Object object) { return ((ColumnConfig) object).isTarget(); } }); if(cc == null) { throw new IllegalStateException("No target column can be found, please check your column configurations"); } return cc.getColumnNum(); } /** * Load basic models from files. * * @param modelConfig * ModelConfig * @param columnConfigList * column config list * @param evalConfig * eval config instance * @return the list of models * @throws IOException * if any IO exception in reading model file. * * @throws IllegalArgumentException * if {@code modelConfig} is, if invalid model algorithm . * * @throws IllegalStateException * if not HDFS or LOCAL source type or algorithm not supported. */ public static List<BasicML> loadBasicModels(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig) throws IOException { if(modelConfig == null || (!Constants.NN.equalsIgnoreCase(modelConfig.getAlgorithm()) && !Constants.SVM.equalsIgnoreCase(modelConfig.getAlgorithm()) && !Constants.LR.equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils .isTreeModel(modelConfig.getAlgorithm()))) { throw new IllegalArgumentException(modelConfig == null ? "modelConfig is null." : String.format( " invalid model algorithm %s.", modelConfig.getAlgorithm())); } return loadBasicModels(modelConfig, evalConfig, modelConfig.getDataSet().getSource()); } /** * Get bin index by binary search. The last bin in <code>binBoundary</code> is missing value bin. * * @param binBoundary * bin boundary list which should be sorted. * @param dVal * value of column * @return bin index */ public static int getBinIndex(List<Double> binBoundary, Double dVal) { assert binBoundary != null && binBoundary.size() > 0; assert dVal != null; int binSize = binBoundary.size(); int low = 0; int high = binSize - 1; while(low <= high) { int mid = (low + high) >>> 1; Double midVal = binBoundary.get(mid); int cmp = midVal.compareTo(dVal); if(cmp < 0) { low = mid + 1; } else if(cmp > 0) { high = mid - 1; } else { return mid; // key found } } return low == 0 ? 0 : low - 1; } public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException { List<BasicML> models = new ArrayList<BasicML>(); FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType); if(CollectionUtils.isNotEmpty(modelFileStats)) { for(FileStatus f: modelFileStats) { models.add(loadModel(modelConfig, f.getPath(), fs)); } } return models; } /** * Load basic models from files. * * @param modelConfig * model config * @param evalConfig * eval confg * @param sourceType * source type * @param gbtConvertToProb * convert gbt score to prob or not * @return list of models * @throws IOException * if any IO exception in reading model file. * * @throws IllegalArgumentException * if {@code modelConfig} is, if invalid model algorithm . * * @throws IllegalStateException * if not HDFS or LOCAL source type or algorithm not supported. */ public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType, boolean gbtConvertToProb) throws IOException { List<BasicML> models = new ArrayList<BasicML>(); FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType); if(CollectionUtils.isNotEmpty(modelFileStats)) { for(FileStatus f: modelFileStats) { models.add(loadModel(modelConfig, f.getPath(), fs, gbtConvertToProb)); } } return models; } public static int getBasicModelsCnt(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException { List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType); return (CollectionUtils.isEmpty(modelFileStats) ? 0 : modelFileStats.size()); } public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException { // we have to register PersistBasicFloatNetwork for loading such models PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork()); List<FileStatus> listStatus = findModels(modelConfig, evalConfig, sourceType); if(CollectionUtils.isEmpty(listStatus)) { // throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND); // disable exception, since we there maybe sub-models return listStatus; } // to avoid the *unix and windows file list order Collections.sort(listStatus, new Comparator<FileStatus>() { @Override public int compare(FileStatus f1, FileStatus f2) { return f1.getPath().getName().compareToIgnoreCase(f2.getPath().getName()); } }); // added in shifu 0.2.5 to slice models not belonging to last training int baggingModelSize = modelConfig.getTrain().getBaggingNum(); if(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) { baggingModelSize = modelConfig.getTags().size(); } listStatus = listStatus.size() <= baggingModelSize ? listStatus : listStatus.subList(0, baggingModelSize); return listStatus; } public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs) throws IOException { return loadModel(modelConfig, modelPath, fs, false); } /** * Loading model according to existing model path. * * @param modelConfig * model config * @param modelPath * the path to store model * @param fs * file system used to store model * @param gbtConvertToProb * convert gbt score to prob or not * @return model object or null if no modelPath file, * * @throws IOException * if loading file for any IOException */ public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs, boolean gbtConvertToProb) throws IOException { if(!fs.exists(modelPath)) { // no such existing model, return null. return null; } // we have to register PersistBasicFloatNetwork for loading such models PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork()); FSDataInputStream stream = null; BufferedReader br = null; try { stream = fs.open(modelPath); if(modelPath.getName().endsWith(LogisticRegressionContants.LR_ALG_NAME.toLowerCase())) { br = new BufferedReader(new InputStreamReader(stream)); return LR.loadFromString(br.readLine()); } else if(modelPath.getName().endsWith(CommonConstants.RF_ALG_NAME.toLowerCase()) || modelPath.getName().endsWith(CommonConstants.GBT_ALG_NAME.toLowerCase())) { return TreeModel.loadFromStream(stream, gbtConvertToProb); } else { return BasicML.class.cast(EncogDirectoryPersistence.loadObject(stream)); } } catch (Exception e) { String msg = "the expecting model file is: " + modelPath; throw new ShifuException(ShifuErrorCode.ERROR_FAIL_TO_LOAD_MODEL_FILE, e, msg); } finally { if(br != null) { IOUtils.closeQuietly(br); } if(stream != null) { IOUtils.closeQuietly(stream); } } } /** * Find the model files for some @ModelConfig. There is a little tricky about this function. * If @EvalConfig is specified, try to load the models according setting in @EvalConfig, * or if @EvalConfig is null or ModelsPath is blank, Shifu will try to load models under `models` * directory * * @param modelConfig * - @ModelConfig, need this, since the model file may exist in HDFS * * @param evalConfig * - @EvalConfig, maybe null * * @param sourceType * - Where is file system * * @return - @FileStatus array for all found models * * @throws IOException * io exception to load files */ public static List<FileStatus> findModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException { FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); PathFinder pathFinder = new PathFinder(modelConfig); // If the algorithm in ModelConfig is NN, we only load NN models // the same as SVM, LR String modelSuffix = "." + modelConfig.getAlgorithm().toLowerCase(); List<FileStatus> fileList = new ArrayList<FileStatus>(); if(null == evalConfig || StringUtils.isBlank(evalConfig.getModelsPath())) { Path path = new Path(pathFinder.getModelsPath(sourceType)); fileList.addAll(Arrays.asList(fs.listStatus(path, new FileSuffixPathFilter(modelSuffix)))); } else { String modelsPath = evalConfig.getModelsPath(); FileStatus[] expandedPaths = fs.globStatus(new Path(modelsPath)); if(ArrayUtils.isNotEmpty(expandedPaths)) { for(FileStatus epath: expandedPaths) { fileList.addAll(Arrays.asList(fs.listStatus(epath.getPath(), new FileSuffixPathFilter(modelSuffix)))); } } } return fileList; } @SuppressWarnings("deprecation") public static List<ModelSpec> loadSubModels(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig, SourceType sourceType, Boolean gbtConvertToProb) { List<ModelSpec> modelSpecs = new ArrayList<ModelSpec>(); FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); // we have to register PersistBasicFloatNetwork for loading such models PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork()); PathFinder pathFinder = new PathFinder(modelConfig); String modelsPath = null; if(evalConfig == null || StringUtils.isEmpty(evalConfig.getModelsPath())) { modelsPath = pathFinder.getModelsPath(sourceType); } else { modelsPath = evalConfig.getModelsPath(); } try { FileStatus[] fsArr = fs.listStatus(new Path(modelsPath)); for(FileStatus fileStatus: fsArr) { if(fileStatus.isDir()) { ModelSpec modelSpec = loadSubModelSpec(modelConfig, columnConfigList, fileStatus, sourceType, gbtConvertToProb); if(modelSpec != null) { modelSpecs.add(modelSpec); } } } } catch (IOException e) { log.error("Error occurred when loading sub-models.", e); } return modelSpecs; } private static ModelSpec loadSubModelSpec(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, FileStatus fileStatus, SourceType sourceType, Boolean gbtConvertToProb) throws IOException { FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); String subModelName = fileStatus.getPath().getName(); List<FileStatus> modelFileStats = new ArrayList<FileStatus>(); FileStatus[] subConfigs = new FileStatus[2]; ALGORITHM algorithm = getModelsAlgAndSpecFiles(fileStatus, sourceType, modelFileStats, subConfigs); ModelSpec modelSpec = null; if(CollectionUtils.isNotEmpty(modelFileStats)) { Collections.sort(modelFileStats, new Comparator<FileStatus>() { @Override public int compare(FileStatus fa, FileStatus fb) { return fa.getPath().getName().compareTo(fb.getPath().getName()); } }); List<BasicML> models = new ArrayList<BasicML>(); for(FileStatus f: modelFileStats) { models.add(loadModel(modelConfig, f.getPath(), fs, gbtConvertToProb)); } ModelConfig subModelConfig = modelConfig; if(subConfigs[0] != null) { subModelConfig = CommonUtils.loadModelConfig(subConfigs[0].getPath().toString(), sourceType); } List<ColumnConfig> subColumnConfigList = columnConfigList; if(subConfigs[1] != null) { subColumnConfigList = CommonUtils.loadColumnConfigList(subConfigs[1].getPath().toString(), sourceType); } modelSpec = new ModelSpec(subModelName, subModelConfig, subColumnConfigList, algorithm, models); } return modelSpec; } @SuppressWarnings("deprecation") public static ALGORITHM getModelsAlgAndSpecFiles(FileStatus fileStatus, SourceType sourceType, List<FileStatus> modelFileStats, FileStatus[] subConfigs) throws IOException { assert modelFileStats != null; FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); ALGORITHM algorithm = null; FileStatus[] fileStatsArr = fs.listStatus(fileStatus.getPath()); if(fileStatsArr != null) { for(FileStatus fls: fileStatsArr) { if(!fls.isDir()) { String fileName = fls.getPath().getName(); if(algorithm == null) { if(fileName.endsWith("." + ALGORITHM.NN.name().toLowerCase())) { algorithm = ALGORITHM.NN; } else if(fileName.endsWith("." + ALGORITHM.LR.name().toLowerCase())) { algorithm = ALGORITHM.LR; } else if(fileName.endsWith("." + ALGORITHM.GBT.name().toLowerCase())) { algorithm = ALGORITHM.GBT; } } if(algorithm != null && fileName.endsWith("." + algorithm.name().toLowerCase())) { modelFileStats.add(fls); } if(fileName.equalsIgnoreCase(Constants.MODEL_CONFIG_JSON_FILE_NAME)) { subConfigs[0] = fls; } else if(fileName.equalsIgnoreCase(Constants.COLUMN_CONFIG_JSON_FILE_NAME)) { subConfigs[1] = fls; } } } } return algorithm; } @SuppressWarnings("deprecation") public static Map<String, Integer> getSubModelsCnt(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig, SourceType sourceType) throws IOException { FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType); PathFinder pathFinder = new PathFinder(modelConfig); String modelsPath = null; if(evalConfig == null || StringUtils.isEmpty(evalConfig.getModelsPath())) { modelsPath = pathFinder.getModelsPath(sourceType); } else { modelsPath = evalConfig.getModelsPath(); } Map<String, Integer> subModelsCnt = new TreeMap<String, Integer>(); try { FileStatus[] fsArr = fs.listStatus(new Path(modelsPath)); for(FileStatus fileStatus: fsArr) { if(fileStatus.isDir()) { List<FileStatus> subModelSpecFiles = new ArrayList<FileStatus>(); getModelsAlgAndSpecFiles(fileStatus, sourceType, subModelSpecFiles, new FileStatus[2]); if(CollectionUtils.isNotEmpty(subModelSpecFiles)) { subModelsCnt.put(fileStatus.getPath().getName(), subModelSpecFiles.size()); } } } } catch (IOException e) { log.error("Error occurred when finnding sub-models.", e); } return subModelsCnt; } public static class FileSuffixPathFilter implements PathFilter { private String fileSuffix; public FileSuffixPathFilter(String fileSuffix) { this.fileSuffix = fileSuffix; } @Override public boolean accept(Path path) { return path.getName().endsWith(fileSuffix); } } public static List<BasicML> loadBasicModels(final String modelsPath, final ALGORITHM alg) throws IOException { return loadBasicModels(modelsPath, alg, false); } /** * Load neural network models from specified file path * * @param modelsPath * - a file or directory that contains .nn files * @param alg * the algorithm * @param isConvertToProb * if convert to prob for gbt model * @return - a list of @BasicML * * @throws IOException * - throw exception when loading model files */ public static List<BasicML> loadBasicModels(final String modelsPath, final ALGORITHM alg, boolean isConvertToProb) throws IOException { if(modelsPath == null || alg == null || ALGORITHM.DT.equals(alg)) { throw new IllegalArgumentException("The model path shouldn't be null"); } // we have to register PersistBasicFloatNetwork for loading such models if(ALGORITHM.NN.equals(alg)) { PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork()); } File modelsPathDir = new File(modelsPath); File[] modelFiles = modelsPathDir.listFiles(new FilenameFilter() { @Override public boolean accept(File dir, String name) { return name.endsWith("." + alg.name().toLowerCase()); } }); if(modelFiles != null) { // sort file names Arrays.sort(modelFiles, new Comparator<File>() { @Override public int compare(File from, File to) { return from.getName().compareTo(to.getName()); } }); List<BasicML> models = new ArrayList<BasicML>(modelFiles.length); for(File nnf: modelFiles) { InputStream is = null; try { is = new FileInputStream(nnf); if(ALGORITHM.NN.equals(alg)) { models.add(BasicML.class.cast(EncogDirectoryPersistence.loadObject(is))); } else if(ALGORITHM.LR.equals(alg)) { models.add(LR.loadFromStream(is)); } else if(ALGORITHM.GBT.equals(alg) || ALGORITHM.RF.equals(alg)) { models.add(TreeModel.loadFromStream(is, isConvertToProb)); } } finally { IOUtils.closeQuietly(is); } } return models; } else { throw new IOException(String.format("Failed to list files in %s", modelsPathDir.getAbsolutePath())); } } /** * Return one HashMap Object contains keys in the first parameter, values in the second parameter. Before calling * this method, you should be aware that headers should be unique. * * @param header * - header that contains column name * @param data * - raw data * @return key-value map for variable */ public static Map<String, String> getRawDataMap(String[] header, String[] data) { if(header.length != data.length) { throw new IllegalArgumentException(String.format("Header/Data mismatch: Header length %s, Data length %s", header.length, data.length)); } Map<String, String> rawDataMap = new HashMap<String, String>(header.length); for(int i = 0; i < header.length; i++) { rawDataMap.put(header[i], data[i]); } return rawDataMap; } /** * Return all parameters for pig execution. * * @param modelConfig * model config * @param sourceType * source type * @return map of configurations * @throws IOException * any io exception * @throws IllegalArgumentException * if modelConfig is null. */ public static Map<String, String> getPigParamMap(ModelConfig modelConfig, SourceType sourceType) throws IOException { if(modelConfig == null) { throw new IllegalArgumentException("modelConfig should not be null."); } PathFinder pathFinder = new PathFinder(modelConfig); Map<String, String> pigParamMap = new HashMap<String, String>(); pigParamMap.put(Constants.NUM_PARALLEL, Environment.getInt(Environment.HADOOP_NUM_PARALLEL, 400).toString()); log.info("jar path is {}", pathFinder.getJarPath()); pigParamMap.put(Constants.PATH_JAR, pathFinder.getJarPath()); pigParamMap.put(Constants.PATH_RAW_DATA, modelConfig.getDataSetRawPath()); pigParamMap.put(Constants.PATH_NORMALIZED_DATA, pathFinder.getNormalizedDataPath(sourceType)); // default norm is not for clean, so set it to false, this will be overrided in Train#Norm for tree models pigParamMap.put(Constants.IS_NORM_FOR_CLEAN, Boolean.FALSE.toString()); pigParamMap.put(Constants.PATH_PRE_TRAINING_STATS, pathFinder.getPreTrainingStatsPath(sourceType)); pigParamMap.put(Constants.PATH_STATS_BINNING_INFO, pathFinder.getUpdatedBinningInfoPath(sourceType)); pigParamMap.put(Constants.PATH_STATS_PSI_INFO, pathFinder.getPSIInfoPath(sourceType)); pigParamMap.put(Constants.WITH_SCORE, Boolean.FALSE.toString()); pigParamMap.put(Constants.STATS_SAMPLE_RATE, modelConfig.getBinningSampleRate().toString()); pigParamMap.put(Constants.PATH_MODEL_CONFIG, pathFinder.getModelConfigPath(sourceType)); pigParamMap.put(Constants.PATH_COLUMN_CONFIG, pathFinder.getColumnConfigPath(sourceType)); pigParamMap.put(Constants.PATH_SELECTED_RAW_DATA, pathFinder.getSelectedRawDataPath(sourceType)); pigParamMap.put(Constants.PATH_BIN_AVG_SCORE, pathFinder.getBinAvgScorePath(sourceType)); pigParamMap.put(Constants.PATH_TRAIN_SCORE, pathFinder.getTrainScoresPath(sourceType)); pigParamMap.put(Constants.SOURCE_TYPE, sourceType.toString()); pigParamMap.put(Constants.JOB_QUEUE, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, Constants.DEFAULT_JOB_QUEUE)); return pigParamMap; } /** * Return all parameters for pig execution. * * @param modelConfig * model config * @param sourceType * source type * @param pathFinder * path finder instance * @return map of configurations * @throws IOException * any io exception * @throws IllegalArgumentException * if modelConfig is null. */ public static Map<String, String> getPigParamMap(ModelConfig modelConfig, SourceType sourceType, PathFinder pathFinder) throws IOException { if(modelConfig == null) { throw new IllegalArgumentException("modelConfig should not be null."); } if(pathFinder == null) { pathFinder = new PathFinder(modelConfig); } Map<String, String> pigParamMap = new HashMap<String, String>(); pigParamMap.put(Constants.NUM_PARALLEL, Environment.getInt(Environment.HADOOP_NUM_PARALLEL, 400).toString()); log.info("jar path is {}", pathFinder.getJarPath()); pigParamMap.put(Constants.PATH_JAR, pathFinder.getJarPath()); pigParamMap.put(Constants.PATH_RAW_DATA, modelConfig.getDataSetRawPath()); pigParamMap.put(Constants.PATH_NORMALIZED_DATA, pathFinder.getNormalizedDataPath(sourceType)); pigParamMap.put(Constants.PATH_PRE_TRAINING_STATS, pathFinder.getPreTrainingStatsPath(sourceType)); pigParamMap.put(Constants.PATH_STATS_BINNING_INFO, pathFinder.getUpdatedBinningInfoPath(sourceType)); pigParamMap.put(Constants.PATH_STATS_PSI_INFO, pathFinder.getPSIInfoPath(sourceType)); pigParamMap.put(Constants.WITH_SCORE, Boolean.FALSE.toString()); pigParamMap.put(Constants.STATS_SAMPLE_RATE, modelConfig.getBinningSampleRate().toString()); pigParamMap.put(Constants.PATH_MODEL_CONFIG, pathFinder.getModelConfigPath(sourceType)); pigParamMap.put(Constants.PATH_COLUMN_CONFIG, pathFinder.getColumnConfigPath(sourceType)); pigParamMap.put(Constants.PATH_SELECTED_RAW_DATA, pathFinder.getSelectedRawDataPath(sourceType)); pigParamMap.put(Constants.PATH_BIN_AVG_SCORE, pathFinder.getBinAvgScorePath(sourceType)); pigParamMap.put(Constants.PATH_TRAIN_SCORE, pathFinder.getTrainScoresPath(sourceType)); pigParamMap.put(Constants.SOURCE_TYPE, sourceType.toString()); pigParamMap.put(Constants.JOB_QUEUE, Environment.getProperty(Environment.HADOOP_JOB_QUEUE, Constants.DEFAULT_JOB_QUEUE)); pigParamMap.put(Constants.DATASET_NAME, modelConfig.getBasic().getName()); return pigParamMap; } /** * Change list str to List object with double type. * * @param str * str to be split * @return list of double * @throws IllegalArgumentException * if str is not a valid list str. */ public static List<Double> stringToDoubleList(String str) { List<String> list = checkAndReturnSplitCollections(str); return Lists.transform(list, new Function<String, Double>() { @Override public Double apply(String input) { return Double.valueOf(input.trim()); } }); } private static List<String> checkAndReturnSplitCollections(String str) { checkListStr(str); return Arrays.asList(str.trim().substring(1, str.length() - 1).split(Constants.COMMA)); } private static List<String> checkAndReturnSplitCollections(String str, char separator) { checkListStr(str); return Arrays.asList(StringUtils.split(str.trim().substring(1, str.length() - 1), separator)); } private static void checkListStr(String str) { if(StringUtils.isEmpty(str)) { throw new IllegalArgumentException("str should not be null or empty"); } if(!str.startsWith("[") || !str.endsWith("]")) { throw new IllegalArgumentException("Invalid list string format, should be like '[1,2,3]'"); } } /** * Change list str to List object with int type. * * @param str * str to be split * @return list of int * @throws IllegalArgumentException * if str is not a valid list str. */ public static List<Integer> stringToIntegerList(String str) { List<String> list = checkAndReturnSplitCollections(str); return Lists.transform(list, new Function<String, Integer>() { @Override public Integer apply(String input) { return Integer.valueOf(input.trim()); } }); } /** * Change list str to List object with string type. * * @param str * str to be split * @return list of string * @throws IllegalArgumentException * if str is not a valid list str. */ public static List<String> stringToStringList(String str) { List<String> list = checkAndReturnSplitCollections(str); return Lists.transform(list, new Function<String, String>() { @Override public String apply(String input) { return input.trim(); } }); } /** * Change list str to List object with string type. * * @param str * str to be split * @param separator * the separator * @return list of string * @throws IllegalArgumentException * if str is not a valid list str. */ public static List<String> stringToStringList(String str, char separator) { List<String> list = checkAndReturnSplitCollections(str, separator); return Lists.transform(list, new Function<String, String>() { @Override public String apply(String input) { return input.trim(); } }); } /* * Return map entries sorted by value. */ public static <K, V extends Comparable<V>> List<Map.Entry<K, V>> getEntriesSortedByValues(Map<K, V> map) { List<Map.Entry<K, V>> entries = new LinkedList<Map.Entry<K, V>>(map.entrySet()); Collections.sort(entries, new Comparator<Map.Entry<K, V>>() { @Override public int compare(Entry<K, V> o1, Entry<K, V> o2) { return o1.getValue().compareTo(o2.getValue()); } }); return entries; } /** * Assemble map data to Encog standard input format with default cut off value. * * @param modelConfig * model config instance * @param columnConfigList * column config list * @param rawDataMap * raw data * @return data pair instance * @throws NullPointerException * if input is null * @throws NumberFormatException * if column value is not number format. */ public static MLDataPair assembleDataPair(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, ? extends Object> rawDataMap) { return assembleDataPair(modelConfig, columnConfigList, rawDataMap, Constants.DEFAULT_CUT_OFF); } /** * Assemble map data to Encog standard input format. If no variable selected(noVarSel = true), all candidate * variables will be selected. * * @param binCategoryMap * categorical map * @param noVarSel * if after var select * @param modelConfig * model config instance * @param columnConfigList * column config list * @param rawDataMap * raw data * @param cutoff * cut off value * @param alg * algorithm used in model * @return data pair instance * @throws NullPointerException * if input is null * @throws NumberFormatException * if column value is not number format. */ public static MLDataPair assembleDataPair(Map<Integer, Map<String, Integer>> binCategoryMap, boolean noVarSel, ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, ? extends Object> rawDataMap, double cutoff, String alg) { return assembleNsDataPair(binCategoryMap, noVarSel, modelConfig, columnConfigList, convertRawObjectMapToNsDataMap(rawDataMap), cutoff, alg); } /** * Assemble map data to Encog standard input format. If no variable selected(noVarSel = true), all candidate * variables will be selected. * * @param binCategoryMap * categorical map * @param noVarSel * if after var select * @param modelConfig * model config instance * @param columnConfigList * column config list * @param rawNsDataMap * raw NSColumn data * @param cutoff * cut off value * @param alg * algorithm used in model * @return data pair instance * @throws NullPointerException * if input is null * @throws NumberFormatException * if column value is not number format. */ public static MLDataPair assembleNsDataPair(Map<Integer, Map<String, Integer>> binCategoryMap, boolean noVarSel, ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<NSColumn, String> rawNsDataMap, double cutoff, String alg) { double[] ideal = { Constants.DEFAULT_IDEAL_VALUE }; List<Double> inputList = new ArrayList<Double>(); for(ColumnConfig config: columnConfigList) { if(config == null) { continue; } NSColumn key = new NSColumn(config.getColumnName()); if(config.isFinalSelect() && !rawNsDataMap.containsKey(key)) { throw new IllegalStateException(String.format("Variable Missing in Test Data: %s", key)); } if(config.isTarget()) { continue; } else { if(!noVarSel) { if(config != null && !config.isMeta() && !config.isTarget() && config.isFinalSelect()) { String val = rawNsDataMap.get(key) == null ? null : rawNsDataMap.get(key).toString(); if(CommonUtils.isTreeModel(alg) && config.isCategorical()) { Integer index = binCategoryMap.get(config.getColumnNum()).get(val == null ? "" : val); if(index == null) { // not in binCategories, should be missing value // -1 as missing value inputList.add(-1d); } else { inputList.add(index * 1d); } } else { inputList.add(computeNumericNormResult(modelConfig, cutoff, config, val)); } } } else { if(!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config)) { String val = rawNsDataMap.get(key) == null ? null : rawNsDataMap.get(key).toString(); if(CommonUtils.isTreeModel(alg) && config.isCategorical()) { Integer index = binCategoryMap.get(config.getColumnNum()).get(val == null ? "" : val); if(index == null) { // not in binCategories, should be missing value // -1 as missing value inputList.add(-1d); } else { inputList.add(index * 1d); } } else { inputList.add(computeNumericNormResult(modelConfig, cutoff, config, val)); } } } } } // god, Double [] cannot be casted to double[], toArray doesn't work int size = inputList.size(); double[] input = new double[size]; for(int i = 0; i < size; i++) { input[i] = inputList.get(i); } return new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); } /** * Assemble map data to Encog standard input format. If no variable selected(noVarSel = true), all candidate * variables will be selected. * * @param binCategoryMap * categorical map * @param noVarSel * if after var select * @param modelConfig * model config instance * @param columnConfigList * column config list * @param rawNsDataMap * raw NSColumn data * @param cutoff * cut off value * @param alg * algorithm used in model * @param featureSet * feature set used in NN model * @return data pair instance * @throws NullPointerException * if input is null * @throws NumberFormatException * if column value is not number format. */ public static MLDataPair assembleNsDataPair(Map<Integer, Map<String, Integer>> binCategoryMap, boolean noVarSel, ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<NSColumn, String> rawNsDataMap, double cutoff, String alg, Set<Integer> featureSet) { if(featureSet == null || featureSet.size() == 0) { return assembleNsDataPair(binCategoryMap, noVarSel, modelConfig, columnConfigList, rawNsDataMap, cutoff, alg); } double[] ideal = { Constants.DEFAULT_IDEAL_VALUE }; List<Double> inputList = new ArrayList<Double>(); for(ColumnConfig config: columnConfigList) { if(config == null) { continue; } NSColumn key = new NSColumn(config.getColumnName()); if(config.isFinalSelect() && !rawNsDataMap.containsKey(key)) { throw new IllegalStateException(String.format("Variable Missing in Test Data: %s", key)); } if(config.isTarget()) { continue; } else { if(featureSet.contains(config.getColumnNum())) { String val = rawNsDataMap.get(key) == null ? null : rawNsDataMap.get(key).toString(); if(CommonUtils.isTreeModel(alg) && config.isCategorical()) { Integer index = binCategoryMap.get(config.getColumnNum()).get(val == null ? "" : val); if(index == null) { // not in binCategories, should be missing value -1 as missing value inputList.add(-1d); } else { inputList.add(index * 1d); } } else { inputList.add(computeNumericNormResult(modelConfig, cutoff, config, val)); } } } } // god, Double [] cannot be casted to double[], toArray doesn't work int size = inputList.size(); double[] input = new double[size]; for(int i = 0; i < size; i++) { input[i] = inputList.get(i); } return new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); } public static List<Integer> getAllFeatureList(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) { List<Integer> features = new ArrayList<Integer>(); for(ColumnConfig config: columnConfigList) { if(isAfterVarSelect) { if(config.isFinalSelect() && !config.isTarget() && !config.isMeta()) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } else { if(!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config)) { // only select numerical feature with getBinBoundary().size() larger than 1 // or categorical feature with getBinCategory().size() larger than 0 if((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) { features.add(config.getColumnNum()); } } } } return features; } private static double computeNumericNormResult(ModelConfig modelConfig, double cutoff, ColumnConfig config, String val) { Double normalizeValue = null; if(CommonUtils.isTreeModel(modelConfig.getAlgorithm())) { try { normalizeValue = Double.parseDouble(val); } catch (Exception e) { normalizeValue = Normalizer.defaultMissingValue(config); } } else { normalizeValue = Normalizer.normalize(config, val, cutoff, modelConfig.getNormalizeType()); } return normalizeValue; } public static boolean isTreeModel(String alg) { return CommonConstants.RF_ALG_NAME.equalsIgnoreCase(alg) || CommonConstants.GBT_ALG_NAME.equalsIgnoreCase(alg); } public static boolean isRandomForestAlgorithm(String alg) { return CommonConstants.RF_ALG_NAME.equalsIgnoreCase(alg); } public static boolean isGBDTAlgorithm(String alg) { return CommonConstants.GBT_ALG_NAME.equalsIgnoreCase(alg); } public static boolean isHadoopConfigurationInjected(String key) { return key.startsWith("nn") || key.startsWith("guagua") || key.startsWith("shifu") || key.startsWith("mapred") || key.startsWith("io") || key.startsWith("hadoop") || key.startsWith("yarn") || key.startsWith("pig") || key.startsWith("hive") || key.startsWith("job"); } /** * Assemble map data to Encog standard input format. * * @param modelConfig * - ModelConfig * @param columnConfigList * - ColumnConfig list * @param rawDataMap * - raw input key-value map * @param cutoff * - cutoff value when normalization * @return * - input data pair for neural network */ public static MLDataPair assembleDataPair(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, ? extends Object> rawDataMap, double cutoff) { Map<NSColumn, Object> nsDataMap = new HashMap<NSColumn, Object>(); for(Entry<String, ? extends Object> entry: rawDataMap.entrySet()) { nsDataMap.put(new NSColumn(entry.getKey()), entry.getValue()); } // if the tag is provided, ideal will be updated; otherwise it defaults to -1 double[] ideal = { Constants.DEFAULT_IDEAL_VALUE }; List<Double> inputList = new ArrayList<Double>(); for(ColumnConfig config: columnConfigList) { NSColumn key = new NSColumn(config.getColumnName()); if(config.isFinalSelect() && !nsDataMap.containsKey(key)) { throw new IllegalStateException(String.format("Variable Missing in Test Data: %s", key)); } if(config.isTarget()) { // TODO - should we have this? maybe not // ideal[0] = Double.valueOf(rawDataMap.get(key).toString()); continue; } else if(config.isFinalSelect()) { // add log for debug purpose // log.info("key: " + key + ", raw_value " + rawDataMap.get(key).toString() + ", zscl_value: " + String val = nsDataMap.get(key) == null ? null : nsDataMap.get(key).toString(); Double normalizeValue = Normalizer.normalize(config, val, cutoff, modelConfig.getNormalizeType()); inputList.add(normalizeValue); } } // god, Double [] cannot be casted to double[], toArray doesn't work int size = inputList.size(); double[] input = new double[size]; for(int i = 0; i < size; i++) { input[i] = inputList.get(i); } return new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); } /* * Expanding score by expandingFactor */ public static long getExpandingScore(double d, int expandingFactor) { return Math.round(d * expandingFactor); } /** * Return column name string with 'derived_' started * * @param columnConfigList * list of column config * @return list of column names * @throws NullPointerException * if modelConfig is null or columnConfigList is null. */ public static List<String> getDerivedColumnNames(List<ColumnConfig> columnConfigList) { List<String> derivedColumnNames = new ArrayList<String>(); for(ColumnConfig config: columnConfigList) { if(config.getColumnName().startsWith(Constants.DERIVED)) { derivedColumnNames.add(config.getColumnName()); } } return derivedColumnNames; } /** * Get the file separator regex * * @return "/" - if the OS is Linux * "\\\\" - if the OS is Windows */ public static String getPathSeparatorRegx() { if(File.separator.equals(Constants.SLASH)) { return File.separator; } else { return Constants.BACK_SLASH + File.separator; } } /** * Update target, listMeta, listForceSelect, listForceRemove * * @param modelConfig * model config list * @param columnConfigList * the column config list * @throws IOException * any io exception * * @throws IllegalArgumentException * if modelConfig is null or columnConfigList is null. */ public static void updateColumnConfigFlags(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException { String targetColumnName = modelConfig.getTargetColumnName(); String weightColumnName = modelConfig.getWeightColumnName(); Set<NSColumn> setCategorialColumns = new HashSet<NSColumn>(); if(CollectionUtils.isNotEmpty(modelConfig.getCategoricalColumnNames())) { for(String column: modelConfig.getCategoricalColumnNames()) { setCategorialColumns.add(new NSColumn(column)); } } Set<NSColumn> setMeta = new HashSet<NSColumn>(); if(CollectionUtils.isNotEmpty(modelConfig.getMetaColumnNames())) { for(String meta: modelConfig.getMetaColumnNames()) { setMeta.add(new NSColumn(meta)); } } Set<NSColumn> setForceRemove = new HashSet<NSColumn>(); if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable()) && CollectionUtils.isNotEmpty(modelConfig.getListForceRemove())) { // if we need to update force remove, only and if one the force is enabled for(String forceRemoveName: modelConfig.getListForceRemove()) { setForceRemove.add(new NSColumn(forceRemoveName)); } } Set<NSColumn> setForceSelect = new HashSet<NSColumn>(512); if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable()) && CollectionUtils.isNotEmpty(modelConfig.getListForceSelect())) { // if we need to update force select, only and if one the force is enabled for(String forceSelectName: modelConfig.getListForceSelect()) { setForceSelect.add(new NSColumn(forceSelectName)); } } for(ColumnConfig config: columnConfigList) { String varName = config.getColumnName(); // reset it config.setColumnFlag(null); if(NSColumnUtils.isColumnEqual(weightColumnName, varName)) { config.setColumnFlag(ColumnFlag.Weight); config.setFinalSelect(false); // reset final select } else if(NSColumnUtils.isColumnEqual(targetColumnName, varName)) { config.setColumnFlag(ColumnFlag.Target); config.setFinalSelect(false); // reset final select } else if(setMeta.contains(new NSColumn(varName))) { config.setColumnFlag(ColumnFlag.Meta); config.setFinalSelect(false); // reset final select } else if(setForceRemove.contains(new NSColumn(varName))) { config.setColumnFlag(ColumnFlag.ForceRemove); config.setFinalSelect(false); // reset final select } else if(setForceSelect.contains(new NSColumn(varName))) { config.setColumnFlag(ColumnFlag.ForceSelect); } if(NSColumnUtils.isColumnEqual(weightColumnName, varName)) { // weight column is numerical config.setColumnType(ColumnType.N); } else if(NSColumnUtils.isColumnEqual(targetColumnName, varName)) { // target column is set to categorical column config.setColumnType(ColumnType.C); } else if(setCategorialColumns.contains(new NSColumn(varName))) { config.setColumnType(ColumnType.C); } else { config.setColumnType(ColumnType.N); } } } /** * To check whether there is targetColumn in columns or not * * @param columns * column array * @param targetColumn * target column * * @return true - if the columns contains targetColumn, or false */ public static boolean isColumnExists(String[] columns, String targetColumn) { if(ArrayUtils.isEmpty(columns) || StringUtils.isBlank(targetColumn)) { return false; } for(int i = 0; i < columns.length; i++) { if(columns[i] != null && columns[i].equalsIgnoreCase(targetColumn)) { return true; } } return false; } /** * Returns the element if it is in both collections. * - return null if any collection is null or empty * - return null if no element exists in both collections * * @param leftCol * - left collection * * @param rightCol * - right collection * @param <T> * - collection type * @return First element that are found in both collections * null if no elements in both collection or any collection is null or empty */ public static <T> T containsAny(Collection<T> leftCol, Collection<T> rightCol) { if(CollectionUtils.isEmpty(leftCol) || CollectionUtils.isEmpty(rightCol)) { return null; } Iterator<T> iterator = leftCol.iterator(); while(iterator.hasNext()) { T element = iterator.next(); if(rightCol.contains(element)) { return element; } } return null; } /** * Escape the delimiter for Pig.... Since the Pig doesn't support invisible character * * @param delimiter * - the original delimiter * @return the delimiter after escape */ public static String escapePigString(String delimiter) { StringBuffer buf = new StringBuffer(); for(int i = 0; i < delimiter.length(); i++) { char c = delimiter.charAt(i); switch(c) { case '\t': buf.append("\\\\t"); break; default: buf.append(c); break; } } return buf.toString(); } public static List<String> readConfFileIntoList(String columnConfFile, SourceType sourceType, String delimiter) throws IOException { List<String> columnNameList = new ArrayList<String>(); if(StringUtils.isBlank(columnConfFile) || !ShifuFileUtils.isFileExists(columnConfFile, sourceType)) { return columnNameList; } List<String> strList = null; Reader reader = ShifuFileUtils.getReader(columnConfFile, sourceType); try { strList = IOUtils.readLines(reader); } finally { IOUtils.closeQuietly(reader); } if(CollectionUtils.isNotEmpty(strList)) { for(String line: strList) { if(line.trim().equals("") || line.trim().startsWith("#")) { continue; } for(String str: Splitter.on(delimiter).split(line)) { // String column = CommonUtils.getRelativePigHeaderColumnName(str); if(StringUtils.isNotBlank(str)) { columnNameList.add(str.trim()); } } } } return columnNameList; } public static Map<String, Integer> generateColumnSeatMap(List<ColumnConfig> columnConfigList) { List<ColumnConfig> selectedColumnList = new ArrayList<ColumnConfig>(); for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.isFinalSelect()) { selectedColumnList.add(columnConfig); } } Collections.sort(selectedColumnList, new Comparator<ColumnConfig>() { @Override public int compare(ColumnConfig from, ColumnConfig to) { return from.getColumnName().compareTo(to.getColumnName()); } }); Map<String, Integer> columnSeatMap = new HashMap<String, Integer>(); for(int i = 0; i < selectedColumnList.size(); i++) { columnSeatMap.put(selectedColumnList.get(i).getColumnName(), i); } return columnSeatMap; } /** * Find the @ColumnConfig according the column name * * @param columnConfigList * list of column config * @param columnName * the column name * @return column config instance */ public static ColumnConfig findColumnConfigByName(List<ColumnConfig> columnConfigList, String columnName) { for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.getColumnName().equalsIgnoreCase(columnName)) { return columnConfig; } } return null; } /** * Convert data into (key, value) map. The inputData is String of a record, which is delimited by delimiter * If fields in inputData is not equal header size, return null * * @param inputData * - String of a record * @param delimiter * - the delimiter of the input data * @param header * - the column names for all the input data * @return (key, value) map for the record */ public static Map<String, String> convertDataIntoMap(String inputData, String delimiter, String[] header) { String[] input = CommonUtils.split(inputData, delimiter); if(input == null || input.length == 0 || input.length != header.length) { log.error("the wrong input data, {}", inputData); return null; } Map<String, String> rawDataMap = new HashMap<String, String>(input.length); for(int i = 0; i < header.length; i++) { if(input[i] == null) { rawDataMap.put(header[i], ""); } else { rawDataMap.put(header[i], input[i]); } } return rawDataMap; } /** * Convert tuple record into (key, value) map. The @tuple is Tuple for a record * If @tuple size is not equal @header size, return null * * @param tuple * - Tuple of a record * @param header * - the column names for all the input data * @return (key, value) map for the record * @throws ExecException * - throw exception when operating tuple */ public static Map<String, String> convertDataIntoMap(Tuple tuple, String[] header) throws ExecException { if(tuple == null || tuple.size() == 0 || tuple.size() != header.length) { log.error("Invalid input, the tuple.size is = " + (tuple == null ? null : tuple.size()) + ", header.length = " + header.length); return null; } Map<String, String> rawDataMap = new HashMap<String, String>(tuple.size()); for(int i = 0; i < header.length; i++) { if(tuple.get(i) == null) { rawDataMap.put(header[i], ""); } else { rawDataMap.put(header[i], tuple.get(i).toString()); } } return rawDataMap; } /** * Convert tuple record into (NSColumn, value) map. The @tuple is Tuple for a record * If @tuple size is not equal @header size, return null * * @param tuple * - Tuple of a record * @param header * - the column names for all the input data * @return (NSColumn, value) map for the record * @throws ExecException * - throw exception when operating tuple */ public static Map<NSColumn, String> convertDataIntoNsMap(Tuple tuple, String[] header) throws ExecException { if(tuple == null || tuple.size() == 0 || tuple.size() != header.length) { log.error("Invalid input, the tuple.size is = " + (tuple == null ? null : tuple.size()) + ", header.length = " + header.length); return null; } Map<NSColumn, String> rawDataNsMap = new HashMap<NSColumn, String>(tuple.size()); for(int i = 0; i < header.length; i++) { if(tuple.get(i) == null) { rawDataNsMap.put(new NSColumn(header[i]), ""); } else { rawDataNsMap.put(new NSColumn(header[i]), tuple.get(i).toString()); } } return rawDataNsMap; } public static boolean isGoodCandidate(boolean isBinaryClassification, ColumnConfig columnConfig) { if(columnConfig == null) { return false; } if(isBinaryClassification) { return columnConfig.isCandidate() && (columnConfig.getKs() != null && columnConfig.getKs() > 0 && columnConfig.getIv() != null && columnConfig.getIv() > 0 && columnConfig.getMean() != null && columnConfig.getStdDev() != null && ((columnConfig.isCategorical() && columnConfig.getBinCategory() != null && columnConfig.getBinCategory().size() > 1) || (columnConfig .isNumerical() && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary() .size() > 1))); } else { // multiple classification return columnConfig.isCandidate() && (columnConfig.getMean() != null && columnConfig.getStdDev() != null && ((columnConfig .isCategorical() && columnConfig.getBinCategory() != null && columnConfig.getBinCategory() .size() > 1) || (columnConfig.isNumerical() && columnConfig.getBinBoundary() != null && columnConfig .getBinBoundary().size() > 1))); } } public static boolean isGoodCandidate(ColumnConfig columnConfig) { if(columnConfig == null) { return false; } return columnConfig.isCandidate() && (columnConfig.getKs() != null && columnConfig.getKs() > 0 && columnConfig.getIv() != null && columnConfig.getIv() > 0 && columnConfig.getMean() != null && columnConfig.getStdDev() != null && ((columnConfig.isCategorical() && columnConfig.getBinCategory() != null && columnConfig.getBinCategory().size() > 1) || (columnConfig .isNumerical() && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary().size() > 1))); } /** * Return first line split string array. This is used to detect data schema. * * @param dataSetRawPath * raw data path * @param delimeter * the delimiter * @param source * source type * @return the first two lines * @throws IOException * any io exception */ public static String[] takeFirstLine(String dataSetRawPath, String delimeter, SourceType source) throws IOException { if(dataSetRawPath == null || delimeter == null || source == null) { throw new IllegalArgumentException("Input parameters should not be null."); } String firstValidFile = null; FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(source); FileStatus[] globStatus = fs.globStatus(new Path(dataSetRawPath), HIDDEN_FILE_FILTER); if(globStatus == null || globStatus.length == 0) { throw new IllegalArgumentException("No files founded in " + dataSetRawPath); } else { for(FileStatus fileStatus: globStatus) { RemoteIterator<LocatedFileStatus> iterator = fs.listFiles(fileStatus.getPath(), true); while(iterator.hasNext()) { LocatedFileStatus lfs = iterator.next(); String name = lfs.getPath().getName(); if(name.startsWith("_") || name.startsWith(".")) { // hidden files, continue; } if(lfs.getLen() > 1024L) { firstValidFile = lfs.getPath().toString(); break; } } if(StringUtils.isNotBlank(firstValidFile)) { break; } } } log.info("The first valid file is - {}", firstValidFile); BufferedReader reader = null; try { reader = ShifuFileUtils.getReader(firstValidFile, source); String firstLine = reader.readLine(); log.debug("The first line is - {}", firstLine); if(firstLine != null && firstLine.length() > 0) { List<String> list = new ArrayList<String>(); for(String unit: Splitter.on(delimeter).split(firstLine)) { list.add(unit); } return list.toArray(new String[0]); } } catch (Exception e) { log.error("Fail to read first line of file.", e); } finally { IOUtils.closeQuietly(reader); } return new String[0]; } /** * Return first two lines split string array. This is used to detect data schema and check if data * schema is the * same as data. * * @param dataSetRawPath * raw data path * @param delimiter * the delimiter * @param source * source type * @return the first two lines * @throws IOException * any io exception */ public static String[][] takeFirstTwoLines(String dataSetRawPath, String delimiter, SourceType source) throws IOException { if(dataSetRawPath == null || delimiter == null || source == null) { throw new IllegalArgumentException("Input parameters should not be null."); } String firstValidFile = null; FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(source); FileStatus[] globStatus = fs.globStatus(new Path(dataSetRawPath), HIDDEN_FILE_FILTER); if(globStatus == null || globStatus.length == 0) { throw new IllegalArgumentException("No files founded in " + dataSetRawPath); } else { for(FileStatus fileStatus: globStatus) { RemoteIterator<LocatedFileStatus> iterator = fs.listFiles(fileStatus.getPath(), true); while(iterator.hasNext()) { LocatedFileStatus lfs = iterator.next(); String name = lfs.getPath().getName(); if(name.startsWith("_") || name.startsWith(".")) { // hidden files, continue; } if(lfs.getLen() > 1024L) { firstValidFile = lfs.getPath().toString(); break; } } if(StringUtils.isNotBlank(firstValidFile)) { break; } } } log.info("The first valid file is - {}", firstValidFile); BufferedReader reader = null; try { reader = ShifuFileUtils.getReader(firstValidFile, source); String firstLine = reader.readLine(); String[] firstArray = null; if(firstLine != null && firstLine.length() > 0) { List<String> list = new ArrayList<String>(); for(String unit: Splitter.on(delimiter).split(firstLine)) { list.add(unit); } firstArray = list.toArray(new String[0]); } String secondLine = reader.readLine(); String[] secondArray = null; if(secondLine != null && secondLine.length() > 0) { List<String> list = new ArrayList<String>(); for(String unit: Splitter.on(delimiter).split(secondLine)) { list.add(unit); } secondArray = list.toArray(new String[0]); } String[][] results = new String[2][]; results[0] = firstArray; results[1] = secondArray; return results; } finally { IOUtils.closeQuietly(reader); } } private static final PathFilter HIDDEN_FILE_FILTER = new PathFilter() { public boolean accept(Path p) { String name = p.getName(); return !name.startsWith("_") && !name.startsWith("."); } }; public static String genPigFieldName(String name) { return ((name != null) ? name.replace('-', '_') : null); } public static String[] genPigFieldName(String[] names) { String[] pigScoreNames = new String[names.length]; for(int i = 0; i < names.length; i++) { pigScoreNames[i] = genPigFieldName(names[i]) + "::mean"; } return pigScoreNames; } public static Map<Integer, MutablePair<String, Double>> computeTreeModelFeatureImportance(List<BasicML> models) { List<Map<Integer, MutablePair<String, Double>>> importanceList = new ArrayList<Map<Integer, MutablePair<String, Double>>>(); for(BasicML basicModel: models) { if(basicModel instanceof TreeModel) { TreeModel model = (TreeModel) basicModel; Map<Integer, MutablePair<String, Double>> importances = model.getFeatureImportances(); importanceList.add(importances); } } if(importanceList.size() < 1) { throw new IllegalArgumentException("Feature importance calculation abort due to no tree model found!!"); } return mergeImportanceList(importanceList); } private static Map<Integer, MutablePair<String, Double>> mergeImportanceList( List<Map<Integer, MutablePair<String, Double>>> list) { Map<Integer, MutablePair<String, Double>> finalResult = new HashMap<Integer, MutablePair<String, Double>>(); int size = list.size(); for(Map<Integer, MutablePair<String, Double>> item: list) { for(Entry<Integer, MutablePair<String, Double>> entry: item.entrySet()) { if(!finalResult.containsKey(entry.getKey())) { MutablePair<String, Double> value = MutablePair.of(entry.getValue().getKey(), entry.getValue() .getValue() / size); finalResult.put(entry.getKey(), value); } else { MutablePair<String, Double> current = finalResult.get(entry.getKey()); double entryValue = entry.getValue().getValue(); current.setValue(current.getValue() + entryValue / size); finalResult.put(entry.getKey(), current); } } } return TreeModel.sortByValue(finalResult, false); } public static void writeFeatureImportance(String fiPath, Map<Integer, MutablePair<String, Double>> importances) throws IOException { ShifuFileUtils.createFileIfNotExists(fiPath, SourceType.LOCAL); BufferedWriter writer = null; log.info("Writing feature importances to file {}", fiPath); try { writer = ShifuFileUtils.getWriter(fiPath, SourceType.LOCAL); writer.write("column_id\t\tcolumn_name\t\timportance"); writer.newLine(); for(Map.Entry<Integer, MutablePair<String, Double>> entry: importances.entrySet()) { String content = entry.getKey() + "\t\t" + entry.getValue().getKey() + "\t\t" + entry.getValue().getValue(); writer.write(content); writer.newLine(); } writer.flush(); } finally { IOUtils.closeQuietly(writer); } } public static String trimTag(String tag) { if(NumberUtils.isNumber(tag)) { tag = tag.trim(); int firstPeriodPos = -1; int firstDeleteZero = -1; boolean hasMetNonZero = false; for(int i = tag.length(); i > 0; i--) { if((tag.charAt(i - 1) == '0' || tag.charAt(i - 1) == '.') && !hasMetNonZero) { firstDeleteZero = i - 1; } if(tag.charAt(i - 1) != '0') { hasMetNonZero = true; } if(tag.charAt(i - 1) == '.') { firstPeriodPos = i - 1; } } String result = (firstDeleteZero >= 0 && firstPeriodPos >= 0) ? tag.substring(0, firstDeleteZero) : tag; return (firstPeriodPos == 0) ? "0" + result : result; } else { return StringUtils.trimToEmpty(tag); } } /** * Convert (String, String) raw data map to (NSColumn, String) data map * * @param rawDataMap * - (String, String) raw data map * @return (NSColumn, String) data map */ public static Map<NSColumn, String> convertRawMapToNsDataMap(Map<String, String> rawDataMap) { if(rawDataMap == null) { return null; } Map<NSColumn, String> nsDataMap = new HashMap<NSColumn, String>(); for(String key: rawDataMap.keySet()) { nsDataMap.put(new NSColumn(key), rawDataMap.get(key)); } return nsDataMap; } /** * Convert (String, ? extends Object) raw data map to (NSColumn, String) data map * * @param rawDataMap * - (String, ? extends Object) raw data map * @return (NSColumn, String) data map */ public static Map<NSColumn, String> convertRawObjectMapToNsDataMap(Map<String, ? extends Object> rawDataMap) { if(rawDataMap == null) { return null; } Map<NSColumn, String> nsDataMap = new HashMap<NSColumn, String>(); for(String key: rawDataMap.keySet()) { Object value = rawDataMap.get(key); nsDataMap.put(new NSColumn(key), ((value == null) ? null : value.toString())); } return nsDataMap; } /** * flatten categorical value group into values list * * @param categoricalValGrp * - categorical val group, it some values like zn^us^ck^ * @return value list of categorical val */ public static List<String> flattenCatValGrp(String categoricalValGrp) { List<String> catVals = new ArrayList<String>(); if(StringUtils.isNotBlank(categoricalValGrp)) { for(String cval: Splitter.on('^').split(categoricalValGrp)) { catVals.add(cval); } } return catVals; } }