/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer; import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.MutableObjectIterator; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.IOException; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class IteratorWrappingTestSingleInputGate<T extends IOReadableWritable> extends TestSingleInputGate { private final TestInputChannel inputChannel = new TestInputChannel(inputGate, 0); private final int bufferSize; private MutableObjectIterator<T> inputIterator; private RecordSerializer<T> serializer; private final T reuse; public IteratorWrappingTestSingleInputGate(int bufferSize, Class<T> recordType, MutableObjectIterator<T> iterator) throws IOException, InterruptedException { super(1, false); this.bufferSize = bufferSize; this.reuse = InstantiationUtil.instantiate(recordType); wrapIterator(iterator); } private IteratorWrappingTestSingleInputGate<T> wrapIterator(MutableObjectIterator<T> iterator) throws IOException, InterruptedException { inputIterator = iterator; serializer = new SpanningRecordSerializer<T>(); // The input iterator can produce an infinite stream. That's why we have to serialize each // record on demand and cannot do it upfront. final Answer<InputChannel.BufferAndAvailability> answer = new Answer<InputChannel.BufferAndAvailability>() { private boolean hasData = inputIterator.next(reuse) != null; @Override public InputChannel.BufferAndAvailability answer(InvocationOnMock invocationOnMock) throws Throwable { if (hasData) { final Buffer buffer = new Buffer(MemorySegmentFactory.allocateUnpooledSegment(bufferSize), mock(BufferRecycler.class)); serializer.setNextBuffer(buffer); serializer.addRecord(reuse); hasData = inputIterator.next(reuse) != null; // Call getCurrentBuffer to ensure size is set return new InputChannel.BufferAndAvailability(serializer.getCurrentBuffer(), true); } else { when(inputChannel.getInputChannel().isReleased()).thenReturn(true); return new InputChannel.BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), false); } } }; when(inputChannel.getInputChannel().getNextBuffer()).thenAnswer(answer); inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel.getInputChannel()); return this; } public IteratorWrappingTestSingleInputGate<T> notifyNonEmpty() { inputGate.notifyChannelNonEmpty(inputChannel.getInputChannel()); return this; } }