/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.jobgraph;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.io.GenericInputFormat;
import org.apache.flink.api.common.io.InitializeOnMaster;
import org.apache.flink.api.common.io.InputFormat;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.core.io.GenericInputSplit;
import org.apache.flink.core.io.InputSplit;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.*;
@SuppressWarnings("serial")
public class JobTaskVertexTest {
@Test
public void testConnectDirectly() {
JobVertex source = new JobVertex("source");
JobVertex target = new JobVertex("target");
target.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
assertTrue(source.isInputVertex());
assertFalse(source.isOutputVertex());
assertFalse(target.isInputVertex());
assertTrue(target.isOutputVertex());
assertEquals(1, source.getNumberOfProducedIntermediateDataSets());
assertEquals(1, target.getNumberOfInputs());
assertEquals(target.getInputs().get(0).getSource(), source.getProducedDataSets().get(0));
assertEquals(1, source.getProducedDataSets().get(0).getConsumers().size());
assertEquals(target, source.getProducedDataSets().get(0).getConsumers().get(0).getTarget());
}
@Test
public void testConnectMultipleTargets() {
JobVertex source = new JobVertex("source");
JobVertex target1= new JobVertex("target1");
JobVertex target2 = new JobVertex("target2");
target1.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
target2.connectDataSetAsInput(source.getProducedDataSets().get(0), DistributionPattern.ALL_TO_ALL);
assertTrue(source.isInputVertex());
assertFalse(source.isOutputVertex());
assertFalse(target1.isInputVertex());
assertTrue(target1.isOutputVertex());
assertFalse(target2.isInputVertex());
assertTrue(target2.isOutputVertex());
assertEquals(1, source.getNumberOfProducedIntermediateDataSets());
assertEquals(2, source.getProducedDataSets().get(0).getConsumers().size());
assertEquals(target1.getInputs().get(0).getSource(), source.getProducedDataSets().get(0));
assertEquals(target2.getInputs().get(0).getSource(), source.getProducedDataSets().get(0));
}
@Test
public void testOutputFormatVertex() {
try {
final TestingOutputFormat outputFormat = new TestingOutputFormat();
final OutputFormatVertex of = new OutputFormatVertex("Name");
new TaskConfig(of.getConfiguration()).setStubWrapper(new UserCodeObjectWrapper<OutputFormat<?>>(outputFormat));
final ClassLoader cl = getClass().getClassLoader();
try {
of.initializeOnMaster(cl);
fail("Did not throw expected exception.");
} catch (TestException e) {
// all good
}
OutputFormatVertex copy = SerializationUtils.clone(of);
try {
copy.initializeOnMaster(cl);
fail("Did not throw expected exception.");
} catch (TestException e) {
// all good
}
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testInputFormatVertex() {
try {
final TestInputFormat inputFormat = new TestInputFormat();
final InputFormatVertex vertex = new InputFormatVertex("Name");
new TaskConfig(vertex.getConfiguration()).setStubWrapper(new UserCodeObjectWrapper<InputFormat<?, ?>>(inputFormat));
final ClassLoader cl = getClass().getClassLoader();
vertex.initializeOnMaster(cl);
InputSplit[] splits = vertex.getInputSplitSource().createInputSplits(77);
assertNotNull(splits);
assertEquals(1, splits.length);
assertEquals(TestSplit.class, splits[0].getClass());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
// --------------------------------------------------------------------------------------------
private static final class TestingOutputFormat extends DiscardingOutputFormat<Object> implements InitializeOnMaster {
@Override
public void initializeGlobal(int parallelism) throws IOException {
throw new TestException();
}
}
private static final class TestException extends IOException {}
// --------------------------------------------------------------------------------------------
private static final class TestSplit extends GenericInputSplit {
public TestSplit(int partitionNumber, int totalNumberOfPartitions) {
super(partitionNumber, totalNumberOfPartitions);
}
}
private static final class TestInputFormat extends GenericInputFormat<Object> {
@Override
public boolean reachedEnd() {
return false;
}
@Override
public Object nextRecord(Object reuse) {
return null;
}
@Override
public GenericInputSplit[] createInputSplits(int numSplits) throws IOException {
return new GenericInputSplit[] { new TestSplit(0, 1) };
}
}
}