/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.optimizer.programs; import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.operators.util.FieldList; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.optimizer.dag.TempMode; import org.apache.flink.optimizer.plan.DualInputPlanNode; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plan.SingleInputPlanNode; import org.apache.flink.optimizer.plan.SinkPlanNode; import org.apache.flink.optimizer.plan.SourcePlanNode; import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.runtime.io.network.DataExchangeMode; import org.apache.flink.runtime.operators.DriverStrategy; import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.apache.flink.runtime.operators.util.LocalStrategy; import org.apache.flink.optimizer.util.CompilerTestBase; import org.apache.flink.util.Collector; import org.junit.Assert; import org.junit.Test; @SuppressWarnings("serial") public class ConnectedComponentsTest extends CompilerTestBase { private static final String VERTEX_SOURCE = "Vertices"; private static final String ITERATION_NAME = "Connected Components Iteration"; private static final String EDGES_SOURCE = "Edges"; private static final String JOIN_NEIGHBORS_MATCH = "Join Candidate Id With Neighbor"; private static final String MIN_ID_REDUCER = "Find Minimum Candidate Id"; private static final String UPDATE_ID_MATCH = "Update Component Id"; private static final String SINK = "Result"; private final FieldList set0 = new FieldList(0); @Test public void testWorksetConnectedComponents() { Plan plan = getConnectedComponentsPlan(DEFAULT_PARALLELISM, 100, false); OptimizedPlan optPlan = compileNoStats(plan); OptimizerPlanNodeResolver or = getOptimizerPlanNodeResolver(optPlan); SourcePlanNode vertexSource = or.getNode(VERTEX_SOURCE); SourcePlanNode edgesSource = or.getNode(EDGES_SOURCE); SinkPlanNode sink = or.getNode(SINK); WorksetIterationPlanNode iter = or.getNode(ITERATION_NAME); DualInputPlanNode neighborsJoin = or.getNode(JOIN_NEIGHBORS_MATCH); SingleInputPlanNode minIdReducer = or.getNode(MIN_ID_REDUCER); SingleInputPlanNode minIdCombiner = (SingleInputPlanNode) minIdReducer.getPredecessor(); DualInputPlanNode updatingMatch = or.getNode(UPDATE_ID_MATCH); // test all drivers Assert.assertEquals(DriverStrategy.NONE, sink.getDriverStrategy()); Assert.assertEquals(DriverStrategy.NONE, vertexSource.getDriverStrategy()); Assert.assertEquals(DriverStrategy.NONE, edgesSource.getDriverStrategy()); Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED, neighborsJoin.getDriverStrategy()); Assert.assertTrue(!neighborsJoin.getInput1().getTempMode().isCached()); Assert.assertTrue(!neighborsJoin.getInput2().getTempMode().isCached()); Assert.assertEquals(set0, neighborsJoin.getKeysForInput1()); Assert.assertEquals(set0, neighborsJoin.getKeysForInput2()); Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_SECOND, updatingMatch.getDriverStrategy()); Assert.assertEquals(set0, updatingMatch.getKeysForInput1()); Assert.assertEquals(set0, updatingMatch.getKeysForInput2()); // test all the shipping strategies Assert.assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, iter.getInitialSolutionSetInput().getShipStrategy()); Assert.assertEquals(set0, iter.getInitialSolutionSetInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, iter.getInitialWorksetInput().getShipStrategy()); Assert.assertEquals(set0, iter.getInitialWorksetInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.FORWARD, neighborsJoin.getInput1().getShipStrategy()); // workset Assert.assertEquals(ShipStrategyType.PARTITION_HASH, neighborsJoin.getInput2().getShipStrategy()); // edges Assert.assertEquals(set0, neighborsJoin.getInput2().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, minIdReducer.getInput().getShipStrategy()); Assert.assertEquals(set0, minIdReducer.getInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.FORWARD, minIdCombiner.getInput().getShipStrategy()); Assert.assertEquals(ShipStrategyType.FORWARD, updatingMatch.getInput1().getShipStrategy()); // min id Assert.assertEquals(ShipStrategyType.FORWARD, updatingMatch.getInput2().getShipStrategy()); // solution set // test all the local strategies Assert.assertEquals(LocalStrategy.NONE, sink.getInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, iter.getInitialSolutionSetInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, iter.getInitialWorksetInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, neighborsJoin.getInput1().getLocalStrategy()); // workset Assert.assertEquals(LocalStrategy.NONE, neighborsJoin.getInput2().getLocalStrategy()); // edges Assert.assertEquals(LocalStrategy.COMBININGSORT, minIdReducer.getInput().getLocalStrategy()); Assert.assertEquals(set0, minIdReducer.getInput().getLocalStrategyKeys()); Assert.assertEquals(LocalStrategy.NONE, minIdCombiner.getInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, updatingMatch.getInput1().getLocalStrategy()); // min id Assert.assertEquals(LocalStrategy.NONE, updatingMatch.getInput2().getLocalStrategy()); // solution set // check the dams Assert.assertEquals(TempMode.NONE, iter.getInitialWorksetInput().getTempMode()); Assert.assertEquals(TempMode.NONE, iter.getInitialSolutionSetInput().getTempMode()); Assert.assertEquals(DataExchangeMode.BATCH, iter.getInitialWorksetInput().getDataExchangeMode()); Assert.assertEquals(DataExchangeMode.BATCH, iter.getInitialSolutionSetInput().getDataExchangeMode()); JobGraphGenerator jgg = new JobGraphGenerator(); jgg.compileJobGraph(optPlan); } @Test public void testWorksetConnectedComponentsWithSolutionSetAsFirstInput() { Plan plan = getConnectedComponentsPlan(DEFAULT_PARALLELISM, 100, true); OptimizedPlan optPlan = compileNoStats(plan); OptimizerPlanNodeResolver or = getOptimizerPlanNodeResolver(optPlan); SourcePlanNode vertexSource = or.getNode(VERTEX_SOURCE); SourcePlanNode edgesSource = or.getNode(EDGES_SOURCE); SinkPlanNode sink = or.getNode(SINK); WorksetIterationPlanNode iter = or.getNode(ITERATION_NAME); DualInputPlanNode neighborsJoin = or.getNode(JOIN_NEIGHBORS_MATCH); SingleInputPlanNode minIdReducer = or.getNode(MIN_ID_REDUCER); SingleInputPlanNode minIdCombiner = (SingleInputPlanNode) minIdReducer.getPredecessor(); DualInputPlanNode updatingMatch = or.getNode(UPDATE_ID_MATCH); // test all drivers Assert.assertEquals(DriverStrategy.NONE, sink.getDriverStrategy()); Assert.assertEquals(DriverStrategy.NONE, vertexSource.getDriverStrategy()); Assert.assertEquals(DriverStrategy.NONE, edgesSource.getDriverStrategy()); Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED, neighborsJoin.getDriverStrategy()); Assert.assertTrue(!neighborsJoin.getInput1().getTempMode().isCached()); Assert.assertTrue(!neighborsJoin.getInput2().getTempMode().isCached()); Assert.assertEquals(set0, neighborsJoin.getKeysForInput1()); Assert.assertEquals(set0, neighborsJoin.getKeysForInput2()); Assert.assertEquals(DriverStrategy.HYBRIDHASH_BUILD_FIRST, updatingMatch.getDriverStrategy()); Assert.assertEquals(set0, updatingMatch.getKeysForInput1()); Assert.assertEquals(set0, updatingMatch.getKeysForInput2()); // test all the shipping strategies Assert.assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, iter.getInitialSolutionSetInput().getShipStrategy()); Assert.assertEquals(set0, iter.getInitialSolutionSetInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, iter.getInitialWorksetInput().getShipStrategy()); Assert.assertEquals(set0, iter.getInitialWorksetInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.FORWARD, neighborsJoin.getInput1().getShipStrategy()); // workset Assert.assertEquals(ShipStrategyType.PARTITION_HASH, neighborsJoin.getInput2().getShipStrategy()); // edges Assert.assertEquals(set0, neighborsJoin.getInput2().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.PARTITION_HASH, minIdReducer.getInput().getShipStrategy()); Assert.assertEquals(set0, minIdReducer.getInput().getShipStrategyKeys()); Assert.assertEquals(ShipStrategyType.FORWARD, minIdCombiner.getInput().getShipStrategy()); Assert.assertEquals(ShipStrategyType.FORWARD, updatingMatch.getInput1().getShipStrategy()); // solution set Assert.assertEquals(ShipStrategyType.FORWARD, updatingMatch.getInput2().getShipStrategy()); // min id // test all the local strategies Assert.assertEquals(LocalStrategy.NONE, sink.getInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, iter.getInitialSolutionSetInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, iter.getInitialWorksetInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, neighborsJoin.getInput1().getLocalStrategy()); // workset Assert.assertEquals(LocalStrategy.NONE, neighborsJoin.getInput2().getLocalStrategy()); // edges Assert.assertEquals(LocalStrategy.COMBININGSORT, minIdReducer.getInput().getLocalStrategy()); Assert.assertEquals(set0, minIdReducer.getInput().getLocalStrategyKeys()); Assert.assertEquals(LocalStrategy.NONE, minIdCombiner.getInput().getLocalStrategy()); Assert.assertEquals(LocalStrategy.NONE, updatingMatch.getInput1().getLocalStrategy()); // min id Assert.assertEquals(LocalStrategy.NONE, updatingMatch.getInput2().getLocalStrategy()); // solution set // check the dams Assert.assertEquals(TempMode.NONE, iter.getInitialWorksetInput().getTempMode()); Assert.assertEquals(TempMode.NONE, iter.getInitialSolutionSetInput().getTempMode()); Assert.assertEquals(DataExchangeMode.BATCH, iter.getInitialWorksetInput().getDataExchangeMode()); Assert.assertEquals(DataExchangeMode.BATCH, iter.getInitialSolutionSetInput().getDataExchangeMode()); JobGraphGenerator jgg = new JobGraphGenerator(); jgg.compileJobGraph(optPlan); } private static Plan getConnectedComponentsPlan(int parallelism, int iterations, boolean solutionSetFirst) { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Tuple2<Long, Long>> verticesWithId = env.generateSequence(0, 1000).name("Vertices") .map(new MapFunction<Long, Tuple2<Long, Long>>() { @Override public Tuple2<Long, Long> map(Long value) { return new Tuple2<Long, Long>(value, value); } }).name("Assign Vertex Ids"); DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration = verticesWithId.iterateDelta(verticesWithId, iterations, 0).name("Connected Components Iteration"); @SuppressWarnings("unchecked") DataSet<Tuple2<Long, Long>> edges = env.fromElements(new Tuple2<Long, Long>(0L, 0L)).name("Edges"); DataSet<Tuple2<Long, Long>> minCandidateId = iteration.getWorkset().join(edges) .where(0).equalTo(0) .projectSecond(1).<Tuple2<Long, Long>>projectFirst(1).name("Join Candidate Id With Neighbor") .groupBy(0) .min(1) .name("Find Minimum Candidate Id"); DataSet<Tuple2<Long, Long>> updateComponentId; if (solutionSetFirst) { updateComponentId = iteration.getSolutionSet().join(minCandidateId) .where(0).equalTo(0) .with(new FlatJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>>() { @Override public void join(Tuple2<Long, Long> current, Tuple2<Long, Long> candidate, Collector<Tuple2<Long, Long>> out) { if (candidate.f1 < current.f1) { out.collect(candidate); } } }) .withForwardedFieldsFirst("0") .withForwardedFieldsSecond("0") .name("Update Component Id"); } else { updateComponentId = minCandidateId.join(iteration.getSolutionSet()) .where(0).equalTo(0) .with(new FlatJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>>() { @Override public void join(Tuple2<Long, Long> candidate, Tuple2<Long, Long> current, Collector<Tuple2<Long, Long>> out) { if (candidate.f1 < current.f1) { out.collect(candidate); } } }) .withForwardedFieldsFirst("0") .withForwardedFieldsSecond("0") .name("Update Component Id"); } iteration.closeWith(updateComponentId, updateComponentId) .output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("Result"); return env.createProgramPlan(); } }