/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.broadcastvars; import java.util.Collection; import java.util.List; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.FlatMapOperator; import org.apache.flink.api.java.operators.JoinOperator; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.configuration.Configuration; import org.apache.flink.test.util.JavaProgramTestBase; import org.apache.flink.util.Collector; public class BroadcastBranchingITCase extends JavaProgramTestBase { private static final String RESULT = "(2,112)\n"; // Sc1(id,a,b,c) -- // \ // Sc2(id,x) -------- Jn2(id) -- Mp2 -- Sk // \ / / <=BC // Jn1(id) -- Mp1 ---- // / // Sc3(id,y) -------- @Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); // Sc1 generates M parameters a,b,c for second degree polynomials P(x) = ax^2 + bx + c identified by id DataSet<Tuple4<String, Integer, Integer, Integer>> sc1 = env .fromElements(new Tuple4<>("1", 61, 6, 29), new Tuple4<>("2", 7, 13, 10), new Tuple4<>("3", 8, 13, 27)); // Sc2 generates N x values to be evaluated with the polynomial identified by id DataSet<Tuple2<String, Integer>> sc2 = env .fromElements(new Tuple2<>("1", 5), new Tuple2<>("2", 3), new Tuple2<>("3", 6)); // Sc3 generates N y values to be evaluated with the polynomial identified by id DataSet<Tuple2<String, Integer>> sc3 = env .fromElements(new Tuple2<>("1", 2), new Tuple2<>("2", 3), new Tuple2<>("3", 7)); // Jn1 matches x and y values on id and emits (id, x, y) triples JoinOperator<Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple3<String, Integer, Integer>> jn1 = sc2.join(sc3).where(0).equalTo(0).with(new Jn1()); // Jn2 matches polynomial and arguments by id, computes p = min(P(x),P(y)) and emits (id, p) tuples JoinOperator<Tuple3<String, Integer, Integer>, Tuple4<String, Integer, Integer, Integer>, Tuple2<String, Integer>> jn2 = jn1.join(sc1).where(0).equalTo(0).with(new Jn2()); // Mp1 selects (id, x, y) triples where x = y and broadcasts z (=x=y) to Mp2 FlatMapOperator<Tuple3<String, Integer, Integer>, Tuple2<String, Integer>> mp1 = jn1.flatMap(new Mp1()); // Mp2 filters out all p values which can be divided by z List<Tuple2<String, Integer>> result = jn2.flatMap(new Mp2()).withBroadcastSet(mp1, "z").collect(); JavaProgramTestBase.compareResultAsText(result, RESULT); } public static class Jn1 implements JoinFunction<Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple3<String, Integer, Integer>> { private static final long serialVersionUID = 1L; @Override public Tuple3<String, Integer, Integer> join(Tuple2<String, Integer> first, Tuple2<String, Integer> second) throws Exception { return new Tuple3<>(first.f0, first.f1, second.f1); } } public static class Jn2 implements JoinFunction<Tuple3<String, Integer, Integer>, Tuple4<String, Integer, Integer, Integer>, Tuple2<String, Integer>> { private static final long serialVersionUID = 1L; private static int p(int x, int a, int b, int c) { return a * x * x + b * x + c; } @Override public Tuple2<String, Integer> join(Tuple3<String, Integer, Integer> first, Tuple4<String, Integer, Integer, Integer> second) throws Exception { int x = first.f1; int y = first.f2; int a = second.f1; int b = second.f2; int c = second.f3; int p_x = p(x, a, b, c); int p_y = p(y, a, b, c); int min = Math.min(p_x, p_y); return new Tuple2<>(first.f0, min); } } public static class Mp1 implements FlatMapFunction<Tuple3<String, Integer, Integer>, Tuple2<String, Integer>> { private static final long serialVersionUID = 1L; @Override public void flatMap(Tuple3<String, Integer, Integer> value, Collector<Tuple2<String, Integer>> out) throws Exception { if (value.f1.compareTo(value.f2) == 0) { out.collect(new Tuple2<>(value.f0, value.f1)); } } } public static class Mp2 extends RichFlatMapFunction<Tuple2<String, Integer>, Tuple2<String, Integer>> { private static final long serialVersionUID = 1L; private Collection<Tuple2<String, Integer>> zs; @Override public void open(Configuration parameters) throws Exception { this.zs = getRuntimeContext().getBroadcastVariable("z"); } @Override public void flatMap(Tuple2<String, Integer> value, Collector<Tuple2<String, Integer>> out) throws Exception { int p = value.f1; for (Tuple2<String, Integer> z : zs) { if (z.f0.equals(value.f0)) { if (p % z.f1 != 0) { out.collect(value); } } } } } }