/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dvarsel.wrapper; import java.io.FileInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.List; import junit.framework.Assert; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dvarsel.CandidateSeed; import ml.shifu.shifu.core.dvarsel.VarSelMasterResult; import ml.shifu.shifu.core.dvarsel.VarSelWorkerResult; import ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet; import ml.shifu.shifu.core.dvarsel.dataset.TrainingRecord; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.testng.annotations.Test; /** * Created on 11/27/2014. */ public class WrapperWorkerConductorTest { @Test public void testWrapperConductor() throws IOException { ModelConfig modelConfig = CommonUtils.loadModelConfig( "src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL); List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList( "src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL); WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList); TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList); wrapper.retainData(trainingDataSet); List<Integer> columnIdList = new ArrayList<Integer>(); for ( int i = 2; i < 30; i ++ ) { columnIdList.add(i); } List<CandidateSeed> seedList = new ArrayList<CandidateSeed>(); for ( int i = 0; i < 10; i ++ ) { seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7))); } wrapper.consumeMasterResult(new VarSelMasterResult(seedList)); VarSelWorkerResult workerResult = wrapper.generateVarSelResult(); Assert.assertNotNull(workerResult); Assert.assertTrue(workerResult.getSeedPerfList().size() > 0 ); } public TrainingDataSet genTrainingDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException { List<Integer> columnIdList = new ArrayList<Integer>(); for (ColumnConfig columnConfig : columnConfigList) { if (columnConfig.isCandidate()) { columnIdList.add(columnConfig.getColumnNum()); } } TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList); List<String> recordsList = IOUtils.readLines( new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00")); for (String record : recordsList) { addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record); } return trainingDataSet; } public void addRecordIntoTrainDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, TrainingDataSet trainingDataSet, String record) { String[] fields = CommonUtils.split(record, modelConfig.getDataSetDelimiter()); int targetColumnId = CommonUtils.getTargetColumnNum(columnConfigList); String tag = StringUtils.trim(fields[targetColumnId]); double[] inputs = new double[trainingDataSet.getDataColumnIdList().size()]; double[] ideal = new double[1]; double significance = CommonConstants.DEFAULT_SIGNIFICANCE_VALUE; ideal[0] = (modelConfig.getPosTags().contains(tag) ? 1.0d : 0.0d); int i = 0; for ( Integer columnId : trainingDataSet.getDataColumnIdList() ) { inputs[i++] = Normalizer.normalize(columnConfigList.get(columnId), fields[columnId]); } trainingDataSet.addTrainingRecord(new TrainingRecord(inputs, ideal, significance)); } }