/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.optimizer; import static org.junit.Assert.*; import java.util.HashSet; import java.util.Set; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.operators.Operator; import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.api.java.io.TextOutputFormat; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction; import org.apache.flink.optimizer.testfunctions.IdentityCoGrouper; import org.apache.flink.optimizer.testfunctions.IdentityCrosser; import org.apache.flink.optimizer.testfunctions.IdentityJoiner; import org.apache.flink.optimizer.testfunctions.SelectOneReducer; import org.apache.flink.optimizer.util.CompilerTestBase; import org.apache.flink.util.Collector; import org.junit.Assert; import org.junit.Test; import org.apache.flink.api.common.Plan; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichJoinFunction; import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plan.SinkPlanNode; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.optimizer.testfunctions.DummyFlatJoinFunction; import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; import org.apache.flink.optimizer.testfunctions.IdentityKeyExtractor; import org.apache.flink.optimizer.testfunctions.IdentityMapper; import org.apache.flink.optimizer.testfunctions.Top1GroupReducer; @SuppressWarnings({"serial"}) public class BranchingPlansCompilerTest extends CompilerTestBase { @Test public void testCostComputationWithMultipleDataSinks() { final int SINKS = 5; try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> source = env.generateSequence(1, 10000); DataSet<Long> mappedA = source.map(new IdentityMapper<Long>()); DataSet<Long> mappedC = source.map(new IdentityMapper<Long>()); for (int sink = 0; sink < SINKS; sink++) { mappedA.output(new DiscardingOutputFormat<Long>()); mappedC.output(new DiscardingOutputFormat<Long>()); } Plan plan = env.createProgramPlan("Plans With Multiple Data Sinks"); OptimizedPlan oPlan = compileNoStats(plan); new JobGraphGenerator().compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * * <pre> * (SRC A) * | * (MAP A) * / \ * (MAP B) (MAP C) * / / \ * (SINK A) (SINK B) (SINK C) * </pre> */ @SuppressWarnings("unchecked") @Test public void testBranchingWithMultipleDataSinks2() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> source = env.generateSequence(1, 10000); DataSet<Long> mappedA = source.map(new IdentityMapper<Long>()); DataSet<Long> mappedB = mappedA.map(new IdentityMapper<Long>()); DataSet<Long> mappedC = mappedA.map(new IdentityMapper<Long>()); mappedB.output(new DiscardingOutputFormat<Long>()); mappedC.output(new DiscardingOutputFormat<Long>()); mappedC.output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); Set<Operator<?>> sinks = new HashSet<Operator<?>>(plan.getDataSinks()); OptimizedPlan oPlan = compileNoStats(plan); // ---------- check the optimizer plan ---------- // number of sinks assertEquals("Wrong number of data sinks.", 3, oPlan.getDataSinks().size()); // remove matching sinks to check relation for (SinkPlanNode sink : oPlan.getDataSinks()) { assertTrue(sinks.remove(sink.getProgramOperator())); } assertTrue(sinks.isEmpty()); new JobGraphGenerator().compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * <pre> * SINK * | * COGROUP * +---/ \----+ * / \ * / MATCH10 * / | \ * / | MATCH9 * MATCH5 | | \ * | \ | | MATCH8 * | MATCH4 | | | \ * | | \ | | | MATCH7 * | | MATCH3 | | | | \ * | | | \ | | | | MATCH6 * | | | MATCH2 | | | | | | * | | | | \ +--+--+--+--+--+ * | | | | MATCH1 MAP * \ | | | | | /-----------/ * (DATA SOURCE ONE) * </pre> */ @Test public void testBranchingSourceMultipleTimes() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Tuple2<Long, Long>> source = env.generateSequence(1, 10000000) .map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> joined1 = source.join(source).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined2 = source.join(joined1).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined3 = source.join(joined2).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined4 = source.join(joined3).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined5 = source.join(joined4).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> mapped = source.map( new MapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>() { @Override public Tuple2<Long, Long> map(Tuple2<Long, Long> value) { return null; } }); DataSet<Tuple2<Long, Long>> joined6 = mapped.join(mapped).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined7 = mapped.join(joined6).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined8 = mapped.join(joined7).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined9 = mapped.join(joined8).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined10 = mapped.join(joined9).where(0).equalTo(0) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); joined5.coGroup(joined10) .where(1).equalTo(1) .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>()) .output(new DiscardingOutputFormat<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>>()); Plan plan = env.createProgramPlan(); OptimizedPlan oPlan = compileNoStats(plan); new JobGraphGenerator().compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * * <pre> * (SINK A) * | (SINK B) (SINK C) * CROSS / / * / \ | +------+ * / \ | / * REDUCE MATCH2 * | +---/ \ * \ / | * MAP | * | | * COGROUP MATCH1 * / \ / \ * (SRC A) (SRC B) (SRC C) * </pre> */ @Test public void testBranchingWithMultipleDataSinks() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Tuple2<Long, Long>> sourceA = env.generateSequence(1, 10000000) .map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> sourceB = env.generateSequence(1, 10000000) .map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> sourceC = env.generateSequence(1, 10000000) .map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> mapped = sourceA.coGroup(sourceB) .where(0).equalTo(1) .with(new CoGroupFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>>() { @Override public void coGroup(Iterable<Tuple2<Long, Long>> first, Iterable<Tuple2<Long, Long>> second, Collector<Tuple2<Long, Long>> out) { } }) .map(new IdentityMapper<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined = sourceB.join(sourceC) .where(0).equalTo(1) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> joined2 = mapped.join(joined) .where(1).equalTo(1) .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); DataSet<Tuple2<Long, Long>> reduced = mapped .groupBy(1) .reduceGroup(new Top1GroupReducer<Tuple2<Long, Long>>()); reduced.cross(joined2) .output(new DiscardingOutputFormat<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>>()); joined2.output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); joined2.output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); Plan plan = env.createProgramPlan(); OptimizedPlan oPlan = compileNoStats(plan); new JobGraphGenerator().compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @SuppressWarnings("unchecked") @Test public void testBranchEachContractType() { try { // construct the plan ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); DataSet<Long> sourceB = env.generateSequence(0,1); DataSet<Long> sourceC = env.generateSequence(0,1); DataSet<Long> map1 = sourceA.map(new IdentityMapper<Long>()).name("Map 1"); DataSet<Long> reduce1 = map1.groupBy("*").reduceGroup(new IdentityGroupReducer<Long>()).name("Reduce 1"); DataSet<Long> join1 = sourceB.union(sourceB).union(sourceC) .join(sourceC).where("*").equalTo("*") .with(new IdentityJoiner<Long>()).name("Join 1"); DataSet<Long> coGroup1 = sourceA.coGroup(sourceB).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 1"); DataSet<Long> cross1 = reduce1.cross(coGroup1) .with(new IdentityCrosser<Long>()).name("Cross 1"); DataSet<Long> coGroup2 = cross1.coGroup(cross1).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 2"); DataSet<Long> coGroup3 = map1.coGroup(join1).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 3"); DataSet<Long> map2 = coGroup3.map(new IdentityMapper<Long>()).name("Map 2"); DataSet<Long> coGroup4 = map2.coGroup(join1).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 4"); DataSet<Long> coGroup5 = coGroup2.coGroup(coGroup1).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 5"); DataSet<Long> coGroup6 = reduce1.coGroup(coGroup4).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 6"); DataSet<Long> coGroup7 = coGroup5.coGroup(coGroup6).where("*").equalTo("*") .with(new IdentityCoGrouper<Long>()).name("CoGroup 7"); coGroup7.union(sourceA) .union(coGroup3) .union(coGroup4) .union(coGroup1) .output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); OptimizedPlan oPlan = compileNoStats(plan); JobGraphGenerator jobGen = new JobGraphGenerator(); //Compile plan to verify that no error is thrown jobGen.compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testBranchingUnion() { try { // construct the plan ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> source1 = env.generateSequence(0,1); DataSet<Long> source2 = env.generateSequence(0,1); DataSet<Long> join1 = source1.join(source2).where("*").equalTo("*") .with(new IdentityJoiner<Long>()).name("Join 1"); DataSet<Long> map1 = join1.map(new IdentityMapper<Long>()).name("Map 1"); DataSet<Long> reduce1 = map1.groupBy("*").reduceGroup(new IdentityGroupReducer<Long>()).name("Reduce 1"); DataSet<Long> reduce2 = join1.groupBy("*").reduceGroup(new IdentityGroupReducer<Long>()).name("Reduce 2"); DataSet<Long> map2 = join1.map(new IdentityMapper<Long>()).name("Map 2"); DataSet<Long> map3 = map2.map(new IdentityMapper<Long>()).name("Map 3"); DataSet<Long> join2 = reduce1.union(reduce2).union(map2).union(map3) .join(map2, JoinHint.REPARTITION_SORT_MERGE).where("*").equalTo("*") .with(new IdentityJoiner<Long>()).name("Join 2"); join2.output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); OptimizedPlan oPlan = compileNoStats(plan); JobGraphGenerator jobGen = new JobGraphGenerator(); //Compile plan to verify that no error is thrown jobGen.compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * * <pre> * (SRC A) * / \ * (SINK A) (SINK B) * </pre> */ @Test public void testBranchingWithMultipleDataSinksSmall() { try { String outPath1 = "/tmp/out1"; String outPath2 = "/tmp/out2"; // construct the plan ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> source1 = env.generateSequence(0,1); source1.writeAsText(outPath1); source1.writeAsText(outPath2); Plan plan = env.createProgramPlan(); OptimizedPlan oPlan = compileNoStats(plan); // ---------- check the optimizer plan ---------- // number of sinks Assert.assertEquals("Wrong number of data sinks.", 2, oPlan.getDataSinks().size()); // sinks contain all sink paths Set<String> allSinks = new HashSet<String>(); allSinks.add(outPath1); allSinks.add(outPath2); for (SinkPlanNode n : oPlan.getDataSinks()) { String path = ((TextOutputFormat<String>)n.getSinkNode().getOperator() .getFormatWrapper().getUserCodeObject()).getOutputFilePath().toString(); Assert.assertTrue("Invalid data sink.", allSinks.remove(path)); } // ---------- compile plan to job graph to verify that no error is thrown ---------- JobGraphGenerator jobGen = new JobGraphGenerator(); jobGen.compileJobGraph(oPlan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * * <pre> * (SINK 3) (SINK 1) (SINK 2) (SINK 4) * \ / \ / * (SRC A) (SRC B) * </pre> * * NOTE: this case is currently not caught by the compiler. we should enable the test once it is caught. */ @Test public void testBranchingDisjointPlan() { // construct the plan final String out1Path = "file:///test/1"; final String out2Path = "file:///test/2"; final String out3Path = "file:///test/3"; final String out4Path = "file:///test/4"; // construct the plan ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); DataSet<Long> sourceB = env.generateSequence(0,1); sourceA.writeAsText(out1Path); sourceB.writeAsText(out2Path); sourceA.writeAsText(out3Path); sourceB.writeAsText(out4Path); Plan plan = env.createProgramPlan(); compileNoStats(plan); } @Test public void testBranchAfterIteration() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); IterativeDataSet<Long> loopHead = sourceA.iterate(10); DataSet<Long> loopTail = loopHead.map(new IdentityMapper<Long>()).name("Mapper"); DataSet<Long> loopRes = loopHead.closeWith(loopTail); loopRes.output(new DiscardingOutputFormat<Long>()); loopRes.map(new IdentityMapper<Long>()) .output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); try { compileNoStats(plan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testBranchBeforeIteration() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> source1 = env.generateSequence(0,1); DataSet<Long> source2 = env.generateSequence(0,1); IterativeDataSet<Long> loopHead = source2.iterate(10).name("Loop"); DataSet<Long> loopTail = source1.map(new IdentityMapper<Long>()).withBroadcastSet(loopHead, "BC").name("In-Loop Mapper"); DataSet<Long> loopRes = loopHead.closeWith(loopTail); DataSet<Long> map = source1.map(new IdentityMapper<Long>()).withBroadcastSet(loopRes, "BC").name("Post-Loop Mapper"); map.output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); try { compileNoStats(plan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * Test to ensure that sourceA is inside as well as outside of the iteration the same * node. * * <pre> * (SRC A) (SRC B) * / \ / \ * (SINK 1) (ITERATION) | (SINK 2) * / \ / * (SINK 3) (CROSS => NEXT PARTIAL SOLUTION) * </pre> */ @Test public void testClosure() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); DataSet<Long> sourceB = env.generateSequence(0,1); sourceA.output(new DiscardingOutputFormat<Long>()); sourceB.output(new DiscardingOutputFormat<Long>()); IterativeDataSet<Long> loopHead = sourceA.iterate(10).name("Loop"); DataSet<Long> loopTail = loopHead.cross(sourceB).with(new IdentityCrosser<Long>()); DataSet<Long> loopRes = loopHead.closeWith(loopTail); loopRes.output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * <pre> * (SRC A) (SRC B) (SRC C) * / \ / / \ * (SINK 1) (DELTA ITERATION) | (SINK 2) * / | \ / * (SINK 3) | (CROSS => NEXT WORKSET) * | | * (JOIN => SOLUTION SET DELTA) * </pre> */ @Test public void testClosureDeltaIteration() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Tuple2<Long, Long>> sourceA = env.generateSequence(0,1).map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> sourceB = env.generateSequence(0,1).map(new Duplicator<Long>()); DataSet<Tuple2<Long, Long>> sourceC = env.generateSequence(0,1).map(new Duplicator<Long>()); sourceA.output(new DiscardingOutputFormat<Tuple2<Long,Long>>()); sourceC.output(new DiscardingOutputFormat<Tuple2<Long,Long>>()); DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> loop = sourceA.iterateDelta(sourceB, 10, 0); DataSet<Tuple2<Long, Long>> workset = loop.getWorkset().cross(sourceB).with(new IdentityCrosser<Tuple2<Long, Long>>()).name("Next work set"); DataSet<Tuple2<Long, Long>> delta = workset.join(loop.getSolutionSet()).where(0).equalTo(0).with(new IdentityJoiner<Tuple2<Long, Long>>()).name("Solution set delta"); DataSet<Tuple2<Long, Long>> result = loop.closeWith(delta, workset); result.output(new DiscardingOutputFormat<Tuple2<Long,Long>>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * <pre> * +----Iteration-------+ * | | * /---------< >---------join-----< >---sink * / (Solution)| / | * / | / | * /--map-------< >----\ / /--| * / (Workset)| \ / / | * src-map | join------/ | * \ | / | * \ +-----/--------------+ * \ / * \--reduce-------/ * </pre> */ @Test public void testDeltaIterationWithStaticInput() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Tuple2<Long, Long>> source = env.generateSequence(0,1).map(new Duplicator<Long>()); DataSet<Tuple2<Long,Long>> map = source .map(new IdentityMapper<Tuple2<Long, Long>>()); DataSet<Tuple2<Long,Long>> reduce = source .reduceGroup(new IdentityGroupReducer<Tuple2<Long, Long>>()); DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> loop = source.iterateDelta(map, 10, 0); DataSet<Tuple2<Long, Long>> workset = loop.getWorkset().join(reduce).where(0).equalTo(0) .with(new IdentityJoiner<Tuple2<Long, Long>>()).name("Next work set"); DataSet<Tuple2<Long, Long>> delta = loop.getSolutionSet().join(workset).where(0).equalTo(0) .with(new IdentityJoiner<Tuple2<Long, Long>>()).name("Solution set delta"); DataSet<Tuple2<Long, Long>> result = loop.closeWith(delta, workset); result.output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } /** * <pre> * +---------Iteration-------+ * | | * /--map--< >----\ | * / | \ /-------< >---sink * src-map | join------/ | * \ | / | * \ +-----/-------------------+ * \ / * \--reduce--/ * </pre> */ @Test public void testIterationWithStaticInput() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<Long> source = env.generateSequence(1, 1000000); DataSet<Long> mapped = source.map(new IdentityMapper<Long>()); DataSet<Long> reduced = source.groupBy(new IdentityKeyExtractor<Long>()).reduce(new SelectOneReducer<Long>()); IterativeDataSet<Long> iteration = mapped.iterate(10); iteration.closeWith( iteration.join(reduced) .where(new IdentityKeyExtractor<Long>()) .equalTo(new IdentityKeyExtractor<Long>()) .with(new DummyFlatJoinFunction<Long>())) .output(new DiscardingOutputFormat<Long>()); compileNoStats(env.createProgramPlan()); } catch(Exception e){ e.printStackTrace(); fail(e.getMessage()); } } @Test public void testBranchingBroadcastVariable() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<String> input1 = env.readTextFile(IN_FILE).name("source1"); DataSet<String> input2 = env.readTextFile(IN_FILE).name("source2"); DataSet<String> input3 = env.readTextFile(IN_FILE).name("source3"); DataSet<String> result1 = input1 .map(new IdentityMapper<String>()) .reduceGroup(new Top1GroupReducer<String>()) .withBroadcastSet(input3, "bc"); DataSet<String> result2 = input2 .map(new IdentityMapper<String>()) .reduceGroup(new Top1GroupReducer<String>()) .withBroadcastSet(input3, "bc"); result1.join(result2) .where(new IdentityKeyExtractor<String>()) .equalTo(new IdentityKeyExtractor<String>()) .with(new RichJoinFunction<String, String, String>() { @Override public String join(String first, String second) { return null; } }) .withBroadcastSet(input3, "bc1") .withBroadcastSet(input1, "bc2") .withBroadcastSet(result1, "bc3") .output(new DiscardingOutputFormat<String>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testBCVariableClosure() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<String> input = env.readTextFile(IN_FILE).name("source1"); DataSet<String> reduced = input .map(new IdentityMapper<String>()) .reduceGroup(new Top1GroupReducer<String>()); DataSet<String> initialSolution = input.map(new IdentityMapper<String>()).withBroadcastSet(reduced, "bc"); IterativeDataSet<String> iteration = initialSolution.iterate(100); iteration.closeWith(iteration.map(new IdentityMapper<String>()).withBroadcastSet(reduced, "red")) .output(new DiscardingOutputFormat<String>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testMultipleIterations() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<String> input = env.readTextFile(IN_FILE).name("source1"); DataSet<String> reduced = input .map(new IdentityMapper<String>()) .reduceGroup(new Top1GroupReducer<String>()); IterativeDataSet<String> iteration1 = input.iterate(100); IterativeDataSet<String> iteration2 = input.iterate(20); IterativeDataSet<String> iteration3 = input.iterate(17); iteration1.closeWith(iteration1.map(new IdentityMapper<String>()).withBroadcastSet(reduced, "bc1")) .output(new DiscardingOutputFormat<String>()); iteration2.closeWith(iteration2.reduceGroup(new Top1GroupReducer<String>()).withBroadcastSet(reduced, "bc2")) .output(new DiscardingOutputFormat<String>()); iteration3.closeWith(iteration3.reduceGroup(new IdentityGroupReducer<String>()).withBroadcastSet(reduced, "bc3")) .output(new DiscardingOutputFormat<String>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testMultipleIterationsWithClosueBCVars() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<String> input = env.readTextFile(IN_FILE).name("source1"); IterativeDataSet<String> iteration1 = input.iterate(100); IterativeDataSet<String> iteration2 = input.iterate(20); IterativeDataSet<String> iteration3 = input.iterate(17); iteration1.closeWith(iteration1.map(new IdentityMapper<String>())) .output(new DiscardingOutputFormat<String>()); iteration2.closeWith(iteration2.reduceGroup(new Top1GroupReducer<String>())) .output(new DiscardingOutputFormat<String>()); iteration3.closeWith(iteration3.reduceGroup(new IdentityGroupReducer<String>())) .output(new DiscardingOutputFormat<String>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testBranchesOnlyInBCVariables1() { try{ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<Long> input = env.generateSequence(1, 10); DataSet<Long> bc_input = env.generateSequence(1, 10); input .map(new IdentityMapper<Long>()).withBroadcastSet(bc_input, "name1") .map(new IdentityMapper<Long>()).withBroadcastSet(bc_input, "name2") .output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); compileNoStats(plan); } catch(Exception e){ e.printStackTrace(); fail(e.getMessage()); } } @Test public void testBranchesOnlyInBCVariables2() { try{ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<Tuple2<Long, Long>> input = env.generateSequence(1, 10).map(new Duplicator<Long>()).name("proper input"); DataSet<Long> bc_input1 = env.generateSequence(1, 10).name("BC input 1"); DataSet<Long> bc_input2 = env.generateSequence(1, 10).name("BC input 1"); DataSet<Tuple2<Long, Long>> joinInput1 = input.map(new IdentityMapper<Tuple2<Long,Long>>()) .withBroadcastSet(bc_input1.map(new IdentityMapper<Long>()), "bc1") .withBroadcastSet(bc_input2, "bc2"); DataSet<Tuple2<Long, Long>> joinInput2 = input.map(new IdentityMapper<Tuple2<Long,Long>>()) .withBroadcastSet(bc_input1, "bc1") .withBroadcastSet(bc_input2, "bc2"); DataSet<Tuple2<Long, Long>> joinResult = joinInput1 .join(joinInput2, JoinHint.REPARTITION_HASH_FIRST).where(0).equalTo(1) .with(new DummyFlatJoinFunction<Tuple2<Long,Long>>()); input .map(new IdentityMapper<Tuple2<Long,Long>>()) .withBroadcastSet(bc_input1, "bc1") .union(joinResult) .output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); Plan plan = env.createProgramPlan(); compileNoStats(plan); } catch(Exception e){ e.printStackTrace(); fail(e.getMessage()); } } private static final class Duplicator<T> implements MapFunction<T, Tuple2<T, T>> { @Override public Tuple2<T, T> map(T value) { return new Tuple2<T, T>(value, value); } } }