/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.java.operators.translation; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import java.util.Iterator; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.aggregators.LongSumAggregator; import org.apache.flink.api.common.operators.GenericDataSinkBase; import org.apache.flink.api.common.operators.base.DeltaIterationBase; import org.apache.flink.api.common.operators.base.InnerJoinOperatorBase; import org.apache.flink.api.common.operators.base.MapOperatorBase; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.functions.RichJoinFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.util.Collector; import org.junit.Test; @SuppressWarnings("serial") public class DeltaIterationTranslationTest implements java.io.Serializable { @Test public void testCorrectTranslation() { try { final String JOB_NAME = "Test JobName"; final String ITERATION_NAME = "Test Name"; final String BEFORE_NEXT_WORKSET_MAP = "Some Mapper"; final String AGGREGATOR_NAME = "AggregatorName"; final int[] ITERATION_KEYS = new int[] {2}; final int NUM_ITERATIONS = 13; final int DEFAULT_parallelism= 133; final int ITERATION_parallelism = 77; ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // ------------ construct the test program ------------------ { env.setParallelism(DEFAULT_parallelism); @SuppressWarnings("unchecked") DataSet<Tuple3<Double, Long, String>> initialSolutionSet = env.fromElements(new Tuple3<Double, Long, String>(3.44, 5L, "abc")); @SuppressWarnings("unchecked") DataSet<Tuple2<Double, String>> initialWorkSet = env.fromElements(new Tuple2<Double, String>(1.23, "abc")); DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, NUM_ITERATIONS, ITERATION_KEYS); iteration.name(ITERATION_NAME).parallelism(ITERATION_parallelism); iteration.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator()); // test that multiple workset consumers are supported DataSet<Tuple2<Double, String>> worksetSelfJoin = iteration.getWorkset() .map(new IdentityMapper<Tuple2<Double,String>>()) .join(iteration.getWorkset()).where(1).equalTo(1).projectFirst(0, 1); DataSet<Tuple3<Double, Long, String>> joined = worksetSelfJoin.join(iteration.getSolutionSet()).where(1).equalTo(2).with(new SolutionWorksetJoin()); DataSet<Tuple3<Double, Long, String>> result = iteration.closeWith( joined, joined.map(new NextWorksetMapper()).name(BEFORE_NEXT_WORKSET_MAP)); result.output(new DiscardingOutputFormat<Tuple3<Double, Long, String>>()); result.writeAsText("/dev/null"); } Plan p = env.createProgramPlan(JOB_NAME); // ------------- validate the plan ---------------- assertEquals(JOB_NAME, p.getJobName()); assertEquals(DEFAULT_parallelism, p.getDefaultParallelism()); // validate the iteration GenericDataSinkBase<?> sink1, sink2; { Iterator<? extends GenericDataSinkBase<?>> sinks = p.getDataSinks().iterator(); sink1 = sinks.next(); sink2 = sinks.next(); } DeltaIterationBase<?, ?> iteration = (DeltaIterationBase<?, ?>) sink1.getInput(); // check that multi consumer translation works for iterations assertEquals(iteration, sink2.getInput()); // check the basic iteration properties assertEquals(NUM_ITERATIONS, iteration.getMaximumNumberOfIterations()); assertArrayEquals(ITERATION_KEYS, iteration.getSolutionSetKeyFields()); assertEquals(ITERATION_parallelism, iteration.getParallelism()); assertEquals(ITERATION_NAME, iteration.getName()); MapOperatorBase<?, ?, ?> nextWorksetMapper = (MapOperatorBase<?, ?, ?>) iteration.getNextWorkset(); InnerJoinOperatorBase<?, ?, ?, ?> solutionSetJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) iteration.getSolutionSetDelta(); InnerJoinOperatorBase<?, ?, ?, ?> worksetSelfJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) solutionSetJoin.getFirstInput(); MapOperatorBase<?, ?, ?> worksetMapper = (MapOperatorBase<?, ?, ?>) worksetSelfJoin.getFirstInput(); assertEquals(IdentityMapper.class, worksetMapper.getUserCodeWrapper().getUserCodeClass()); assertEquals(NextWorksetMapper.class, nextWorksetMapper.getUserCodeWrapper().getUserCodeClass()); if (solutionSetJoin.getUserCodeWrapper().getUserCodeObject() instanceof WrappingFunction) { WrappingFunction<?> wf = (WrappingFunction<?>) solutionSetJoin.getUserCodeWrapper().getUserCodeObject(); assertEquals(SolutionWorksetJoin.class, wf.getWrappedFunction().getClass()); } else { assertEquals(SolutionWorksetJoin.class, solutionSetJoin.getUserCodeWrapper().getUserCodeClass()); } assertEquals(BEFORE_NEXT_WORKSET_MAP, nextWorksetMapper.getName()); assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName()); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail(e.getMessage()); } } @Test public void testRejectWhenSolutionSetKeysDontMatchJoin() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @SuppressWarnings("unchecked") DataSet<Tuple3<Double, Long, String>> initialSolutionSet = env.fromElements(new Tuple3<Double, Long, String>(3.44, 5L, "abc")); @SuppressWarnings("unchecked") DataSet<Tuple2<Double, String>> initialWorkSet = env.fromElements(new Tuple2<Double, String>(1.23, "abc")); DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, 10, 1); try { iteration.getWorkset().join(iteration.getSolutionSet()).where(1).equalTo(2); fail("Accepted invalid program."); } catch (InvalidProgramException e) { // all good! } try { iteration.getSolutionSet().join(iteration.getWorkset()).where(2).equalTo(1); fail("Accepted invalid program."); } catch (InvalidProgramException e) { // all good! } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail(e.getMessage()); } } @Test public void testRejectWhenSolutionSetKeysDontMatchCoGroup() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @SuppressWarnings("unchecked") DataSet<Tuple3<Double, Long, String>> initialSolutionSet = env.fromElements(new Tuple3<Double, Long, String>(3.44, 5L, "abc")); @SuppressWarnings("unchecked") DataSet<Tuple2<Double, String>> initialWorkSet = env.fromElements(new Tuple2<Double, String>(1.23, "abc")); DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, 10, 1); try { iteration.getWorkset().coGroup(iteration.getSolutionSet()).where(1).equalTo(2).with(new SolutionWorksetCoGroup1()); fail("Accepted invalid program."); } catch (InvalidProgramException e) { // all good! } try { iteration.getSolutionSet().coGroup(iteration.getWorkset()).where(2).equalTo(1).with(new SolutionWorksetCoGroup2()); fail("Accepted invalid program."); } catch (InvalidProgramException e) { // all good! } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail(e.getMessage()); } } // -------------------------------------------------------------------------------------------- public static class SolutionWorksetJoin extends RichJoinFunction<Tuple2<Double, String>, Tuple3<Double, Long, String>, Tuple3<Double, Long, String>> { @Override public Tuple3<Double, Long, String> join(Tuple2<Double, String> first, Tuple3<Double, Long, String> second){ return null; } } public static class NextWorksetMapper extends RichMapFunction<Tuple3<Double, Long, String>, Tuple2<Double, String>> { @Override public Tuple2<Double, String> map(Tuple3<Double, Long, String> value) { return null; } } public static class IdentityMapper<T> extends RichMapFunction<T, T> { @Override public T map(T value) throws Exception { return value; } } public static class SolutionWorksetCoGroup1 extends RichCoGroupFunction<Tuple2<Double, String>, Tuple3<Double, Long, String>, Tuple3<Double, Long, String>> { @Override public void coGroup(Iterable<Tuple2<Double, String>> first, Iterable<Tuple3<Double, Long, String>> second, Collector<Tuple3<Double, Long, String>> out) { } } public static class SolutionWorksetCoGroup2 extends RichCoGroupFunction<Tuple3<Double, Long, String>, Tuple2<Double, String>, Tuple3<Double, Long, String>> { @Override public void coGroup(Iterable<Tuple3<Double, Long, String>> second, Iterable<Tuple2<Double, String>> first, Collector<Tuple3<Double, Long, String>> out) { } } }