/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.nn; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Random; import java.util.Set; import ml.shifu.guagua.GuaguaRuntimeException; import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter; import ml.shifu.guagua.worker.AbstractWorkerComputable; import ml.shifu.guagua.worker.WorkerContext; import ml.shifu.guagua.worker.WorkerContext.WorkerCompletionCallBack; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataSet; import ml.shifu.shifu.core.dtrain.dataset.BufferedFloatMLDataSet; import ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork; import ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair; import ml.shifu.shifu.core.dtrain.dataset.FloatMLDataSet; import ml.shifu.shifu.core.dtrain.dataset.MemoryDiskFloatMLDataSet; import ml.shifu.shifu.core.dtrain.gs.GridSearch; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.math.RandomUtils; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Writable; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.neural.error.LinearErrorFunction; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.networks.BasicNetwork; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Splitter; /** * {@link AbstractNNWorker} is refactored as a common class for different NN input format. */ public abstract class AbstractNNWorker<VALUE extends Writable> extends AbstractWorkerComputable<NNParams, NNParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<VALUE>> { protected static final Logger LOG = LoggerFactory.getLogger(AbstractNNWorker.class); /** * Default splitter used to split input record. Use one instance to prevent more news in Splitter.on. */ protected static final Splitter DEFAULT_SPLITTER = Splitter.on(CommonConstants.DEFAULT_COLUMN_SEPARATOR) .trimResults(); /** * Training data set */ protected FloatMLDataSet trainingData = null; /** * Validation data set */ protected FloatMLDataSet validationData = null; /** * NN algorithm runner instance. */ protected ParallelGradient gradient; /** * Model Config read from HDFS */ protected ModelConfig modelConfig; /** * Column Config list read from HDFS */ protected List<ColumnConfig> columnConfigList; /** * Basic input node count for NN model */ protected int inputNodeCount; /** * Basic output node count for NN model */ protected int outputNodeCount; /** * {@link #candidateCount} is used to check if no variable is selected. If {@link #inputNodeCount} equals * {@link #candidateCount}, which means no column is selected or all columns are selected. */ protected int candidateCount; /** * Trainer id used to tag bagging training job, starting from 0, 1, 2 ... */ protected int trainerId = 0; /** * input record size, inc one by one. */ protected long count; /** * Whether the training is dry training. */ protected boolean isDry; /** * In each iteration, how many epochs will be run. */ protected int epochsPerIteration = 1; /** * Whether to alternative training and testing elements. */ protected boolean isCrossOver = false; /** * Whether to enable poisson bagging with replacement. */ protected boolean poissonSampler; /** * PoissonDistribution which is used for poisson sampling for bagging with replacement. */ protected PoissonDistribution rng = null; /** * PoissonDistribution which is used for up sampling positive records. */ protected PoissonDistribution upSampleRng = null; /** * A instance from context properties which is from job configuration. */ protected Properties props; /** * Indicates if there are cross validation data sets. */ protected boolean isSpecificValidation = false; /** * Valid params specially for grid search */ private Map<String, Object> validParams; /** * If stratified sampling or random sampling */ protected boolean isStratifiedSampling = false; /** * Positive count in training data list, only be effective in 0-1 regression or onevsall classification */ protected long positiveTrainCount; /** * Positive count in training data list and being selected in training, only be effective in 0-1 regression or * onevsall classification */ protected long positiveSelectedTrainCount; /** * Negative count in training data list , only be effective in 0-1 regression or onevsall classification */ protected long negativeTrainCount; /** * Negative count in training data list and being selected, only be effective in 0-1 regression or onevsall * classification */ protected long negativeSelectedTrainCount; /** * Positive count in validation data list, only be effective in 0-1 regression or onevsall classification */ protected long positiveValidationCount; /** * Negative count in validation data list, only be effective in 0-1 regression or onevsall classification */ protected long negativeValidationCount; /** * PoissonDistribution which is used for poission sampling for bagging with replacement. */ protected Map<Integer, PoissonDistribution> baggingRngMap = new HashMap<Integer, PoissonDistribution>(); /** * Construct a bagging random map for different classes. For stratified sampling, this is useful for each class * sampling. */ protected Map<Integer, Random> baggingRandomMap = new HashMap<Integer, Random>(); /** * Construct a validation random map for different classes. For stratified sampling, this is useful for each class * sampling. */ protected Map<Integer, Random> validationRandomMap = new HashMap<Integer, Random>(); /** * Random object to sample negative records */ protected Random sampelNegOnlyRandom = new Random(System.currentTimeMillis() + 1000L); /** * If k-fold cross validation */ private boolean isKFoldCV; /** * If enabled by extreme learning machine: https://en.wikipedia.org/wiki/Extreme_learning_machine */ private boolean isELM; /** * Cache all features with feature index for searching */ protected List<Integer> allFeatures; /** * Cache subset features with feature index for searching */ protected List<Integer> subFeatures; /** * Set for sub features to quick check if column is in sub feature list */ protected Set<Integer> subFeatureSet; protected boolean isUpSampleEnabled() { // only enabled in regression return this.upSampleRng != null && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain() .isOneVsAll())); } /** * Load all configurations for modelConfig and columnConfigList from source type. */ private void loadConfigFiles(final Properties props) { try { SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.isCrossOver = this.modelConfig.getTrain().getIsCrossOver().booleanValue(); LOG.info("Parameter isCrossOver:{}", this.isCrossOver); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } /** * Create memory data set object */ @SuppressWarnings("unused") private void initMemoryDataSet() { this.trainingData = new BasicFloatMLDataSet(); this.validationData = new BasicFloatMLDataSet(); } /** * For disk data set , initialize it with parameters and other work about creating files. * * @throws IOException * if any exception on local fs operations. * @throws RuntimeException * if error on deleting testing or training file. */ private void initDiskDataSet() throws IOException { Path trainingFile = DTrainUtils.getTrainingFile(); Path testingFile = DTrainUtils.getTestingFile(); LOG.debug("Use disk to store training data and testing data. Training data file:{}; Testing data file:{} ", trainingFile.toString(), testingFile.toString()); this.trainingData = new BufferedFloatMLDataSet(new File(trainingFile.toString())); ((BufferedFloatMLDataSet) this.trainingData).beginLoad(this.subFeatures.size(), getOutputNodeCount()); this.validationData = new BufferedFloatMLDataSet(new File(testingFile.toString())); ((BufferedFloatMLDataSet) this.validationData).beginLoad(this.subFeatures.size(), getOutputNodeCount()); } @Override public void init(WorkerContext<NNParams, NNParams> context) { // load props firstly this.props = context.getProps(); loadConfigFiles(context.getProps()); this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams()); this.validParams = this.modelConfig.getTrain().getParams(); if(gs.hasHyperParam()) { this.validParams = gs.getParams(trainerId); LOG.info("Start grid search master with params: {}", validParams); } Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if(kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; } this.poissonSampler = Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER)); this.rng = new PoissonDistribution(1.0d); Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if(Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain() .isOneVsAll()))) { // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } Integer epochsPerIterationInteger = this.modelConfig.getTrain().getEpochsPerIteration(); this.epochsPerIteration = epochsPerIterationInteger == null ? 1 : epochsPerIterationInteger.intValue(); LOG.info("epochsPerIteration in worker is :{}", epochsPerIteration); Object elmObject = validParams.get(DTrainUtils.IS_ELM); isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString()); LOG.info("Check isELM: {}", isELM); int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; // if is one vs all classification, outputNodeCount is set to 1 this.outputNodeCount = modelConfig.isRegression() ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : modelConfig.getTags().size()); this.candidateCount = inputOutputIndex[2]; boolean isAfterVarSelect = inputOutputIndex[0] != 0; LOG.info("Input count {}, output count {}, candidate count {}", inputNodeCount, outputNodeCount, candidateCount); // cache all feature list for sampling features this.allFeatures = CommonUtils.getAllFeatureList(columnConfigList, isAfterVarSelect); String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET); if(StringUtils.isBlank(subsetStr)) { this.subFeatures = this.allFeatures; } else { String[] splits = subsetStr.split(","); this.subFeatures = new ArrayList<Integer>(splits.length); for(String split: splits) { int featureIndex = Integer.parseInt(split); this.subFeatures.add(featureIndex); } } this.subFeatureSet = new HashSet<Integer>(this.subFeatures); LOG.info("subFeatures size is {}", subFeatures.size()); this.isDry = Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN)); this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig .getValidationDataSetRawPath())); this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); if(isOnDisk()) { LOG.info("NNWorker is loading data into disk."); try { initDiskDataSet(); } catch (IOException e) { throw new RuntimeException(e); } // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { ((BufferedFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close(); ((BufferedFloatMLDataSet) (AbstractNNWorker.this.validationData)).close(); } })); } else { LOG.info("NNWorker is loading data into memory."); double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double crossValidationRate = this.modelConfig.getValidSetRate(); try { if(StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.6), DTrainUtils .getTrainingFile().toString(), this.subFeatures.size(), this.outputNodeCount); this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.4), DTrainUtils .getTestingFile().toString(), this.subFeatures.size(), this.outputNodeCount); } else { this.trainingData = new MemoryDiskFloatMLDataSet( (long) (memoryStoreSize * (1 - crossValidationRate)), DTrainUtils.getTrainingFile() .toString(), this.subFeatures.size(), this.outputNodeCount); this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * crossValidationRate), DTrainUtils.getTestingFile().toString(), this.subFeatures.size(), this.outputNodeCount); } // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close(); ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.validationData)).close(); } })); } catch (IOException e) { throw new GuaguaRuntimeException(e); } } } private boolean isOnDisk() { return this.modelConfig.getTrain().getTrainOnDisk() != null && this.modelConfig.getTrain().getTrainOnDisk().booleanValue(); } @Override public NNParams doCompute(WorkerContext<NNParams, NNParams> context) { // For dry option, return empty result. // For first iteration, we don't do anything, just wait for master to update weights in next iteration. This // make sure all workers in the 1st iteration to get the same weights. if(this.isDry || context.isFirstIteration()) { return buildEmptyNNParams(context); } if(context.getLastMasterResult() == null) { // This may not happen since master will set initialization weights firstly. LOG.warn("Master result of last iteration is null."); return null; } LOG.debug("Set current model with params {}", context.getLastMasterResult()); // initialize gradients if null double[] weights = context.getLastMasterResult().getWeights(); if(gradient == null) { initGradient(this.trainingData, this.validationData, weights, this.isCrossOver); // register call back for shut down thread pool. context.addCompletionCallBack(new WorkerCompletionCallBack<NNParams, NNParams>() { @Override public void callback(WorkerContext<NNParams, NNParams> context) { AbstractNNWorker.this.gradient.shutdown(); } }); } else { if(this.isCrossOver) { // each iteration reset seed this.gradient.setSeed(System.currentTimeMillis()); } } this.gradient.getNetwork().setWeights(weights); // using the weights from master to train model in current iteration double[] gradients = null; for(int i = 0; i < epochsPerIteration; i++) { gradients = this.gradient.computeGradients(); if(this.epochsPerIteration > 1) { this.gradient.resetNetworkWeights(); } } // get train errors and test errors double trainError = this.gradient.getTrainError(); long start = System.currentTimeMillis(); double testError = this.validationData.getRecordCount() > 0 ? (this.gradient.calculateError()) : this.gradient .getTrainError(); LOG.info("Computing test error time: {}ms", (System.currentTimeMillis() - start)); // if the validation set is 0%, then the validation error should be "N/A" LOG.info("NNWorker compute iteration {} (train error {} validation error {})", new Object[] { context.getCurrentIteration(), trainError, (this.validationData.getRecordCount() > 0 ? testError : "N/A") }); NNParams params = new NNParams(); params.setTestError(testError); params.setTrainError(trainError); params.setGradients(gradients); // prevent null point; params.setWeights(new double[0]); params.setTrainSize(this.trainingData.getRecordCount()); params.setCount(count); return params; } @SuppressWarnings("unchecked") private void initGradient(FloatMLDataSet training, FloatMLDataSet testing, double[] weights, boolean isCrossOver) { int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS); List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC); List<Integer> hiddenNodeList = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES); BasicNetwork network = DTrainUtils.generateNetwork(this.subFeatures.size(), this.outputNodeCount, numLayers, actFunc, hiddenNodeList, false); // use the weights from master network.getFlat().setWeights(weights); FlatNetwork flat = network.getFlat(); // copy Propagation from encog, fix flat spot problem double[] flatSpot = new double[flat.getActivationFunctions().length]; for(int i = 0; i < flat.getActivationFunctions().length; i++) { flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0; } LOG.info("Gradient computing thread count is {}.", modelConfig.getTrain().getWorkerThreadCount()); this.gradient = new ParallelGradient((FloatFlatNetwork) flat, training, testing, flatSpot, new LinearErrorFunction(), isCrossOver, modelConfig.getTrain().getWorkerThreadCount(), this.isELM); } private NNParams buildEmptyNNParams(WorkerContext<NNParams, NNParams> workerContext) { NNParams params = new NNParams(); params.setWeights(new double[0]); params.setGradients(new double[0]); params.setTestError(NNConstants.DRY_ERROR); params.setTrainError(NNConstants.DRY_ERROR); return params; } @Override protected void postLoad(WorkerContext<NNParams, NNParams> workerContext) { if(isOnDisk()) { ((BufferedFloatMLDataSet) this.trainingData).endLoad(); if(validationData != null) { ((BufferedFloatMLDataSet) this.validationData).endLoad(); } } else { ((MemoryDiskFloatMLDataSet) this.trainingData).endLoad(); ((MemoryDiskFloatMLDataSet) this.validationData).endLoad(); LOG.info(" - # Training Records in memory: {}.", ((MemoryDiskFloatMLDataSet) this.trainingData).getMemoryCount()); LOG.info(" - # Training Records in disk: {}.", ((MemoryDiskFloatMLDataSet) this.trainingData).getDiskCount()); } LOG.info(" - # Records of the Master Data Set: {}.", this.count); LOG.info(" - Bagging Sample Rate: {}.", this.modelConfig.getBaggingSampleRate()); LOG.info(" - Bagging With Replacement: {}.", this.modelConfig.isBaggingWithReplacement()); LOG.info(" - Cross Validation Rate: {}.", this.modelConfig.getValidSetRate()); LOG.info(" - # Records of the Training Set: {}.", this.trainingData.getRecordCount()); if(modelConfig.isRegression() || modelConfig.getTrain().isOneVsAll()) { LOG.info(" - # Positive Bagging Selected Records of the Training Set: {}.", this.positiveSelectedTrainCount); LOG.info(" - # Negative Bagging Selected Records of the Training Set: {}.", this.negativeSelectedTrainCount); LOG.info(" - # Positive Raw Records of the Training Set: {}.", this.positiveTrainCount); LOG.info(" - # Negative Raw Records of the Training Set: {}.", this.negativeTrainCount); } if(validationData != null) { LOG.info(" - # Records of the Validation Set: {}.", this.validationData.getRecordCount()); if(modelConfig.isRegression() || modelConfig.getTrain().isOneVsAll()) { LOG.info(" - # Positive Records of the Validation Set: {}.", this.positiveValidationCount); LOG.info(" - # Negative Records of the Validation Set: {}.", this.negativeValidationCount); } } } protected float sampleWeights(float label) { float sampleWeights = 1f; // sample negative or kFoldCV, sample rate is 1d double sampleRate = (modelConfig.getTrain().getSampleNegOnly() || this.isKFoldCV) ? 1d : modelConfig.getTrain() .getBaggingSampleRate(); int classValue = (int) (label + 0.01f); if(!modelConfig.isBaggingWithReplacement()) { Random random = null; if(this.isStratifiedSampling) { random = baggingRandomMap.get(classValue); if(random == null) { random = new Random(); baggingRandomMap.put(classValue, random); } } else { random = baggingRandomMap.get(0); if(random == null) { random = new Random(); baggingRandomMap.put(0, random); } } if(random.nextDouble() <= sampleRate) { sampleWeights = 1f; } else { sampleWeights = 0f; } } else { // bagging with replacement sampling in training data set, take PoissonDistribution for sampling with // replacement if(this.isStratifiedSampling) { PoissonDistribution rng = this.baggingRngMap.get(classValue); if(rng == null) { rng = new PoissonDistribution(sampleRate); this.baggingRngMap.put(classValue, rng); } sampleWeights = rng.sample(); } else { PoissonDistribution rng = this.baggingRngMap.get(0); if(rng == null) { rng = new PoissonDistribution(sampleRate); this.baggingRngMap.put(0, rng); } sampleWeights = rng.sample(); } } return sampleWeights; } protected void addDataPairToDataSet(long hashcode, FloatMLDataPair pair) { addDataPairToDataSet(hashcode, pair, false); } protected boolean isPositive(float value) { return Float.compare(1f, value) == 0 ? true : false; } /** * Add to training set or validation set according to validation rate. * * @param hashcode * the hash code of the data * @param pair * data instance * @param isValidation * if it is validation * @return if in training, training is true, others are false. */ protected boolean addDataPairToDataSet(long hashcode, FloatMLDataPair pair, boolean isValidation) { if(this.isKFoldCV) { int k = this.modelConfig.getTrain().getNumKFold(); if(hashcode % k == this.trainerId) { this.validationData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } if(this.isSpecificValidation) { if(isValidation) { this.validationData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } else { if(Double.compare(this.modelConfig.getValidSetRate(), 0d) != 0) { int classValue = (int) (pair.getIdealArray()[0] + 0.01f); Random random = null; if(this.isStratifiedSampling) { // each class use one random instance random = validationRandomMap.get(classValue); if(random == null) { random = new Random(); this.validationRandomMap.put(classValue, random); } } else { // all data use one random instance random = validationRandomMap.get(0); if(random == null) { random = new Random(); this.validationRandomMap.put(0, random); } } if(this.modelConfig.isFixInitialInput()) { // for fix initial input, if hashcode%100 is in [start-hashcode, end-hashcode), validation, // otherwise training. start hashcode in different job is different to make sure bagging jobs have // different data. if end-hashcode is over 100, then check if hashcode is in [start-hashcode, 100] // or [0, end-hashcode] int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId; int endHashCode = startHashCode + Double.valueOf(this.modelConfig.getValidSetRate() * 100).intValue(); if(isInRange(hashcode, startHashCode, endHashCode)) { this.validationData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } else { this.trainingData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } else { // not fixed initial input, if random value >= validRate, training, otherwise validation. if(random.nextDouble() >= this.modelConfig.getValidSetRate()) { this.trainingData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } else { this.validationData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveValidationCount += 1L; } else { this.negativeValidationCount += 1L; } return false; } } } else { this.trainingData.add(pair); if(isPositive(pair.getIdealArray()[0])) { this.positiveTrainCount += 1L; } else { this.negativeTrainCount += 1L; } return true; } } } protected boolean isInRange(long hashcode, int startHashCode, int endHashCode) { // check if in [start, end] or if in [start, 100) and [0, end-100) int hashCodeIn100 = (int) hashcode % 100; if(endHashCode <= 100) { // in range [start, end) return hashCodeIn100 >= startHashCode && hashCodeIn100 < endHashCode; } else { // in range [start, 100) or [0, endHashCode-100) return hashCodeIn100 >= startHashCode || hashCodeIn100 < (endHashCode % 100); } } /** * Only baggingWithReplacement is set and size over NNConstants.NN_BAGGING_THRESHOLD, and random value <= 1/size. We * choose use existing data to add training data set and testing data set. */ @SuppressWarnings("unused") private boolean isBaggingReplacementTrigged(final double random) { long trainingSize = this.trainingData.getRecordCount(); long testingSize = this.validationData.getRecordCount(); // size should be equals to sampleCount:) long size = trainingSize + testingSize; return this.modelConfig.isBaggingWithReplacement() && (testingSize > 0) && (trainingSize > 0) && (size > NNConstants.NN_BAGGING_THRESHOLD) && (Double.compare(random, 0.5d) < 0); } /** * From Trainer, the logic is to random choose items in master dataset, but I don't want to load data twice for * saving memory. Use this to mock raw random repeat logic. This should be some logic difference because of data are * not loaded into data set, not random. */ @SuppressWarnings("unused") private void mockRandomRepeatData(double crossValidationRate, double random) { long trainingSize = this.trainingData.getRecordCount(); long testingSize = this.validationData.getRecordCount(); long size = trainingSize + testingSize; // here we used a strong cast from long to int since it's just a random choosing algorithm int next = RandomUtils.nextInt((int) size); FloatMLDataPair dataPair = new BasicFloatMLDataPair(new BasicFloatMLData(new float[this.subFeatures.size()]), new BasicFloatMLData(new float[this.outputNodeCount])); if(next >= trainingSize) { this.validationData.getRecord(next - trainingSize, dataPair); } else { this.trainingData.getRecord(next, dataPair); } if(Double.compare(random, crossValidationRate) < 0) { this.validationData.add(dataPair); } else { this.trainingData.add(dataPair); } } public FloatMLDataSet getTrainingData() { return trainingData; } public void setTrainingData(FloatMLDataSet trainingData) { this.trainingData = trainingData; } public FloatMLDataSet getTestingData() { return validationData; } public void setTestingData(FloatMLDataSet testingData) { this.validationData = testingData; } public ModelConfig getModelConfig() { return modelConfig; } public void setModelConfig(ModelConfig modelConfig) { this.modelConfig = modelConfig; } public List<ColumnConfig> getColumnConfigList() { return columnConfigList; } public void setColumnConfigList(List<ColumnConfig> columnConfigList) { this.columnConfigList = columnConfigList; } public int getInputNodeCount() { return inputNodeCount; } public void setInputNodeCount(int inputNodeCount) { this.inputNodeCount = inputNodeCount; } public int getOutputNodeCount() { return outputNodeCount; } public void setOutputNodeCount(int outputNodeCount) { this.outputNodeCount = outputNodeCount; } }