/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.operators.util; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.base.IntComparator; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; import org.apache.flink.runtime.operators.shipping.OutputEmitter; import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory; import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory; import org.apache.flink.types.DeserializationException; import org.apache.flink.types.DoubleValue; import org.apache.flink.types.IntValue; import org.apache.flink.types.KeyFieldOutOfBoundsException; import org.apache.flink.types.NullKeyFieldException; import org.apache.flink.types.Record; import org.apache.flink.types.StringValue; import org.junit.Assert; import org.junit.Test; import java.io.IOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class OutputEmitterTest { @Test public void testPartitionHash() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, intComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); int numChans = 100; int numRecs = 50000; int[] hit = new int[numChans]; for (int i = 0; i < numRecs; i++) { IntValue k = new IntValue(i); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe1.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } int cnt = 0; for (int aHit : hit) { assertTrue(aHit > 0); cnt += aHit; } assertTrue(cnt == numRecs); // Test for StringValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] {0}, new Class[] {StringValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, stringComp); numChans = 100; numRecs = 10000; hit = new int[numChans]; for (int i = 0; i < numRecs; i++) { StringValue k = new StringValue(i + ""); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe2.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } cnt = 0; for (int aHit : hit) { assertTrue(aHit > 0); cnt += aHit; } assertTrue(cnt == numRecs); // test hash corner cases final TestIntComparator testIntComp = new TestIntComparator(); final ChannelSelector<SerializationDelegate<Integer>> oe3 = new OutputEmitter<Integer>(ShipStrategyType.PARTITION_HASH, testIntComp); final SerializationDelegate<Integer> intDel = new SerializationDelegate<Integer>(new IntSerializer()); numChans = 100; // MinVal hash intDel.setInstance(Integer.MIN_VALUE); int[] chans = oe3.selectChannels(intDel, numChans); assertTrue(chans.length == 1); assertTrue(chans[0] >= 0 && chans[0] <= numChans-1); // -1 hash intDel.setInstance(-1); chans = oe3.selectChannels(intDel, hit.length); assertTrue(chans.length == 1); assertTrue(chans[0] >= 0 && chans[0] <= numChans-1); // 0 hash intDel.setInstance(0); chans = oe3.selectChannels(intDel, hit.length); assertTrue(chans.length == 1); assertTrue(chans[0] >= 0 && chans[0] <= numChans-1); // 1 hash intDel.setInstance(1); chans = oe3.selectChannels(intDel, hit.length); assertTrue(chans.length == 1); assertTrue(chans[0] >= 0 && chans[0] <= numChans-1); // MaxVal hash intDel.setInstance(Integer.MAX_VALUE); chans = oe3.selectChannels(intDel, hit.length); assertTrue(chans.length == 1); assertTrue(chans[0] >= 0 && chans[0] <= numChans-1); } @Test public void testForward() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, intComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); int numChannels = 100; int numRecords = 50000 + numChannels / 2; int[] hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { IntValue k = new IntValue(i); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe1.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } assertTrue(hit[0] == numRecords); for (int i = 1; i < hit.length; i++) { assertTrue(hit[i] == 0); } // Test for StringValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] {0}, new Class[] {StringValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, stringComp); numChannels = 100; numRecords = 10000 + numChannels / 2; hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { StringValue k = new StringValue(i + ""); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe2.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } assertTrue(hit[0] == numRecords); for (int i = 1; i < hit.length; i++) { assertTrue(hit[i] == 0); } } @Test public void testForcedRebalance() { // Test for IntValue int numChannels = 100; int toTaskIndex = numChannels * 6/7; int fromTaskIndex = toTaskIndex + numChannels; int extraRecords = numChannels / 3; int numRecords = 50000 + extraRecords; final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); int[] hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { IntValue k = new IntValue(i); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe1.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } int cnt = 0; for (int i = 0; i < hit.length; i++) { if (toTaskIndex <= i || i < toTaskIndex+extraRecords-numChannels) { assertTrue(hit[i] == (numRecords/numChannels)+1); } else { assertTrue(hit[i] == numRecords/numChannels); } cnt += hit[i]; } assertTrue(cnt == numRecords); // Test for StringValue numChannels = 100; toTaskIndex = numChannels / 5; fromTaskIndex = toTaskIndex + 2 * numChannels; extraRecords = numChannels * 2/9; numRecords = 10000 + extraRecords; final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex); hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { StringValue k = new StringValue(i + ""); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe2.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } cnt = 0; for (int i = 0; i < hit.length; i++) { if (toTaskIndex <= i && i < toTaskIndex+extraRecords) { assertTrue(hit[i] == (numRecords/numChannels)+1); } else { assertTrue(hit[i] == numRecords/numChannels); } cnt += hit[i]; } assertTrue(cnt == numRecords); } @Test public void testBroadcast() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.BROADCAST, intComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); int numChannels = 100; int numRecords = 50000; int[] hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { IntValue k = new IntValue(i); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe1.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } for (int aHit : hit) { assertTrue(aHit + "", aHit == numRecords); } // Test for StringValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] {0}, new Class[] {StringValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.BROADCAST, stringComp); numChannels = 100; numRecords = 5000; hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { StringValue k = new StringValue(i + ""); Record rec = new Record(k); delegate.setInstance(rec); int[] chans = oe2.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } for (int aHit : hit) { assertTrue(aHit + "", aHit == numRecords); } } @Test public void testMultiKeys() { @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> multiComp = new RecordComparatorFactory(new int[] {0,1,3}, new Class[] {IntValue.class, StringValue.class, DoubleValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, multiComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); int numChannels = 100; int numRecords = 5000; int[] hit = new int[numChannels]; for (int i = 0; i < numRecords; i++) { Record rec = new Record(4); rec.setField(0, new IntValue(i)); rec.setField(1, new StringValue("AB"+i+"CD"+i)); rec.setField(3, new DoubleValue(i*3.141d)); delegate.setInstance(rec); int[] chans = oe1.selectChannels(delegate, hit.length); for (int chan : chans) { hit[chan]++; } } int cnt = 0; for (int aHit : hit) { assertTrue(aHit > 0); cnt += aHit; } assertTrue(cnt == numRecords); } @Test public void testMissingKey() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {1}, new Class[] {IntValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, intComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); Record rec = new Record(0); rec.setField(0, new IntValue(1)); delegate.setInstance(rec); try { oe1.selectChannels(delegate, 100); } catch (KeyFieldOutOfBoundsException re) { Assert.assertEquals(1, re.getFieldNumber()); return; } Assert.fail("Expected a KeyFieldOutOfBoundsException."); } @Test public void testNullKey() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, intComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); Record rec = new Record(2); rec.setField(1, new IntValue(1)); delegate.setInstance(rec); try { oe1.selectChannels(delegate, 100); } catch (NullKeyFieldException re) { Assert.assertEquals(0, re.getFieldNumber()); return; } Assert.fail("Expected a NullKeyFieldException."); } @Test public void testWrongKeyClass() { // Test for IntValue @SuppressWarnings({"unchecked", "rawtypes"}) final TypeComparator<Record> doubleComp = new RecordComparatorFactory(new int[] {0}, new Class[] {DoubleValue.class}).createComparator(); final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, doubleComp); final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer()); ; Record rec = null; try { PipedInputStream pipedInput = new PipedInputStream(1024*1024); DataInputView in = new DataInputViewStreamWrapper(pipedInput); DataOutputView out = new DataOutputViewStreamWrapper(new PipedOutputStream(pipedInput)); rec = new Record(1); rec.setField(0, new IntValue()); rec.write(out); rec = new Record(); rec.read(in); } catch (IOException e) { fail("Test erroneous"); } try { delegate.setInstance(rec); oe1.selectChannels(delegate, 100); } catch (DeserializationException re) { return; } Assert.fail("Expected a NullKeyFieldException."); } @SuppressWarnings({"serial", "rawtypes"}) private static class TestIntComparator extends TypeComparator<Integer> { private TypeComparator[] comparators = new TypeComparator[]{new IntComparator(true)}; @Override public int hash(Integer record) { return record; } @Override public void setReference(Integer toCompare) { throw new UnsupportedOperationException(); } @Override public boolean equalToReference(Integer candidate) { throw new UnsupportedOperationException(); } @Override public int compareToReference( TypeComparator<Integer> referencedComparator) { throw new UnsupportedOperationException(); } @Override public int compare(Integer first, Integer second) { throw new UnsupportedOperationException(); } @Override public int compareSerialized(DataInputView firstSource, DataInputView secondSource) { throw new UnsupportedOperationException(); } @Override public boolean supportsNormalizedKey() { throw new UnsupportedOperationException(); } @Override public boolean supportsSerializationWithKeyNormalization() { throw new UnsupportedOperationException(); } @Override public int getNormalizeKeyLen() { throw new UnsupportedOperationException(); } @Override public boolean isNormalizedKeyPrefixOnly(int keyBytes) { throw new UnsupportedOperationException(); } @Override public void putNormalizedKey(Integer record, MemorySegment target, int offset, int numBytes) { throw new UnsupportedOperationException(); } @Override public void writeWithKeyNormalization(Integer record, DataOutputView target) throws IOException { throw new UnsupportedOperationException(); } @Override public Integer readWithKeyDenormalization(Integer reuse, DataInputView source) throws IOException { throw new UnsupportedOperationException(); } @Override public boolean invertNormalizedKey() { throw new UnsupportedOperationException(); } @Override public TypeComparator<Integer> duplicate() { throw new UnsupportedOperationException(); } @Override public int extractKeys(Object record, Object[] target, int index) { target[index] = record; return 1; } @Override public TypeComparator[] getFlatComparators() { return comparators; } } }