/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // We have it in this package because we could not mock the methods otherwise package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer; import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.streaming.runtime.streamrecord.StreamElement; import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.IOException; import java.util.concurrent.ConcurrentLinkedQueue; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * Test {@link InputGate} that allows setting multiple channels. Use * {@link #sendElement(Object, int)} to offer an element on a specific channel. Use * {@link #sendEvent(AbstractEvent, int)} to offer an event on the specified channel. Use * {@link #endInput()} to notify all channels of input end. */ public class StreamTestSingleInputGate<T> extends TestSingleInputGate { private final int numInputChannels; private final TestInputChannel[] inputChannels; private final int bufferSize; private TypeSerializer<T> serializer; private ConcurrentLinkedQueue<InputValue<Object>>[] inputQueues; @SuppressWarnings("unchecked") public StreamTestSingleInputGate( int numInputChannels, int bufferSize, TypeSerializer<T> serializer) throws IOException, InterruptedException { super(numInputChannels, false); this.bufferSize = bufferSize; this.serializer = serializer; this.numInputChannels = numInputChannels; inputChannels = new TestInputChannel[numInputChannels]; inputQueues = new ConcurrentLinkedQueue[numInputChannels]; setupInputChannels(); doReturn(bufferSize).when(inputGate).getPageSize(); } @SuppressWarnings("unchecked") private void setupInputChannels() throws IOException, InterruptedException { for (int i = 0; i < numInputChannels; i++) { final int channelIndex = i; final RecordSerializer<SerializationDelegate<Object>> recordSerializer = new SpanningRecordSerializer<SerializationDelegate<Object>>(); final SerializationDelegate<Object> delegate = (SerializationDelegate<Object>) (SerializationDelegate<?>) new SerializationDelegate<StreamElement>(new StreamElementSerializer<T>(serializer)); inputQueues[channelIndex] = new ConcurrentLinkedQueue<InputValue<Object>>(); inputChannels[channelIndex] = new TestInputChannel(inputGate, i); final Answer<BufferAndAvailability> answer = new Answer<BufferAndAvailability>() { @Override public BufferAndAvailability answer(InvocationOnMock invocationOnMock) throws Throwable { InputValue<Object> input = inputQueues[channelIndex].poll(); if (input != null && input.isStreamEnd()) { when(inputChannels[channelIndex].getInputChannel().isReleased()).thenReturn( true); return new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), false); } else if (input != null && input.isStreamRecord()) { Object inputElement = input.getStreamRecord(); final Buffer buffer = new Buffer( MemorySegmentFactory.allocateUnpooledSegment(bufferSize), mock(BufferRecycler.class)); recordSerializer.setNextBuffer(buffer); delegate.setInstance(inputElement); recordSerializer.addRecord(delegate); // Call getCurrentBuffer to ensure size is set return new BufferAndAvailability(recordSerializer.getCurrentBuffer(), false); } else if (input != null && input.isEvent()) { AbstractEvent event = input.getEvent(); return new BufferAndAvailability(EventSerializer.toBuffer(event), false); } else { synchronized (inputQueues[channelIndex]) { inputQueues[channelIndex].wait(); return answer(invocationOnMock); } } } }; when(inputChannels[channelIndex].getInputChannel().getNextBuffer()).thenAnswer(answer); inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannels[channelIndex].getInputChannel()); } } public void sendElement(Object element, int channel) { synchronized (inputQueues[channel]) { inputQueues[channel].add(InputValue.element(element)); inputQueues[channel].notifyAll(); } inputGate.notifyChannelNonEmpty(inputChannels[channel].getInputChannel()); } public void sendEvent(AbstractEvent event, int channel) { synchronized (inputQueues[channel]) { inputQueues[channel].add(InputValue.event(event)); inputQueues[channel].notifyAll(); } inputGate.notifyChannelNonEmpty(inputChannels[channel].getInputChannel()); } public void endInput() { for (int i = 0; i < numInputChannels; i++) { synchronized (inputQueues[i]) { inputQueues[i].add(InputValue.streamEnd()); inputQueues[i].notifyAll(); } inputGate.notifyChannelNonEmpty(inputChannels[i].getInputChannel()); } } /** * Returns true iff all input queues are empty. */ public boolean allQueuesEmpty() { // for (int i = 0; i < numInputChannels; i++) { // synchronized (inputQueues[i]) { // inputQueues[i].add(InputValue.<T>event(new DummyEvent())); // inputQueues[i].notifyAll(); // inputGate.onAvailableBuffer(inputChannels[i].getInputChannel()); // } // } for (int i = 0; i < numInputChannels; i++) { if (inputQueues[i].size() > 0) { return false; } } return true; } public static class InputValue<T> { private Object elementOrEvent; private boolean isStreamEnd; private boolean isStreamRecord; private boolean isEvent; private InputValue(Object elementOrEvent, boolean isStreamEnd, boolean isEvent, boolean isStreamRecord) { this.elementOrEvent = elementOrEvent; this.isStreamEnd = isStreamEnd; this.isStreamRecord = isStreamRecord; this.isEvent = isEvent; } public static <X> InputValue<X> element(Object element) { return new InputValue<X>(element, false, false, true); } public static <X> InputValue<X> streamEnd() { return new InputValue<X>(null, true, false, false); } public static <X> InputValue<X> event(AbstractEvent event) { return new InputValue<X>(event, false, true, false); } public Object getStreamRecord() { return elementOrEvent; } public AbstractEvent getEvent() { return (AbstractEvent) elementOrEvent; } public boolean isStreamEnd() { return isStreamEnd; } public boolean isStreamRecord() { return isStreamRecord; } public boolean isEvent() { return isEvent; } } }