package edu.stanford.nlp.util; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.function.Function; import junit.framework.Assert; import junit.framework.TestCase; /** * Unit tests for Iterables utility class. * * @author dramage */ public class IterablesTest extends TestCase { public void testZip() { String[] s1 = new String[]{"a", "b", "c"}; Integer[] s2 = new Integer[]{1, 2, 3, 4}; int count = 0; for (Pair<String,Integer> pair : Iterables.zip(s1, s2)) { assertEquals(pair.first, s1[count]); assertEquals(pair.second, s2[count]); count++; } assertEquals(s1.length < s2.length ? s1.length : s2.length, count); } @SuppressWarnings("unchecked") public void testChain() { List<String> s1 = Arrays.asList(new String[]{"hi", "there"}); List<String> s2 = Arrays.asList(new String[]{}); List<String> s3 = Arrays.asList(new String[]{"yoo"}); List<String> s4 = Arrays.asList(new String[]{}); List<String> answer = Arrays.asList(new String[]{"yoo","hi","there","yoo"}); List<String> chained = new ArrayList<String>(); for (String s : Iterables.chain(s3, s1, s2, s3, s4)) { chained.add(s); } assertEquals(answer, chained); } public void testFilter() { List<String> values = Arrays.asList("a","HI","tHere","YO"); Iterator<String> iterator = Iterables.filter(values, new Function<String,Boolean>(){ public Boolean apply(String in) { return in.equals(in.toUpperCase()); } }).iterator(); assertTrue(iterator.hasNext()); assertEquals(iterator.next(), "HI"); assertEquals(iterator.next(), "YO"); assertFalse(iterator.hasNext()); } public void testTransform() { List<Integer> values = Arrays.asList(1,2,3,4); List<Integer> squares = Arrays.asList(1,4,9,16); Function<Integer,Integer> squarer = new Function<Integer,Integer>() { public Integer apply(Integer in) { return in * in; } }; for (Pair<Integer,Integer> pair : Iterables.zip(Iterables.transform(values, squarer), squares)) { assertEquals(pair.first, pair.second); } } public void testMerge() { List<String> a = Arrays.asList("a","b","d","e"); List<String> b = Arrays.asList("b","c","d","e"); Comparator<String> comparator = new Comparator<String>() { public int compare(String o1, String o2) { return o1.compareTo(o2); } }; Iterator<Pair<String,String>> iter = Iterables.merge(a, b, comparator).iterator(); assertEquals(iter.next(),new Pair<String,String>("b","b")); assertEquals(iter.next(),new Pair<String,String>("d","d")); assertEquals(iter.next(),new Pair<String,String>("e","e")); assertTrue(!iter.hasNext()); } public void testMerge3() { List<String> a = Arrays.asList("a","b","d","e"); List<String> b = Arrays.asList("b","c","d","e"); List<String> c = Arrays.asList("a", "b", "c", "e", "f"); Comparator<String> comparator = new Comparator<String>() { public int compare(String o1, String o2) { return o1.compareTo(o2); } }; Iterator<Triple<String,String,String>> iter = Iterables.merge(a, b, c, comparator).iterator(); assertEquals(iter.next(),new Triple<String,String,String>("b","b", "b")); assertEquals(iter.next(),new Triple<String,String,String>("e","e", "e")); assertTrue( ! iter.hasNext()); } public void testGroup() { String[] input = new String[]{ "0 ab", "0 bb", "0 cc", "1 dd", "2 dd", "2 kj", "3 kj", "3 kk"}; int[] counts = new int[]{3,1,2,2}; Comparator<String> fieldOne= new Comparator<String>() { public int compare(String o1, String o2) { return o1.split(" ")[0].compareTo(o2.split(" ")[0]); } }; int index = 0; int group = 0; for (Iterable<String> set : Iterables.group(Arrays.asList(input), fieldOne)) { String sharedKey = null; int thisCount = 0; for (String line : set) { String thisKey = line.split(" ")[0]; if (sharedKey == null) { sharedKey = thisKey; } else { assertEquals("Wrong key", sharedKey, thisKey); } assertEquals("Wrong input line", line, input[index++]); thisCount++; } assertEquals("Wrong number of items in this iterator", counts[group++], thisCount); } assertEquals("Didn't get all inputs", input.length, index); assertEquals("Wrong number of groups", counts.length, group); } public void testSample() { // make sure correct number of items is sampled and items are in range Iterable<Integer> items = Arrays.asList(5, 4, 3, 2, 1); int count = 0; for (Integer item: Iterables.sample(items, 5, 2, new Random())) { ++count; Assert.assertTrue(item <= 5); Assert.assertTrue(item >= 1); } Assert.assertEquals(2, count); } }