/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network.partition.consumer;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.OngoingStubbing;
import java.io.IOException;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* A mocked input channel.
*/
public class TestInputChannel {
private final InputChannel mock = Mockito.mock(InputChannel.class);
private final SingleInputGate inputGate;
// Abusing Mockito here... ;)
protected OngoingStubbing<InputChannel.BufferAndAvailability> stubbing;
public TestInputChannel(SingleInputGate inputGate, int channelIndex) {
checkArgument(channelIndex >= 0);
this.inputGate = checkNotNull(inputGate);
when(mock.getChannelIndex()).thenReturn(channelIndex);
}
public TestInputChannel read(Buffer buffer) throws IOException, InterruptedException {
if (stubbing == null) {
stubbing = when(mock.getNextBuffer()).thenReturn(new InputChannel.BufferAndAvailability(buffer, true));
} else {
stubbing = stubbing.thenReturn(new InputChannel.BufferAndAvailability(buffer, true));
}
return this;
}
public TestInputChannel readBuffer() throws IOException, InterruptedException {
final Buffer buffer = mock(Buffer.class);
when(buffer.isBuffer()).thenReturn(true);
return read(buffer);
}
public TestInputChannel readEndOfPartitionEvent() throws IOException, InterruptedException {
final Answer<InputChannel.BufferAndAvailability> answer = new Answer<InputChannel.BufferAndAvailability>() {
@Override
public InputChannel.BufferAndAvailability answer(InvocationOnMock invocationOnMock) throws Throwable {
// Return true after finishing
when(mock.isReleased()).thenReturn(true);
return new InputChannel.BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), false);
}
};
if (stubbing == null) {
stubbing = when(mock.getNextBuffer()).thenAnswer(answer);
} else {
stubbing = stubbing.thenAnswer(answer);
}
return this;
}
public InputChannel getInputChannel() {
return mock;
}
// ------------------------------------------------------------------------
/**
* Creates test input channels and attaches them to the specified input gate.
*
* @return The created test input channels.
*/
public static TestInputChannel[] createInputChannels(SingleInputGate inputGate, int numberOfInputChannels) {
checkNotNull(inputGate);
checkArgument(numberOfInputChannels > 0);
TestInputChannel[] mocks = new TestInputChannel[numberOfInputChannels];
for (int i = 0; i < numberOfInputChannels; i++) {
mocks[i] = new TestInputChannel(inputGate, i);
inputGate.setInputChannel(new IntermediateResultPartitionID(), mocks[i].getInputChannel());
}
return mocks;
}
}