/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.data.versatile.division;
import java.util.List;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.ml.data.versatile.MatrixMLDataSet;
import org.encog.ml.data.versatile.VersatileMLDataSet;
/**
* Perform a data division.
*/
public class PerformDataDivision {
/**
* True, if we should shuffle during division.
*/
private final boolean shuffle;
/**
* A random number generator.
*/
private final GenerateRandom rnd;
/**
* Construct the data division processor.
* @param theShuffle Should we shuffle?
* @param theRandom Random number generator, often seeded to be consistent.
*/
public PerformDataDivision(boolean theShuffle, GenerateRandom theRandom) {
this.shuffle = theShuffle;
this.rnd = theRandom;
}
/**
* Perform the split.
* @param dataDivisionList The list of data divisions.
* @param dataset The dataset to split.
* @param inputCount The input count.
* @param idealCount The ideal count.
*/
public void perform(List<DataDivision> dataDivisionList, VersatileMLDataSet dataset,
int inputCount, int idealCount) {
generateCounts(dataDivisionList, dataset.getData().length);
generateMasks(dataDivisionList);
if (this.shuffle) {
performShuffle(dataDivisionList, dataset.getData().length);
}
createDividedDatasets(dataDivisionList, dataset, inputCount, idealCount);
}
/**
* Create the datasets that we will divide into.
* @param dataDivisionList The list of divisions.
* @param parentDataset The data set to divide.
* @param inputCount The input count.
* @param idealCount The ideal count.
*/
private void createDividedDatasets(List<DataDivision> dataDivisionList,
VersatileMLDataSet parentDataset, int inputCount, int idealCount) {
for (DataDivision division : dataDivisionList) {
MatrixMLDataSet dataset = new MatrixMLDataSet(parentDataset.getData(), inputCount,
idealCount, division.getMask());
dataset.setLagWindowSize(parentDataset.getLagWindowSize());
dataset.setLeadWindowSize(parentDataset.getLeadWindowSize());
division.setDataset(dataset);
}
}
/**
* Perform a Fisher-Yates shuffle.
* http://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
*
* @param dataDivisionList
* The division list.
*/
private void performShuffle(List<DataDivision> dataDivisionList,
int totalCount) {
for (int i = totalCount - 1; i > 0; i--) {
int n = this.rnd.nextInt(i + 1);
virtualSwap(dataDivisionList, i, n);
}
}
/**
* Swap two items, across all divisions.
* @param dataDivisionList The division list
* @param a The index of the first item to swap.
* @param b The index of the second item to swap.
*/
private void virtualSwap(List<DataDivision> dataDivisionList, int a, int b) {
DataDivision divA = null;
DataDivision divB = null;
int offsetA = 0;
int offsetB = 0;
// Find points a and b in the collections.
int baseIndex = 0;
for(DataDivision division: dataDivisionList) {
baseIndex+=division.getCount();
if( divA==null && a<baseIndex ) {
divA = division;
offsetA = a - (baseIndex - division.getCount());
}
if( divB==null && b<baseIndex ) {
divB = division;
offsetB = b - (baseIndex - division.getCount());
}
}
// Swap a and b.
int temp = divA.getMask()[offsetA];
divA.getMask()[offsetA] = divB.getMask()[offsetB];
divB.getMask()[offsetB] = temp;
}
/**
* Generate the masks, for all divisions.
* @param dataDivisionList The divisions.
*/
private void generateMasks(List<DataDivision> dataDivisionList) {
int idx = 0;
for (DataDivision division : dataDivisionList) {
division.allocateMask(division.getCount());
for (int i = 0; i < division.getCount(); i++) {
division.getMask()[i] = idx++;
}
}
}
/**
* Generate the counts for all divisions, give remaining items to final division.
* @param dataDivisionList The division list.
* @param totalCount The total count.
*/
private void generateCounts(List<DataDivision> dataDivisionList,
int totalCount) {
// First pass at division.
int countSofar = 0;
for (DataDivision division : dataDivisionList) {
int count = (int) (division.getPercent() * totalCount);
division.setCount(count);
countSofar += count;
}
// Adjust any remaining count
int remaining = totalCount - countSofar;
while (remaining-- > 0) {
int idx = this.rnd.nextInt(dataDivisionList.size());
DataDivision div = dataDivisionList.get(idx);
div.setCount(div.getCount() + 1);
}
}
/**
* @return the shuffle
*/
public boolean isShuffle() {
return shuffle;
}
/**
* @return the rnd
*/
public GenerateRandom getRandom() {
return rnd;
}
}