/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dvarsel.wrapper;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.alg.NNTrainer;
import ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import java.io.IOException;
import java.util.List;
import java.util.Set;
/**
* Created on 11/24/2014.
*/
public class ValidationConductor {
private ModelConfig modelConfig;
@SuppressWarnings("unused")
private List<ColumnConfig> columnConfigList;
private Set<Integer> workingColumnSet;
private TrainingDataSet trainingDataSet;
public ValidationConductor(ModelConfig modelConfig,
List<ColumnConfig> columnConfigList,
Set<Integer> workingColumnSet,
TrainingDataSet trainingDataSet) {
this.modelConfig = modelConfig;
this.columnConfigList = columnConfigList;
this.workingColumnSet = workingColumnSet;
this.trainingDataSet = trainingDataSet;
}
public double runValidate() {
//1. prepare training data
MLDataSet trainingData = new BasicMLDataSet();
MLDataSet testingData = new BasicMLDataSet();
this.trainingDataSet.generateValidateData(this.workingColumnSet,
this.modelConfig.getValidSetRate(),
trainingData,
testingData);
//2. build NNTrainer
NNTrainer trainer = new NNTrainer(this.modelConfig, 1, false);
trainer.setTrainSet(trainingData);
trainer.setValidSet(testingData);
trainer.disableModelPersistence();
trainer.disableLogging();
//3. train and get validation error
double validateError = Double.MAX_VALUE;
try {
validateError = trainer.train();
} catch ( IOException e ) {
// Ignore the exception when nn files
validateError = trainer.getBaseMSE();
}
return validateError;
}
}