package com.skp.experiment.common;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.mahout.common.Pair;
import org.junit.Before;
import org.junit.Test;
import com.skp.experiment.cf.als.hadoop.KFoldCrossValidationUtils;
public class KFoldCrossValidatoinUtilsTest {
private ArrayList<String> seed;
private int testSize = 10;
@Before
public void setup() {
seed = new ArrayList<String>();
for (int i = 0; i < testSize; i++) {
seed.add(String.valueOf(i));
}
}
private void printSeed(List<String> list) {
for (int i = 0; i < list.size(); i++) {
System.out.print(list.get(i) + " ");
}
System.out.println();
}
@Test
public void testRandomSuffleInPlace() {
for (int k = 0; k < 10; k++) {
Map<String, Integer> counts = new HashMap<String, Integer>();
KFoldCrossValidationUtils.randomSuffleInPlace(seed);
// check size
assertTrue(testSize == seed.size());
for (int i = 0; i < seed.size(); i++) {
if (counts.containsKey(seed.get(i)) == false) {
counts.put(seed.get(i), 0);
}
counts.put(seed.get(i), counts.get(seed.get(i)) + 1);
}
// check distinct element
for (Entry<String, Integer> e : counts.entrySet()) {
assertTrue(e.getValue() == 1);
}
//printSeed(seed);
}
}
@Test
public void testSplitNth() {
KFoldCrossValidationUtils.randomSuffleInPlace(seed);
printSeed(seed);
int kfold = 5;
for (int nth = 0; nth < kfold; nth++) {
System.out.println("Nth: " + nth);
Pair<List<String>, List<String>> ret = KFoldCrossValidationUtils.splitNth(seed, 5, nth);
printSeed(ret.getFirst());
printSeed(ret.getSecond());
}
}
}