/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.data.folded;
import java.util.Iterator;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.networks.training.TrainingError;
/**
* A folded data set allows you to "fold" the data into several equal(or nearly
* equal) datasets. You then have the ability to select which fold the dataset
* will process. This is very useful for crossvalidation.
*
* This dataset works off of an underlying dataset. By default there are no
* folds (fold size 1). Call the fold method to create more folds.
*
*/
public class FoldedDataSet implements MLDataSet {
/**
* Error message: adds are not supported.
*/
public static final String ADD_NOT_SUPPORTED
= "Direct adds to the folded dataset are not supported.";
/**
* The underlying dataset.
*/
private final MLDataSet underlying;
/**
* The fold that we are currently on.
*/
private int currentFold;
/**
* The total number of folds. Or 0 if the data has not been folded yet.
*/
private int numFolds;
/**
* The size of all folds, except the last fold, the last fold may have a
* different number.
*/
private int foldSize;
/**
* The size of the last fold.
*/
private int lastFoldSize;
/**
* The offset to the current fold.
*/
private int currentFoldOffset;
/**
* The size of the current fold.
*/
private int currentFoldSize;
/**
* The owner object(from openAdditional).
*/
private FoldedDataSet owner;
/**
* Create a folded dataset.
*
* @param theUnderlying
* The underlying folded dataset.
*/
public FoldedDataSet(final MLDataSet theUnderlying) {
this.underlying = theUnderlying;
fold(1);
}
/**
* Not supported.
*
* @param data1
* Not used.
*/
@Override
public void add(final MLData data1) {
throw new TrainingError(FoldedDataSet.ADD_NOT_SUPPORTED);
}
/**
* Not supported.
*
* @param inputData
* Not used.
* @param idealData
* Not used.
*/
@Override
public void add(final MLData inputData, final MLData idealData) {
throw new TrainingError(FoldedDataSet.ADD_NOT_SUPPORTED);
}
/**
* Not supported.
*
* @param inputData
* Not used.
*/
@Override
public void add(final MLDataPair inputData) {
throw new TrainingError(FoldedDataSet.ADD_NOT_SUPPORTED);
}
/**
* Close the dataset.
*/
@Override
public void close() {
this.underlying.close();
}
/**
* Fold the dataset. Must be done before the dataset is used.
*
* @param theNumFolds
* The number of folds.
*/
public void fold(final int theNumFolds) {
this.numFolds = (int) Math.min(theNumFolds,
this.underlying.getRecordCount());
this.foldSize = (int) (this.underlying.getRecordCount()
/ this.numFolds);
this.lastFoldSize = (int) this.foldSize;
this.lastFoldSize += (int) (this.underlying.getRecordCount()
- (this.foldSize * this.numFolds));
setCurrentFold(0);
}
/**
* @return the currentFold
*/
public int getCurrentFold() {
if (this.owner != null) {
return this.owner.getCurrentFold();
} else {
return this.currentFold;
}
}
/**
* @return the currentFoldOffset
*/
public int getCurrentFoldOffset() {
if (this.owner != null) {
return this.owner.getCurrentFoldOffset();
} else {
return this.currentFoldOffset;
}
}
/**
* @return the currentFoldSize
*/
public int getCurrentFoldSize() {
if (this.owner != null) {
return this.owner.getCurrentFoldSize();
} else {
return this.currentFoldSize;
}
}
/**
* {@inheritDoc}
*/
@Override
public int getIdealSize() {
return this.underlying.getIdealSize();
}
/**
* {@inheritDoc}
*/
@Override
public int getInputSize() {
return this.underlying.getInputSize();
}
/**
* @return the numFolds
*/
public int getNumFolds() {
return this.numFolds;
}
/**
* @return The owner.
*/
public FoldedDataSet getOwner() {
return this.owner;
}
/**
* {@inheritDoc}
*/
@Override
public void getRecord(final long index, final MLDataPair pair) {
this.underlying.getRecord(getCurrentFoldOffset() + index, pair);
}
/**
* {@inheritDoc}
*/
@Override
public long getRecordCount() {
return getCurrentFoldSize();
}
/**
* @return The underlying dataset.
*/
public MLDataSet getUnderlying() {
return this.underlying;
}
/**
* {@inheritDoc}
*/
@Override
public boolean isSupervised() {
return this.underlying.isSupervised();
}
/**
* {@inheritDoc}
*/
@Override
public Iterator<MLDataPair> iterator() {
return new FoldedIterator(this);
}
/**
* {@inheritDoc}
*/
@Override
public MLDataSet openAdditional() {
final FoldedDataSet folded = new FoldedDataSet(
this.underlying.openAdditional());
folded.setOwner(this);
return folded;
}
/**
* Set the current fold.
*
* @param theCurrentFold
* the currentFold to set
*/
public void setCurrentFold(final int theCurrentFold) {
if (this.owner != null) {
throw new TrainingError(
"Can't set the fold on a non-top-level set.");
}
if (theCurrentFold >= this.numFolds) {
throw new TrainingError(
"Can't set the current fold to be greater than "
+ "the number of folds.");
}
this.currentFold = theCurrentFold;
this.currentFoldOffset = this.foldSize * this.currentFold;
if (this.currentFold == (this.numFolds - 1)) {
this.currentFoldSize = this.lastFoldSize;
} else {
this.currentFoldSize = this.foldSize;
}
}
/**
* @param theOwner
* The owner.
*/
public void setOwner(final FoldedDataSet theOwner) {
this.owner = theOwner;
}
@Override
public int size() {
return (int)getRecordCount();
}
@Override
public MLDataPair get(int index) {
MLDataPair result = BasicMLDataPair.createPair(getInputSize(), getIdealSize());
this.getRecord(index, result);
return result;
}
}