/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.propagation.sgd;
import java.util.Iterator;
import org.encog.EncogError;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
/**
* The BatchDataSet wraps a larger dataset and breaks it up into a series of batches. This dataset was specifically
* created to be used with the StochasticGradientDescent trainer; however, it should work with the others as well.
* It is important that the BatchDataSet's advance method be called at the end of each iteration, so that the next
* batch can be prepared. All Encog-provided trainers will detect the BatchDataSet and make this call.
*
* This dataset can be used in two ways, depending on the setting of the randomSamples property. If this value is
* false (the default), then the first batch starts at the beginning of the dataset, and following batches will start
* at the end of the previous batch. This method ensures that every data item is used If randomSamples is true, then
* each batch will be sampled from the underlying dataset (without replacement).
*/
public class BatchDataSet implements MLDataSet {
/**
* An iterator to be used with the BasicMLDataSet. This iterator does not
* support removes.
*
* @author jheaton
*/
public class BatchedMLIterator implements Iterator<MLDataPair> {
/**
* The index that the iterator is currently at.
*/
private int currentIndex = 0;
/**
* {@inheritDoc}
*/
@Override
public final boolean hasNext() {
return this.currentIndex < BatchDataSet.this.getBatchSize();
}
/**
* {@inheritDoc}
*/
@Override
public final MLDataPair next() {
if (!hasNext()) {
return null;
}
return BatchDataSet.this.get(this.currentIndex++);
}
/**
* {@inheritDoc}
*/
@Override
public final void remove() {
throw new EncogError("Called remove, unsupported operation.");
}
}
/**
* The source dataset.
*/
private MLDataSet dataset;
/**
* The current location within the source dataset.
*/
private int currentIndex;
/**
* The size of the batch.
*/
private int batchSize;
/**
* Random number generator.
*/
private GenerateRandom random;
/**
* Should a random sample be taken for each batch.
*/
private boolean randomBatches;
/**
* Index entries for the current random sample.
*/
private int[] randomSample;
/**
* Construct the batch dataset.
* @param theDataset The source dataset.
* @param theRandom The random number generator.
*/
public BatchDataSet(MLDataSet theDataset, GenerateRandom theRandom) {
this.dataset = theDataset;
this.random = theRandom;
setBatchSize(500);
}
/**
* @param theSize Set the batch size, but not larger than the dataset.
*/
public void setBatchSize(int theSize) {
this.batchSize = Math.min(theSize,this.dataset.size());
this.randomSample = new int[this.batchSize];
if( this.randomBatches ) {
generaterandomSample();
}
}
public int getBatchSize() {
return this.batchSize;
}
/**
* {@inheritDoc}
*/
@Override
public Iterator<MLDataPair> iterator() {
final BatchDataSet.BatchedMLIterator result = new BatchDataSet.BatchedMLIterator();
return result;
}
/**
* {@inheritDoc}
*/
@Override
public int getIdealSize() {
return this.dataset.getIdealSize();
}
/**
* {@inheritDoc}
*/
@Override
public int getInputSize() {
return this.dataset.getInputSize();
}
/**
* {@inheritDoc}
*/
@Override
public boolean isSupervised() {
return this.dataset.isSupervised();
}
/**
* {@inheritDoc}
*/
@Override
public long getRecordCount() {
return this.batchSize;
}
/**
* {@inheritDoc}
*/
@Override
public void getRecord(long index, MLDataPair pair) {
this.dataset.getRecord((index+this.currentIndex)%this.dataset.size(), pair);
}
/**
* This will open an additional batched dataset. However, please note, the additional datasets will use a
* mersenne twister generator that is seeded by a long sampled from this object's random number
* generator.
* @return An additional dataset.
*/
@Override
public MLDataSet openAdditional() {
BatchDataSet result = new BatchDataSet(this.dataset,new MersenneTwisterGenerateRandom(this.random.nextLong()));
result.setBatchSize(getBatchSize());
return result;
}
/**
* This operation is not supported by this object.
* @param data1 NA
*/
@Override
public void add(MLData data1) {
throw new EncogError("Unsupported.");
}
/**
* This operation is not supported by this object.
* @param inputData NA
* @param idealData NA
*/
@Override
public void add(MLData inputData, MLData idealData) {
throw new EncogError("Unsupported.");
}
/**
* This operation is not supported by this object.
* @param inputData NA
*/
@Override
public void add(MLDataPair inputData) {
throw new EncogError("Unsupported.");
}
/**
* {@inheritDoc}
*/
@Override
public void close() {
}
/**
* {@inheritDoc}
*/
@Override
public int size() {
return this.batchSize;
}
/**
* {@inheritDoc}
*/
@Override
public MLDataPair get(int index) {
int resultIndex = (index+this.currentIndex)%this.dataset.size();
if( this.randomBatches) {
resultIndex = this.randomSample[resultIndex];
}
return this.dataset.get(resultIndex);
}
/**
* Advance to the next batch. Should be called at the end of each training iteration.
*/
public void advance() {
if( this.randomBatches) {
generaterandomSample();
} else {
this.currentIndex = (this.currentIndex + this.batchSize) % this.dataset.size();
}
}
/**
* @return The current index, within a batch.
*/
public int getCurrentIndex() {
return currentIndex;
}
/**
* Set the current index, within a batch.
* @param currentIndex The current index, within a batch.
*/
public void setCurrentIndex(int currentIndex) {
this.currentIndex = currentIndex;
}
/**
* @return True, if random batches are being used.
*/
public boolean isRandomBatches() {
return randomBatches;
}
/**
* Set if random batches should be generated.
* @param randomBatches True, if random batches should be used.
*/
public void setRandomBatches(boolean randomBatches) {
this.randomBatches = randomBatches;
}
/**
* Generate a random sample.
*/
private void generaterandomSample() {
for(int i=0;i<this.batchSize;i++) {
boolean uniqueFound = true;
int t;
// Generate a unique index
do {
t = this.random.nextInt(0, this.dataset.size());
for (int j = 0; j < i; j++) {
if (this.randomSample[j]==t) {
uniqueFound = false;
break;
}
}
} while(!uniqueFound);
// Record it
this.randomSample[i] = t;
}
}
}