/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.util; import com.fasterxml.jackson.core.JsonGenerationException; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; import ml.shifu.shifu.core.validator.ModelInspector; import ml.shifu.shifu.util.updater.ColumnConfigUpdater; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.fs.FileStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; import java.io.BufferedWriter; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.util.*; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType; import ml.shifu.shifu.container.obj.EvalConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.fs.PathFinder; import ml.shifu.shifu.udf.CalculateStatsUDF; /** * CommonUtilsTest */ public class CommonUtilsTest { private static final Logger LOG = LoggerFactory.getLogger(CommonUtilsTest.class); private ObjectMapper jsonMapper = new ObjectMapper(); @Test public void stringToIntegerListTest() { Assert.assertEquals(Arrays.asList(new Integer[] { 1, 2, 3 }), CommonUtils.stringToIntegerList("[1, 2, 3]")); } // @Test public void syncTest() throws IOException { ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, "test"); config.setModelSetName("testModel"); jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("ModelConfig.json"), config); ColumnConfig col = new ColumnConfig(); col.setColumnName("ColumnA"); List<ColumnConfig> columnConfigList = new ArrayList<ColumnConfig>(); columnConfigList.add(col); config.getDataSet().setSource(SourceType.LOCAL);; jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("ColumnConfig.json"), columnConfigList); File file = null; file = new File("models"); if(!file.exists()) { FileUtils.forceMkdir(file); } file = new File("models/model1.nn"); if(!file.exists()) { if(file.createNewFile()) { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), Constants.DEFAULT_CHARSET)); writer.write("test string"); writer.close(); } else { LOG.warn("Create file {} failed", file.getAbsolutePath()); } } file = new File("EvalSets/test"); if(!file.exists()) { FileUtils.forceMkdir(file); } file = new File("EvalSets/test/EvalConfig.json"); if(!file.exists()) { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), Constants.DEFAULT_CHARSET)); writer.write("test string"); writer.close(); } CommonUtils.copyConfFromLocalToHDFS(config, new PathFinder(config)); file = new File("ModelSets"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel/ModelConfig.json"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel/ColumnConfig.json"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel/ReasonCodeMap.json"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel/models/model1.nn"); Assert.assertTrue(file.exists()); file = new File("ModelSets/testModel/EvalSets/test/EvalConfig.json"); Assert.assertTrue(file.exists()); file = new File("ModelSets"); if(file.exists()) { FileUtils.deleteDirectory(file); } file = new File("ColumnConfig.json"); FileUtils.deleteQuietly(file); file = new File("ModelConfig.json"); FileUtils.deleteQuietly(file); FileUtils.deleteDirectory(new File("models")); FileUtils.deleteDirectory(new File("EvalSets")); } // @Test public void syncUpEvalTest() throws IOException { ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, "test"); config.setModelSetName("shifu"); File file = new File("evals/EvalA"); if(!file.exists()) { FileUtils.forceMkdir(file); } file = new File("testEval/EvalConfig.json"); FileUtils.touch(file); // CommonUtils.copyEvalConfFromLocalToHDFS(config, "testEval"); Assert.assertTrue(file.exists()); FileUtils.deleteDirectory(new File("ModelSets")); FileUtils.deleteDirectory(new File("evals")); } @Test public void loadModelConfigTest() throws JsonGenerationException, JsonMappingException, IOException { ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, "test"); config.setModelSetName("shifu"); jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("ModelConfig.json"), config); ModelConfig anotherConfig = CommonUtils.loadModelConfig(); Assert.assertEquals(config, anotherConfig); FileUtils.deleteQuietly(new File("ModelConfig.json")); } @Test public void getFinalSelectColumnConfigListTest() { Collection<ColumnConfig> configList = new ArrayList<ColumnConfig>(); ColumnConfig config = new ColumnConfig(); config.setColumnName("A"); config.setFinalSelect(false); configList.add(config); config = new ColumnConfig(); config.setFinalSelect(true); config.setColumnName("B"); configList.add(config); config = new ColumnConfig(); config.setFinalSelect(false); config.setColumnName("C"); configList.add(config); configList = CommonUtils.getFinalSelectColumnConfigList(configList); Assert.assertTrue(configList.size() == 1); } @Test public void getBinNumTest() { ColumnConfig config = new ColumnConfig(); config.setColumnName("A"); config.setColumnType(ColumnType.C); config.setBinCategory(Arrays.asList(new String[] { "2", "1", "3" })); int rt = CommonUtils.getBinNum(config, "2"); Assert.assertTrue(rt == 0); } @Test public void testStringToIntegerList() { Assert.assertEquals(CommonUtils.stringToIntegerList("[]").size(), 1); } // @Test // public void assembleDataPairTest() throws Exception { // Map<String, String> rawDataMap = new HashMap<String, String>(); // rawDataMap.put("ColumnA", "TestValue"); // // ColumnConfig config = new ColumnConfig(); // config.setColumnName("ColumnA"); // List<ColumnConfig> columnConfigList = new ArrayList<ColumnConfig>(); // columnConfigList.add(config); // // MLDataPair dp = CommonUtils.assembleDataPair(columnConfigList, // rawDataMap); // Assert.assertTrue(dp.getInput().getData().length == 0); // // Map<String, Object> objDataMap = new HashMap<String, Object>(); // objDataMap.put("ColumnA", 10); // config.setFinalSelect(true); // config.setMean(12.0); // config.setStdDev(4.6); // MLDataPair pair = CommonUtils.assembleDataPair(columnConfigList, // objDataMap); // Assert.assertTrue(pair.getInput().getData()[0] < 0.0); // } @Test public void getTargetColumnNumTest() { List<ColumnConfig> list = new ArrayList<ColumnConfig>(); ColumnConfig config = new ColumnConfig(); config.setColumnFlag(null); list.add(config); config = new ColumnConfig(); config.setColumnFlag(ColumnFlag.Target); config.setColumnNum(20); list.add(config); config = new ColumnConfig(); config.setColumnFlag(null); list.add(config); Assert.assertEquals(Integer.valueOf(20), CommonUtils.getTargetColumnNum(list)); } @Test public void loadModelsTest() { // TODO load models test } @Test public void getRawDataMapTest() { Map<String, String> map = CommonUtils.getRawDataMap(new String[] { "input1", "input2" }, new String[] { "1", "2" }); Assert.assertTrue(map.containsKey("input2")); Assert.assertTrue(map.keySet().size() == 2); } @Test public void stringToDoubleListTest() { String str = "[0,1,2,3]"; List<Integer> list = CommonUtils.stringToIntegerList(str); Assert.assertTrue(list.get(0) == 0); } // @Test public void updateColumnConfigFlagsTest() throws IOException { ModelConfig config = ModelConfig.createInitModelConfig("test", ALGORITHM.NN, "test"); config.getDataSet().setMetaColumnNameFile("./conf/meta_column_conf.txt"); config.getVarSelect().setForceRemoveColumnNameFile("./conf/remove_column_list.txt"); List<ColumnConfig> list = new ArrayList<ColumnConfig>(); ColumnConfig e = new ColumnConfig(); e.setColumnName("a"); list.add(e); e = new ColumnConfig(); e.setColumnName("c"); list.add(e); e = new ColumnConfig(); e.setColumnName("d"); list.add(e); ColumnConfigUpdater.updateColumnConfigFlags(config, list, ModelInspector.ModelStep.VARSELECT); Assert.assertTrue(list.get(0).isMeta()); } @Test public void stringToStringListTest() { String str = "[1,2,3,,4]"; List<Integer> list = CommonUtils.stringToIntegerList(str); Assert.assertTrue(list.get(0) == 1); } @Test public void getDerivedColumnNamesTest() { List<ColumnConfig> list = new ArrayList<ColumnConfig>(); ColumnConfig e = new ColumnConfig(); e.setColumnName("a"); list.add(e); e = new ColumnConfig(); e.setColumnName("derived_c"); list.add(e); e = new ColumnConfig(); e.setColumnName("d"); list.add(e); List<String> output = CommonUtils.getDerivedColumnNames(list); Assert.assertEquals(output.get(0), "derived_c"); } @Test public void testLoadModelConfig() throws IOException { ModelConfig config = CommonUtils.loadModelConfig( "src/test/resources/example/wdbc/wdbcModelSetLocal/ModelConfig.json", SourceType.LOCAL); Assert.assertEquals(config.getDataSet().getNegTags().get(0), "B"); } @Test public void testEscape() { Assert.assertEquals("\\\\t", CommonUtils.escapePigString("\t")); } @AfterClass public void delete() throws IOException { FileUtils.deleteDirectory(new File("common-utils")); } @Test public void testFindModels() throws IOException { ModelConfig modelConfig = CommonUtils.loadModelConfig( "src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL); File srcModels = new File("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/models"); File dstModels = new File("models"); FileUtils.copyDirectory(srcModels, dstModels); List<FileStatus> modelFiles = CommonUtils.findModels(modelConfig, null, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); EvalConfig evalConfig = modelConfig.getEvalConfigByName("EvalA"); evalConfig.setCustomPaths(new HashMap<String, String>()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, null); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, " "); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); FileUtils.deleteDirectory(dstModels); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, "./src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/models"); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, "./src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/models/model0.nn"); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(1, modelFiles.size()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, "not-exists"); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(0, modelFiles.size()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, "./src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/models/*.nn"); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); evalConfig.getCustomPaths().put(Constants.KEY_MODELS_PATH, "./src/test/resources/example/cancer-judgement/ModelStore/ModelSet{0,1,9}/*/*.nn"); modelFiles = CommonUtils.findModels(modelConfig, evalConfig, SourceType.LOCAL); Assert.assertEquals(5, modelFiles.size()); } @Test public void testStringToArray() { String input = "[-37.075125208681136, 0.5043788517677587, 1.2588712402838798, 2.543219666931007, 4.896511355654414, 8.986345381526105, 17.06859410430839, 33.557046979865774, 73.27777777777777, 231.63698630136986, 100000.0]"; List<Double> output = CommonUtils.stringToDoubleList(input); Assert.assertEquals( output, Arrays.asList(new Double[] { -37.075125208681136, 0.5043788517677587, 1.2588712402838798, 2.543219666931007, 4.896511355654414, 8.986345381526105, 17.06859410430839, 33.557046979865774, 73.27777777777777, 231.63698630136986, 100000.0 })); } @Test public void testCategoryVauleSepartor() { List<String> strList = new ArrayList<String>(); strList.add("[Hello, Testing"); strList.add("Haha, It's a testing]"); String joinStr = StringUtils.join(strList, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR); List<String> recoverList = CommonUtils.stringToStringList(joinStr, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR); Assert.assertEquals(2, recoverList.size()); Assert.assertEquals(strList.get(0).substring(1), recoverList.get(0)); Assert.assertEquals(strList.get(1).substring(0, strList.get(1).length() - 1), recoverList.get(1)); } @Test public void testSortFileNames() { File[] modelFiles = new File[5]; modelFiles[0] = new File("model3.nn"); modelFiles[1] = new File("model1.nn"); modelFiles[2] = new File("model0.nn"); modelFiles[3] = new File("model4.nn"); modelFiles[4] = new File("model2.nn"); Arrays.sort(modelFiles, new Comparator<File>() { @Override public int compare(File from, File to) { return from.getName().compareTo(to.getName()); } }); Assert.assertEquals(modelFiles[0].getName(), "model0.nn"); Assert.assertEquals(modelFiles[4].getName(), "model4.nn"); } @Test public void binIndexTest() { Double[] array = { Double.NEGATIVE_INFINITY, 2.1E-4, 0.00351, 0.01488, 0.02945, 0.0642, 0.11367, 0.22522, 0.23977 }; List<Double> binBoundary = Arrays.asList(array); Assert.assertEquals(CommonUtils.getBinIndex(binBoundary, 0.00350), 1); Assert.assertEquals(CommonUtils.getBinIndex(binBoundary, 0.00010), 0); Assert.assertEquals(CommonUtils.getBinIndex(binBoundary, 5D), 8); } @Test public void trimNumber() { Assert.assertEquals(CommonUtils.trimTag("1000"), "1000"); Assert.assertEquals(CommonUtils.trimTag("1.000"), "1"); Assert.assertEquals(CommonUtils.trimTag("1."), "1"); Assert.assertEquals(CommonUtils.trimTag("0.0000"), "0"); Assert.assertEquals(CommonUtils.trimTag("1.03400"), "1.034"); Assert.assertEquals(CommonUtils.trimTag("1.034001"), "1.034001"); Assert.assertEquals(CommonUtils.trimTag(".0000"), "0"); Assert.assertEquals(CommonUtils.trimTag(".00001"), "0.00001"); Assert.assertEquals(CommonUtils.trimTag(".M0001"), ".M0001"); Assert.assertEquals(CommonUtils.trimTag("M."), "M."); Assert.assertEquals(CommonUtils.trimTag(".L"), ".L"); Assert.assertEquals(CommonUtils.trimTag(" .L "), ".L"); Assert.assertEquals(CommonUtils.trimTag(" "), ""); Assert.assertEquals(CommonUtils.trimTag(null), ""); Assert.assertEquals(CommonUtils.trimTag("1.0B"), "1.0B"); } @AfterClass public void tearDown() throws IOException { FileUtils.deleteDirectory(new File(Constants.COLUMN_META_FOLDER_NAME)); } }