/*********************************************************************************************************************** * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. **********************************************************************************************************************/ package eu.stratosphere.pact.runtime.sort; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import eu.stratosphere.api.common.typeutils.TypeComparator; import eu.stratosphere.api.common.typeutils.TypePairComparator; import eu.stratosphere.api.common.typeutils.TypeSerializer; import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparator; import eu.stratosphere.api.java.typeutils.runtime.record.RecordPairComparator; import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializer; import eu.stratosphere.pact.runtime.test.util.TestData; import eu.stratosphere.pact.runtime.test.util.TestData.Generator; import eu.stratosphere.pact.runtime.test.util.TestData.Generator.KeyMode; import eu.stratosphere.pact.runtime.test.util.TestData.Generator.ValueMode; import eu.stratosphere.types.Record; import eu.stratosphere.util.MutableObjectIterator; /** */ public class SortMergeCoGroupIteratorITCase { // the size of the left and right inputs private static final int INPUT_1_SIZE = 20000; private static final int INPUT_2_SIZE = 1000; // random seeds for the left and right input data generators private static final long SEED1 = 561349061987311L; private static final long SEED2 = 231434613412342L; // left and right input data generators private Generator generator1; private Generator generator2; // left and right input RecordReader mocks private MutableObjectIterator<Record> reader1; private MutableObjectIterator<Record> reader2; private TypeSerializer<Record> serializer1; private TypeSerializer<Record> serializer2; private TypeComparator<Record> comparator1; private TypeComparator<Record> comparator2; private TypePairComparator<Record, Record> pairComparator; @SuppressWarnings("unchecked") @Before public void beforeTest() { this.serializer1 = RecordSerializer.get(); this.serializer2 = RecordSerializer.get(); this.comparator1 = new RecordComparator(new int[] {0}, new Class[]{TestData.Key.class}); this.comparator2 = new RecordComparator(new int[] {0}, new Class[]{TestData.Key.class}); this.pairComparator = new RecordPairComparator(new int[] {0}, new int[] {0}, new Class[]{TestData.Key.class}); } @Test public void testMerge() { try { generator1 = new Generator(SEED1, 500, 4096, KeyMode.SORTED, ValueMode.RANDOM_LENGTH); generator2 = new Generator(SEED2, 500, 2048, KeyMode.SORTED, ValueMode.RANDOM_LENGTH); reader1 = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); reader2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); // collect expected data Map<TestData.Key, Collection<TestData.Value>> expectedValuesMap1 = collectData(generator1, INPUT_1_SIZE); Map<TestData.Key, Collection<TestData.Value>> expectedValuesMap2 = collectData(generator2, INPUT_2_SIZE); Map<TestData.Key, List<Collection<TestData.Value>>> expectedCoGroupsMap = coGroupValues(expectedValuesMap1, expectedValuesMap2); // reset the generators generator1.reset(); generator2.reset(); // compare with iterator values SortMergeCoGroupIterator<Record, Record> iterator = new SortMergeCoGroupIterator<Record, Record>( this.reader1, this.reader2, this.serializer1, this.comparator1, this.serializer2, this.comparator2, this.pairComparator); iterator.open(); final TestData.Key key = new TestData.Key(); while (iterator.next()) { Iterator<Record> iter1 = iterator.getValues1(); Iterator<Record> iter2 = iterator.getValues2(); TestData.Value v1 = null; TestData.Value v2 = null; if (iter1.hasNext()) { Record rec = iter1.next(); rec.getFieldInto(0, key); v1 = rec.getField(1, TestData.Value.class); } else if (iter2.hasNext()) { Record rec = iter2.next(); rec.getFieldInto(0, key); v2 = rec.getField(1, TestData.Value.class); } else { Assert.fail("No input on both sides."); } // assert that matches for this key exist Assert.assertTrue("No matches for key " + key, expectedCoGroupsMap.containsKey(key)); Collection<TestData.Value> expValues1 = expectedCoGroupsMap.get(key).get(0); Collection<TestData.Value> expValues2 = expectedCoGroupsMap.get(key).get(1); if (v1 != null) { expValues1.remove(v1); } else { expValues2.remove(v2); } while(iter1.hasNext()) { Record rec = iter1.next(); Assert.assertTrue("Value not in expected set of first input", expValues1.remove(rec.getField(1, TestData.Value.class))); } Assert.assertTrue("Expected set of first input not empty", expValues1.isEmpty()); while(iter2.hasNext()) { Record rec = iter2.next(); Assert.assertTrue("Value not in expected set of second input", expValues2.remove(rec.getField(1, TestData.Value.class))); } Assert.assertTrue("Expected set of second input not empty", expValues2.isEmpty()); expectedCoGroupsMap.remove(key); } iterator.close(); Assert.assertTrue("Expected key set not empty", expectedCoGroupsMap.isEmpty()); } catch (Exception e) { e.printStackTrace(); Assert.fail("An exception occurred during the test: " + e.getMessage()); } } // -------------------------------------------------------------------------------------------- private Map<TestData.Key, List<Collection<TestData.Value>>> coGroupValues( Map<TestData.Key, Collection<TestData.Value>> leftMap, Map<TestData.Key, Collection<TestData.Value>> rightMap) { Map<TestData.Key, List<Collection<TestData.Value>>> map = new HashMap<TestData.Key, List<Collection<TestData.Value>>>(1000); Set<TestData.Key> keySet = new HashSet<TestData.Key>(leftMap.keySet()); keySet.addAll(rightMap.keySet()); for (TestData.Key key : keySet) { Collection<TestData.Value> leftValues = leftMap.get(key); Collection<TestData.Value> rightValues = rightMap.get(key); ArrayList<Collection<TestData.Value>> list = new ArrayList<Collection<TestData.Value>>(2); if (leftValues == null) { list.add(new ArrayList<TestData.Value>(0)); } else { list.add(leftValues); } if (rightValues == null) { list.add(new ArrayList<TestData.Value>(0)); } else { list.add(rightValues); } map.put(key, list); } return map; } private Map<TestData.Key, Collection<TestData.Value>> collectData(Generator iter, int num) throws Exception { Map<TestData.Key, Collection<TestData.Value>> map = new HashMap<TestData.Key, Collection<TestData.Value>>(); Record pair = new Record(); for (int i = 0; i < num; i++) { iter.next(pair); TestData.Key key = pair.getField(0, TestData.Key.class); if (!map.containsKey(key)) { map.put(new TestData.Key(key.getKey()), new ArrayList<TestData.Value>()); } Collection<TestData.Value> values = map.get(key); values.add(new TestData.Value(pair.getField(1, TestData.Value.class).getValue())); } return map; } }