/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dvarsel.wrapper; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.nn.NNConstants; import ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet; import ml.shifu.shifu.core.dvarsel.dataset.TrainingRecord; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.testng.annotations.Test; import java.io.FileInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; /** * Created on 11/27/2014. */ public class ValidationConductorTest { @Test public void testRunValidate() throws IOException { ModelConfig modelConfig = CommonUtils.loadModelConfig( "src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL); List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList( "src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL); List<Integer> columnIdList = new ArrayList<Integer>(); for ( ColumnConfig columnConfig : columnConfigList ) { if ( columnConfig.isCandidate() ) { columnIdList.add(columnConfig.getColumnNum()); } } TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList); List<String> recordsList = IOUtils.readLines( new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00")); for( String record : recordsList ) { addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record); } Set<Integer> workingList = new HashSet<Integer>(); for ( Integer columnId : trainingDataSet.getDataColumnIdList() ) { workingList.clear(); workingList.add(columnId); ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet); double error = conductor.runValidate(); System.out.println("The error is - " + error + ", for columnId - " + columnId); } } public void addRecordIntoTrainDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, TrainingDataSet trainingDataSet, String record) { String[] fields = CommonUtils.split(record, modelConfig.getDataSetDelimiter()); int targetColumnId = CommonUtils.getTargetColumnNum(columnConfigList); String tag = StringUtils.trim(fields[targetColumnId]); double[] inputs = new double[trainingDataSet.getDataColumnIdList().size()]; double[] ideal = new double[1]; double significance = CommonConstants.DEFAULT_SIGNIFICANCE_VALUE; ideal[0] = (modelConfig.getPosTags().contains(tag) ? 1.0d : 0.0d); int i = 0; for ( Integer columnId : trainingDataSet.getDataColumnIdList() ) { inputs[i++] = Normalizer.normalize(columnConfigList.get(columnId), fields[columnId]); } trainingDataSet.addTrainingRecord(new TrainingRecord(inputs, ideal, significance)); } //@Test public void testPartershipModel() throws IOException { ModelConfig modelConfig = CommonUtils.loadModelConfig( "/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL); List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList( "/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL); List<Integer> columnIdList = new ArrayList<Integer>(); for ( ColumnConfig columnConfig : columnConfigList ) { if ( CommonUtils.isGoodCandidate(columnConfig) ) { columnIdList.add(columnConfig.getColumnNum()); } } TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList); List<String> recordsList = IOUtils.readLines( new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479")); for( String record : recordsList ) { addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record); } Set<Integer> workingList = new HashSet<Integer>(); for ( Integer columnId : trainingDataSet.getDataColumnIdList() ) { workingList.clear(); workingList.add(columnId); ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet); double error = conductor.runValidate(); System.out.println("The error is - " + error + ", for columnId - " + columnId); } } public void addNormalizedRecordIntoTrainDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, TrainingDataSet trainingDataSet, String record) { String[] fields = CommonUtils.split(record, "|"); double[] inputs = new double[trainingDataSet.getDataColumnIdList().size()]; double[] ideal = new double[1]; double significance = NNConstants.DEFAULT_SIGNIFICANCE_VALUE; int targetColumnId = CommonUtils.getTargetColumnNum(columnConfigList); ideal[0] = Double.parseDouble(fields[targetColumnId]); int i = 0; for ( Integer columnId : trainingDataSet.getDataColumnIdList() ) { if ( StringUtils.isBlank(fields[columnId]) ) { System.out.println(columnId + "|" + fields[columnId]); } try { inputs[i++] = Double.parseDouble(fields[columnId]); } catch ( Exception e ) { System.out.println(columnId + "|" + fields[columnId]); } } trainingDataSet.addTrainingRecord(new TrainingRecord(inputs, ideal, significance)); } }