/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.alg;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.AbstractTrainer;
import ml.shifu.shifu.core.ConvergeJudger;
import org.apache.commons.io.FileUtils;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.quick.QuickPropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
/**
* Implementation of AbstractTrainer for LogisticRegression
*/
public class LogisticRegressionTrainer extends AbstractTrainer {
private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionTrainer.class);
public static final String LEARNING_RATE = "LearningRate";
private BasicNetwork classifier;
/**
* Convergence judger instance for convergence criteria checking.
*/
private ConvergeJudger judger = new ConvergeJudger();
public LogisticRegressionTrainer(ModelConfig modelConfig, int trainerID, Boolean dryRun) {
super(modelConfig, trainerID, dryRun);
}
public void setDataSet(MLDataSet masterDataSet) throws IOException {
super.setDataSet(masterDataSet);
}
/**
* {@inheritDoc}
* <p>
* no <code>regularization</code>
* <p>
* Regular will be provide later
* <p>
*
* @throws IOException
* e
*/
@Override
public double train() throws IOException {
LOG.info("Using logistic regression algorithm...");
LOG.info("Input Size: " + trainSet.getInputSize());
LOG.info("Ideal Size: " + trainSet.getIdealSize());
classifier = new BasicNetwork();
classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
classifier.getStructure().finalizeStructure();
// resetParams(classifier);
classifier.reset();
// Propagation mlTrain = getMLTrain();
Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get(
"LearningRate"));
int epochs = modelConfig.getNumTrainEpochs();
// Get convergence threshold from modelConfig.
double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain()
.getConvergenceThreshold().doubleValue();
String formatedThreshold = df.format(threshold);
LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
for(int i = 0; i < epochs; i++) {
propagation.iteration();
double trainError = propagation.getError();
double validError = classifier.calculateError(this.validSet);
LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:"
+ df.format(validError));
// Convergence judging.
double avgErr = (trainError + validError) / 2;
if(judger.judge(avgErr, threshold)) {
LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1),
df.format(avgErr), formatedThreshold);
break;
} else {
LOG.info("Trainer-{}> Epoch #{} Average Error: {}, Threshold: {}", trainerID, (i + 1),
df.format(avgErr), formatedThreshold);
}
}
propagation.finishTraining();
LOG.info("#" + this.trainerID + " finish training");
saveLR();
return 0.0d;
}
private void saveLR() throws IOException {
File folder = new File("./models");
if(!folder.exists()) {
FileUtils.forceMkdir(folder);
}
EncogDirectoryPersistence.saveObject(new File("./models/model" + this.trainerID + ".lr"), classifier);
}
public BasicNetwork getClassifier() {
return classifier;
}
}