/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.container.obj;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import ml.shifu.shifu.util.JSONUtils;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;
/**
* Created by zhanhu on 11/18/16.
*/
public class ComboModelTrainTest {
@Test
public void testSerDeser() throws IOException {
ComboModelTrain inst = new ComboModelTrain();
List<SubTrainConf> varTrainConfList = new ArrayList<SubTrainConf>();
varTrainConfList.add(createSubTrainConf(ModelTrainConf.ALGORITHM.NN));
varTrainConfList.add(createSubTrainConf(ModelTrainConf.ALGORITHM.GBT));
inst.setSubTrainConfList(varTrainConfList);
JSONUtils.writeValue(new File("src/test/resources/example/ComboTrain.json"), inst);
ComboModelTrain anotherInst = JSONUtils.readValue(new File(
"src/test/resources/example/ComboTrain.json"), ComboModelTrain.class);
Assert.assertEquals(inst.getSubTrainConfList().size(), anotherInst.getSubTrainConfList().size());
}
private SubTrainConf createSubTrainConf(ModelTrainConf.ALGORITHM alg) {
SubTrainConf subTrainConf = new SubTrainConf();
subTrainConf.setModelStatsConf(createModelStatsConf(alg));
subTrainConf.setModelNormalizeConf(createModelNormalizeConf(alg));
subTrainConf.setModelVarSelectConf(createModelVarSelectConf(alg));
subTrainConf.setModelTrainConf(createModelTrainConf(alg));
return subTrainConf;
}
private ModelStatsConf createModelStatsConf(ModelTrainConf.ALGORITHM alg) {
ModelStatsConf statsConf = new ModelStatsConf();
if(ModelTrainConf.ALGORITHM.NN.equals(alg) || ModelTrainConf.ALGORITHM.LR.equals(alg )) {
statsConf.setBinningAlgorithm(ModelStatsConf.BinningAlgorithm.DynamicBinning);
statsConf.setBinningMethod(ModelStatsConf.BinningMethod.EqualTotal);
statsConf.setMaxNumBin(20);
} else if(ModelTrainConf.ALGORITHM.RF.equals(alg) || ModelTrainConf.ALGORITHM.GBT.equals(alg)) {
statsConf.setBinningAlgorithm(ModelStatsConf.BinningAlgorithm.SPDTI);
statsConf.setBinningMethod(ModelStatsConf.BinningMethod.EqualPositive);
statsConf.setMaxNumBin(20);
}
return statsConf;
}
private ModelNormalizeConf createModelNormalizeConf(ModelTrainConf.ALGORITHM alg) {
ModelNormalizeConf normalizeConf = new ModelNormalizeConf();
normalizeConf.setNormType(ModelNormalizeConf.NormType.WOE);
normalizeConf.setSampleNegOnly(false);
normalizeConf.setSampleRate(1.0);
return normalizeConf;
}
private ModelVarSelectConf createModelVarSelectConf(ModelTrainConf.ALGORITHM alg) {
ModelVarSelectConf varSelectConf = new ModelVarSelectConf();
varSelectConf.setFilterNum(20);
if(ModelTrainConf.ALGORITHM.NN.equals(alg) || ModelTrainConf.ALGORITHM.LR.equals(alg )) {
varSelectConf.setFilterBy("SE");
} else if(ModelTrainConf.ALGORITHM.RF.equals(alg) || ModelTrainConf.ALGORITHM.GBT.equals(alg)) {
varSelectConf.setFilterBy("FI");
}
return varSelectConf;
}
private ModelTrainConf createModelTrainConf(ModelTrainConf.ALGORITHM alg) {
ModelTrainConf trainConf = new ModelTrainConf();
trainConf.setAlgorithm(alg.name());
trainConf.setEpochsPerIteration(1);
trainConf.setParams(ModelTrainConf.createParamsByAlg(alg, trainConf));
trainConf.setNumTrainEpochs(100);
if(ModelTrainConf.ALGORITHM.NN.equals(alg)) {
trainConf.setNumTrainEpochs(200);
} else if(ModelTrainConf.ALGORITHM.SVM.equals(alg)) {
trainConf.setNumTrainEpochs(100);
} else if(ModelTrainConf.ALGORITHM.RF.equals(alg)) {
trainConf.setNumTrainEpochs(20000);
} else if(ModelTrainConf.ALGORITHM.GBT.equals(alg)) {
trainConf.setNumTrainEpochs(20000);
} else if(ModelTrainConf.ALGORITHM.LR.equals(alg)) {
trainConf.setNumTrainEpochs(100);
}
trainConf.setBaggingWithReplacement(true);
return trainConf;
}
@AfterClass
public void cleanup() {
FileUtils.deleteQuietly(new File("src/test/resources/example/ComboTrain.json"));
}
}