/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.validation; /** * Cross-validation is a technique for assessing how the results of a * statistical analysis will generalize to an independent data set. * It is mainly used in settings where the goal is prediction, and one * wants to estimate how accurately a predictive model will perform in * practice. One round of cross-validation involves partitioning a sample * of data into complementary subsets, performing the analysis on one subset * (called the training set), and validating the analysis on the other subset * (called the validation set or testing set). To reduce variability, multiple * rounds of cross-validation are performed using different partitions, and the * validation results are averaged over the rounds. * * @author Haifeng Li */ public class CrossValidation { /** * The number of rounds of cross validation. */ public final int k; /** * The index of training instances. */ public final int[][] train; /** * The index of testing instances. */ public final int[][] test; /** * Constructor. * @param n the number of samples. * @param k the number of rounds of cross validation. */ public CrossValidation(int n, int k) { if (n < 0) { throw new IllegalArgumentException("Invalid sample size: " + n); } if (k < 0 || k > n) { throw new IllegalArgumentException("Invalid number of CV rounds: " + k); } this.k = k; smile.math.Random random = new smile.math.Random(System.currentTimeMillis()); int[] index = random.permutate(n); train = new int[k][]; test = new int[k][]; int chunk = n / k; for (int i = 0; i < k; i++) { int start = chunk * i; int end = chunk * (i + 1); if (i == k-1) end = n; train[i] = new int[n - end + start]; test[i] = new int[end - start]; for (int j = 0, p = 0, q = 0; j < n; j++) { if (j >= start && j < end) { test[i][p++] = index[j]; } else { train[i][q++] = index[j]; } } } } }