/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package opennlp.tools.util.eval; import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.NoSuchElementException; import org.junit.Assert; import org.junit.Test; import opennlp.tools.util.ObjectStream; import opennlp.tools.util.eval.CrossValidationPartitioner.TrainingSampleStream; /** * Test for the {@link CrossValidationPartitioner} class. */ public class CrossValidationPartitionerTest { @Test public void testEmptyDataSet() throws IOException { Collection<String> emptyCollection = Collections.emptySet(); CrossValidationPartitioner<String> partitioner = new CrossValidationPartitioner<>(emptyCollection, 2); Assert.assertTrue(partitioner.hasNext()); Assert.assertNull(partitioner.next().read()); Assert.assertTrue(partitioner.hasNext()); Assert.assertNull(partitioner.next().read()); Assert.assertFalse(partitioner.hasNext()); try { // Should throw NoSuchElementException partitioner.next(); // ups, hasn't thrown one Assert.fail(); } catch (NoSuchElementException e) { // expected } } /** * Test 3-fold cross validation on a small sample data set. */ @Test public void test3FoldCV() throws IOException { List<String> data = new LinkedList<>(); data.add("01"); data.add("02"); data.add("03"); data.add("04"); data.add("05"); data.add("06"); data.add("07"); data.add("08"); data.add("09"); data.add("10"); CrossValidationPartitioner<String> partitioner = new CrossValidationPartitioner<>(data, 3); // first partition Assert.assertTrue(partitioner.hasNext()); TrainingSampleStream<String> firstTraining = partitioner.next(); Assert.assertEquals("02", firstTraining.read()); Assert.assertEquals("03", firstTraining.read()); Assert.assertEquals("05", firstTraining.read()); Assert.assertEquals("06", firstTraining.read()); Assert.assertEquals("08", firstTraining.read()); Assert.assertEquals("09", firstTraining.read()); Assert.assertNull(firstTraining.read()); ObjectStream<String> firstTest = firstTraining.getTestSampleStream(); Assert.assertEquals("01", firstTest.read()); Assert.assertEquals("04", firstTest.read()); Assert.assertEquals("07", firstTest.read()); Assert.assertEquals("10", firstTest.read()); Assert.assertNull(firstTest.read()); // second partition Assert.assertTrue(partitioner.hasNext()); TrainingSampleStream<String> secondTraining = partitioner.next(); Assert.assertEquals("01", secondTraining.read()); Assert.assertEquals("03", secondTraining.read()); Assert.assertEquals("04", secondTraining.read()); Assert.assertEquals("06", secondTraining.read()); Assert.assertEquals("07", secondTraining.read()); Assert.assertEquals("09", secondTraining.read()); Assert.assertEquals("10", secondTraining.read()); Assert.assertNull(secondTraining.read()); ObjectStream<String> secondTest = secondTraining.getTestSampleStream(); Assert.assertEquals("02", secondTest.read()); Assert.assertEquals("05", secondTest.read()); Assert.assertEquals("08", secondTest.read()); Assert.assertNull(secondTest.read()); // third partition Assert.assertTrue(partitioner.hasNext()); TrainingSampleStream<String> thirdTraining = partitioner.next(); Assert.assertEquals("01", thirdTraining.read()); Assert.assertEquals("02", thirdTraining.read()); Assert.assertEquals("04", thirdTraining.read()); Assert.assertEquals("05", thirdTraining.read()); Assert.assertEquals("07", thirdTraining.read()); Assert.assertEquals("08", thirdTraining.read()); Assert.assertEquals("10", thirdTraining.read()); Assert.assertNull(thirdTraining.read()); ObjectStream<String> thirdTest = thirdTraining.getTestSampleStream(); Assert.assertEquals("03", thirdTest.read()); Assert.assertEquals("06", thirdTest.read()); Assert.assertEquals("09", thirdTest.read()); Assert.assertNull(thirdTest.read()); Assert.assertFalse(partitioner.hasNext()); } @Test public void testFailSafty() throws IOException { List<String> data = new LinkedList<>(); data.add("01"); data.add("02"); data.add("03"); data.add("04"); CrossValidationPartitioner<String> partitioner = new CrossValidationPartitioner<>(data, 4); // Test that iterator from previous partition fails // if it is accessed TrainingSampleStream<String> firstTraining = partitioner.next(); Assert.assertEquals("02", firstTraining.read()); TrainingSampleStream<String> secondTraining = partitioner.next(); try { firstTraining.read(); Assert.fail(); } catch (IllegalStateException expected) { // the read above is expected to throw an exception } try { firstTraining.getTestSampleStream(); Assert.fail(); } catch (IllegalStateException expected) { // the read above is expected to throw an exception } // Test that training iterator fails if there is a test iterator secondTraining.getTestSampleStream(); try { secondTraining.read(); Assert.fail(); } catch (IllegalStateException expected) { // the read above is expected to throw an exception } // Test that test iterator from previous partition fails // if there is a new partition TrainingSampleStream<String> thirdTraining = partitioner.next(); ObjectStream<String> thridTest = thirdTraining.getTestSampleStream(); Assert.assertTrue(partitioner.hasNext()); partitioner.next(); try { thridTest.read(); Assert.fail(); } catch (IllegalStateException expected) { // the read above is expected to throw an exception } } @Test public void testToString() { Collection<String> emptyCollection = Collections.emptySet(); new CrossValidationPartitioner<>(emptyCollection, 10).toString(); } }