/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core; import ml.shifu.shifu.container.ModelInitInputObject; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.fs.PathFinder; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.Constants; import ml.shifu.shifu.util.JSONUtils; import org.encog.ml.BasicML; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.ml.data.buffer.BufferedMLDataSet; import org.encog.neural.networks.BasicNetwork; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * Abstract Trainer */ public abstract class AbstractTrainer { /** * Log for abstract trainer */ protected static final Logger log = LoggerFactory.getLogger(AbstractTrainer.class); /** * formatter */ protected static final DecimalFormat df = new DecimalFormat("0.000000"); /** * randomizer */ protected Random random; /** * model config instance */ protected ModelConfig modelConfig; /** * trainer id, identify the trainer */ protected int trainerID = 0; /** * dry run flag */ protected Boolean dryRun = false; /** * cross validation rate */ protected Double crossValidationRate; /** * sample rate */ protected Double baggingSampleRate; /** * training option, M: memory, D: disk */ protected String trainingOption = "M"; /** * training set instance */ protected MLDataSet trainSet; /** * validation set instance */ protected MLDataSet validSet; /** * base standard error value */ protected Double baseMSE = null; /** * path finder to locate file */ protected PathFinder pathFinder = null; public AbstractTrainer(ModelConfig modelConfig, int trainerID, Boolean dryRun) { this.random = new Random(System.currentTimeMillis() + trainerID); this.modelConfig = modelConfig; this.trainerID = trainerID; this.dryRun = dryRun; crossValidationRate = this.modelConfig.getValidSetRate(); if (crossValidationRate == null) { crossValidationRate = 0.2; } baggingSampleRate = this.modelConfig.getBaggingSampleRate(); if (baggingSampleRate == null) { baggingSampleRate = 0.8; } pathFinder = new PathFinder(modelConfig); } /* * Set up the training dataset and validation dataset */ public void setDataSet(MLDataSet masterDataSet) throws IOException { log.info("Setting Data Set..."); MLDataSet sampledDataSet; if (this.trainingOption.equalsIgnoreCase("M")) { log.info("Loading to Memory ..."); sampledDataSet = new BasicMLDataSet(); this.trainSet = new BasicMLDataSet(); this.validSet = new BasicMLDataSet(); } else if (this.trainingOption.equalsIgnoreCase("D")) { log.info("Loading to Disk ..."); sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb")); this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb")); this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb")); int inputSize = masterDataSet.getInputSize(); int idealSize = masterDataSet.getIdealSize(); ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize); ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize); ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize); } else { throw new RuntimeException("Training Option is not Valid: " + this.trainingOption); } // Encog 3.1 // int masterSize = masterDataSet.size(); // Encog 3.0 int masterSize = (int) masterDataSet.getRecordCount(); if (!modelConfig.isFixInitialInput()) { // Bagging if (modelConfig.isBaggingWithReplacement()) { // Bagging With Replacement int sampledSize = (int) (masterSize * baggingSampleRate); for (int i = 0; i < sampledSize; i++) { // Encog 3.1 // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize))); // Encog 3.0 double[] input = new double[masterDataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); masterDataSet.getRecord(random.nextInt(masterSize), pair); sampledDataSet.add(pair); } } else { // Bagging Without Replacement for (MLDataPair pair : masterDataSet) { if (random.nextDouble() < baggingSampleRate) { sampledDataSet.add(pair); } } } } else { List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement()); for (Integer i : list) { double[] input = new double[masterDataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); masterDataSet.getRecord(i, pair); sampledDataSet.add(pair); } } if (this.trainingOption.equalsIgnoreCase("D")) { ((BufferedMLDataSet) sampledDataSet).endLoad(); } // Cross Validation log.info("Generating Training Set and Validation Set ..."); if (!modelConfig.isFixInitialInput()) { // Encog 3.0 for (int i = 0; i < sampledDataSet.getRecordCount(); i++) { double[] input = new double[sampledDataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); sampledDataSet.getRecord(i, pair); if (random.nextDouble() > crossValidationRate) { trainSet.add(pair); } else { validSet.add(pair); } } } else { long sampleSize = sampledDataSet.getRecordCount(); long trainSetSize = (long) (sampleSize * (1 - crossValidationRate)); int i = 0; for (; i < trainSetSize; i++) { double[] input = new double[sampledDataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); sampledDataSet.getRecord(i, pair); trainSet.add(pair); } for (; i < sampleSize; i++) { double[] input = new double[sampledDataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); sampledDataSet.getRecord(i, pair); validSet.add(pair); } } if (this.trainingOption.equalsIgnoreCase("D")) { ((BufferedMLDataSet) trainSet).endLoad(); ((BufferedMLDataSet) validSet).endLoad(); } log.info(" - # Records of the Master Data Set: " + masterSize); log.info(" - Bagging Sample Rate: " + baggingSampleRate); log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement()); log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount()); log.info(" - Cross Validation Rate: " + crossValidationRate); log.info(" - # Records of the Training Set: " + this.getTrainSetSize()); log.info(" - # Records of the Validation Set: " + this.getValidSetSize()); } /** * get the training data set size * * @return number of size */ public int getTrainSetSize() { // Encog 3.1 // return trainSet.size(); // Encog 3.0 return (int) trainSet.getRecordCount(); } /** * get the training dataset * * @return the trainSet */ public MLDataSet getTrainSet() { return trainSet; } /** * @param trainSet the trainSet to set */ public void setTrainSet(MLDataSet trainSet) { this.trainSet = trainSet; } /** * @param validSet the validSet to set */ public void setValidSet(MLDataSet validSet) { this.validSet = validSet; } /** * get the validation set size * * @return the validation set number */ public int getValidSetSize() { // Encog 3.1 // return validSet.size(); // Encog 3.0 return (int) validSet.getRecordCount(); } /* * set the training option, M/D */ public void setTrainingOption(String trainingOption) { this.trainingOption = trainingOption; } /** * get the validation dataset * * @return the validation dataset */ public MLDataSet getValidSet() { return validSet; } /* * load/save the sampling data from pre-initialization file */ private List<Integer> loadSampleInput(int sampleSize, int masterSize, boolean replaceable) throws IOException { List<Integer> list = null; File file = new File("./init" + trainerID + ".json"); if (!file.exists()) { list = randomSetSampleIndex(sampleSize, masterSize, replaceable); ModelInitInputObject io = new ModelInitInputObject(); io.setNumSample(sampleSize); io.setSampleIndex(list); JSONUtils.writeValue(file, io); } else { BufferedReader reader = ShifuFileUtils.getReader("./init" + trainerID + ".json", SourceType.LOCAL); ModelInitInputObject io = JSONUtils.readValue(reader, ModelInitInputObject.class); if (io == null) { io = new ModelInitInputObject(); } if (io.getNumSample() != sampleSize) { list = randomSetSampleIndex(sampleSize, masterSize, replaceable); io.setNumSample(sampleSize); io.setSampleIndex(list); JSONUtils.writeValue(file, io); } else { list = io.getSampleIndex(); } reader.close(); } return list; } /* * randomizer the input data */ private List<Integer> randomSetSampleIndex(int sampleSize, int masterSize, boolean replaceable) { List<Integer> list = new ArrayList<Integer>(); if (replaceable) { for (int i = 0; i < sampleSize; i++) { list.add(random.nextInt(masterSize)); } } else { for (int i = 0; i < masterSize; i++) { if (random.nextDouble() < baggingSampleRate) { list.add(i); } } } return list; } /* * reset the weights in trainer */ public void resetParams(BasicML classifier) { if (modelConfig.isFixInitialInput()) { } else { if (modelConfig.getAlgorithm() == "NN" || modelConfig.getAlgorithm() == "LR") { ((BasicNetwork) classifier).reset(); } } } /* * A training start function, and print the training error and validation errors */ public abstract double train() throws IOException; /* * non-synchronously version update error * * @return the standard error */ public static Double calculateMSE(BasicNetwork network, MLDataSet dataSet) { double mse = 0; long numRecords = dataSet.getRecordCount(); for (int i = 0; i < numRecords; i++) { // Encog 3.1 // MLDataPair pair = dataSet.get(i); // Encog 3.0 double[] input = new double[dataSet.getInputSize()]; double[] ideal = new double[1]; MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal)); dataSet.getRecord(i, pair); MLData result = network.compute(pair.getInput()); double tmp = result.getData()[0] - pair.getIdeal().getData()[0]; mse += tmp * tmp; } mse = mse / numRecords; return mse; } public Double getBaseMSE() { return baseMSE; } public void setBaseMSE(Double baseMSE) { this.baseMSE = baseMSE; } }