/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.functions.async;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Test cases for {@link RichAsyncFunction}
*/
public class RichAsyncFunctionTest {
/**
* Test the set of iteration runtime context methods in the context of a
* {@link RichAsyncFunction}.
*/
@Test
public void testIterationRuntimeContext() throws Exception {
RichAsyncFunction<Integer, Integer> function = new RichAsyncFunction<Integer, Integer>() {
private static final long serialVersionUID = -2023923961609455894L;
@Override
public void asyncInvoke(Integer input, AsyncCollector<Integer> collector) throws Exception {
// no op
}
};
int superstepNumber = 42;
IterationRuntimeContext mockedIterationRuntimeContext = mock(IterationRuntimeContext.class);
when(mockedIterationRuntimeContext.getSuperstepNumber()).thenReturn(superstepNumber);
function.setRuntimeContext(mockedIterationRuntimeContext);
IterationRuntimeContext iterationRuntimeContext = function.getIterationRuntimeContext();
assertEquals(superstepNumber, iterationRuntimeContext.getSuperstepNumber());
try {
iterationRuntimeContext.getIterationAggregator("foobar");
fail("Expected getIterationAggregator to fail with unsupported operation exception");
} catch (UnsupportedOperationException e) {
// expected
}
try {
iterationRuntimeContext.getPreviousIterationAggregate("foobar");
fail("Expected getPreviousIterationAggregator to fail with unsupported operation exception");
} catch (UnsupportedOperationException e) {
// expected
}
}
/**
* Test the set of runtime context methods in the context of a {@link RichAsyncFunction}.
*/
@Test
public void testRuntimeContext() throws Exception {
RichAsyncFunction<Integer, Integer> function = new RichAsyncFunction<Integer, Integer>() {
private static final long serialVersionUID = 1707630162838967972L;
@Override
public void asyncInvoke(Integer input, AsyncCollector<Integer> collector) throws Exception {
// no op
}
};
final String taskName = "foobarTask";
final MetricGroup metricGroup = mock(MetricGroup.class);
final int numberOfParallelSubtasks = 42;
final int indexOfSubtask = 43;
final int attemptNumber = 1337;
final String taskNameWithSubtask = "barfoo";
final ExecutionConfig executionConfig = mock(ExecutionConfig.class);
final ClassLoader userCodeClassLoader = mock(ClassLoader.class);
RuntimeContext mockedRuntimeContext = mock(RuntimeContext.class);
when(mockedRuntimeContext.getTaskName()).thenReturn(taskName);
when(mockedRuntimeContext.getMetricGroup()).thenReturn(metricGroup);
when(mockedRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(numberOfParallelSubtasks);
when(mockedRuntimeContext.getIndexOfThisSubtask()).thenReturn(indexOfSubtask);
when(mockedRuntimeContext.getAttemptNumber()).thenReturn(attemptNumber);
when(mockedRuntimeContext.getTaskNameWithSubtasks()).thenReturn(taskNameWithSubtask);
when(mockedRuntimeContext.getExecutionConfig()).thenReturn(executionConfig);
when(mockedRuntimeContext.getUserCodeClassLoader()).thenReturn(userCodeClassLoader);
function.setRuntimeContext(mockedRuntimeContext);
RuntimeContext runtimeContext = function.getRuntimeContext();
assertEquals(taskName, runtimeContext.getTaskName());
assertEquals(metricGroup, runtimeContext.getMetricGroup());
assertEquals(numberOfParallelSubtasks, runtimeContext.getNumberOfParallelSubtasks());
assertEquals(indexOfSubtask, runtimeContext.getIndexOfThisSubtask());
assertEquals(attemptNumber, runtimeContext.getAttemptNumber());
assertEquals(taskNameWithSubtask, runtimeContext.getTaskNameWithSubtasks());
assertEquals(executionConfig, runtimeContext.getExecutionConfig());
assertEquals(userCodeClassLoader, runtimeContext.getUserCodeClassLoader());
try {
runtimeContext.getDistributedCache();
fail("Expected getDistributedCached to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getState(new ValueStateDescriptor<>("foobar", Integer.class, 42));
fail("Expected getState to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getListState(new ListStateDescriptor<>("foobar", Integer.class));
fail("Expected getListState to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getReducingState(new ReducingStateDescriptor<>("foobar", new ReduceFunction<Integer>() {
private static final long serialVersionUID = 2136425961884441050L;
@Override
public Integer reduce(Integer value1, Integer value2) throws Exception {
return value1;
}
}, Integer.class));
fail("Expected getReducingState to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getFoldingState(new FoldingStateDescriptor<>("foobar", 0, new FoldFunction<Integer, Integer>() {
@Override
public Integer fold(Integer accumulator, Integer value) throws Exception {
return accumulator;
}
}, Integer.class));
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getMapState(new MapStateDescriptor<>("foobar", Integer.class, String.class));
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.addAccumulator("foobar", new Accumulator<Integer, Integer>() {
private static final long serialVersionUID = -4673320336846482358L;
@Override
public void add(Integer value) {
// no op
}
@Override
public Integer getLocalValue() {
return null;
}
@Override
public void resetLocal() {
}
@Override
public void merge(Accumulator<Integer, Integer> other) {
}
@Override
public Accumulator<Integer, Integer> clone() {
return null;
}
});
fail("Expected addAccumulator to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getAccumulator("foobar");
fail("Expected getAccumulator to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getAllAccumulators();
fail("Expected getAllAccumulators to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getIntCounter("foobar");
fail("Expected getIntCounter to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getLongCounter("foobar");
fail("Expected getLongCounter to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getDoubleCounter("foobar");
fail("Expected getDoubleCounter to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getHistogram("foobar");
fail("Expected getHistogram to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.hasBroadcastVariable("foobar");
fail("Expected hasBroadcastVariable to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getBroadcastVariable("foobar");
fail("Expected getBroadcastVariable to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
try {
runtimeContext.getBroadcastVariableWithInitializer("foobar", new BroadcastVariableInitializer<Object, Object>() {
@Override
public Object initializeBroadcastVariable(Iterable<Object> data) {
return null;
}
});
fail("Expected getBroadcastVariableWithInitializer to fail with unsupported operation exception.");
} catch (UnsupportedOperationException e) {
// expected
}
}
}