/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.iterative.aggregators; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.util.Random; import org.apache.flink.api.common.functions.RichFilterFunction; import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.test.util.MultipleProgramsTestBase; import org.junit.After; import org.junit.Assert; import org.apache.flink.api.common.aggregators.ConvergenceCriterion; import org.apache.flink.api.common.aggregators.LongSumAggregator; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.apache.flink.types.LongValue; import org.apache.flink.util.Collector; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static org.junit.Assert.assertEquals; /** * Test the functionality of aggregators in bulk and delta iterative cases. */ @RunWith(Parameterized.class) public class AggregatorsITCase extends MultipleProgramsTestBase { private static final int MAX_ITERATIONS = 20; private static final int parallelism = 2; private static final String NEGATIVE_ELEMENTS_AGGR = "count.negative.elements"; private static String testString = "Et tu, Brute?"; private static String testName = "testing_caesar"; private static String testPath; public AggregatorsITCase(TestExecutionMode mode){ super(mode); } private String resultPath; private String expected; @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); @Before public void before() throws Exception{ File tempFile = tempFolder.newFile(); testPath = tempFile.toString(); resultPath = tempFile.toURI().toString(); } @After public void after() throws Exception{ compareResultsByLinesInMemory(expected, resultPath); } @Test public void testDistributedCacheWithIterations() throws Exception{ File tempFile = new File(testPath); try (FileWriter writer = new FileWriter(tempFile)) { writer.write(testString); } final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.registerCachedFile(resultPath, testName); IterativeDataSet<Long> solution = env.fromElements(1L).iterate(2); solution.closeWith(env.generateSequence(1,2).filter(new RichFilterFunction<Long>() { @Override public void open(Configuration parameters) throws Exception{ File file = getRuntimeContext().getDistributedCache().getFile(testName); BufferedReader reader = new BufferedReader(new FileReader(file)); String output = reader.readLine(); reader.close(); assertEquals(output, testString); } @Override public boolean filter(Long value) throws Exception { return false; } }).withBroadcastSet(solution, "SOLUTION")).output(new DiscardingOutputFormat<Long>()); env.execute(); expected = testString; // this will be a useless verification now. } @Test public void testAggregatorWithoutParameterForIterate() throws Exception { /* * Test aggregator without parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterion()); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap()); iteration.closeWith(updatedDs).writeAsText(resultPath); env.execute(); expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n" + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"; } @Test public void testAggregatorWithParameterForIterate() throws Exception { /* * Test aggregator with parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregatorWithParameter aggr = new LongSumAggregatorWithParameter(0); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterion()); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMapWithParam()); iteration.closeWith(updatedDs).writeAsText(resultPath); env.execute(); expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n" + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"; } @Test public void testConvergenceCriterionWithParameterForIterate() throws Exception { /* * Test convergence criterion with parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterionWithParam(3)); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap()); iteration.closeWith(updatedDs).writeAsText(resultPath); env.execute(); expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n" + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"; } @Test public void testAggregatorWithoutParameterForIterateDelta() throws Exception { /* * Test aggregator without parameter for iterateDelta */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap()); DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta( initialSolutionSet, MAX_ITERATIONS, 0); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateMapDelta()); DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet()) .where(0).equalTo(0).flatMap(new UpdateFilter()); DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements); DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper()); result.writeAsText(resultPath); env.execute(); expected = "1\n" + "2\n" + "2\n" + "3\n" + "3\n" + "3\n" + "4\n" + "4\n" + "4\n" + "4\n" + "5\n" + "5\n" + "5\n" + "5\n" + "5\n"; } @Test public void testAggregatorWithParameterForIterateDelta() throws Exception { /* * Test aggregator with parameter for iterateDelta */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap()); DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta( initialSolutionSet, MAX_ITERATIONS, 0); // register aggregator LongSumAggregator aggr = new LongSumAggregatorWithParameter(4); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateMapDelta()); DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet()) .where(0).equalTo(0).flatMap(new UpdateFilter()); DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements); DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper()); result.writeAsText(resultPath); env.execute(); expected = "1\n" + "2\n" + "2\n" + "3\n" + "3\n" + "3\n" + "4\n" + "4\n" + "4\n" + "4\n" + "5\n" + "5\n" + "5\n" + "5\n" + "5\n"; } @Test public void testConvergenceCriterionWithParameterForIterateDelta() throws Exception { /* * Test convergence criterion with parameter for iterate delta */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap()); DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta( initialSolutionSet, MAX_ITERATIONS, 0); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterionWithParam(3)); DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateAndSubtractOneDelta()); DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet()) .where(0).equalTo(0).projectFirst(0, 1); DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements); DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper()); result.writeAsText(resultPath); env.execute(); expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n" + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"; } @SuppressWarnings("serial") public static final class NegativeElementsConvergenceCriterion implements ConvergenceCriterion<LongValue> { @Override public boolean isConverged(int iteration, LongValue value) { return value.getValue() > 3; } } @SuppressWarnings("serial") public static final class NegativeElementsConvergenceCriterionWithParam implements ConvergenceCriterion<LongValue> { private int value; public NegativeElementsConvergenceCriterionWithParam(int val) { this.value = val; } public int getValue() { return this.value; } @Override public boolean isConverged(int iteration, LongValue value) { return value.getValue() > this.value; } } @SuppressWarnings("serial") public static final class SubtractOneMap extends RichMapFunction<Integer, Integer> { private LongSumAggregator aggr; @Override public void open(Configuration conf) { aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR); } @Override public Integer map(Integer value) { Integer newValue = value - 1; // count negative numbers if (newValue < 0) { aggr.aggregate(1l); } return newValue; } } @SuppressWarnings("serial") public static final class SubtractOneMapWithParam extends RichMapFunction<Integer, Integer> { private LongSumAggregatorWithParameter aggr; @Override public void open(Configuration conf) { aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR); } @Override public Integer map(Integer value) { Integer newValue = value - 1; // count numbers less than the aggregator parameter if ( newValue < aggr.getValue() ) { aggr.aggregate(1l); } return newValue; } } @SuppressWarnings("serial") public static class LongSumAggregatorWithParameter extends LongSumAggregator { private int value; public LongSumAggregatorWithParameter(int val) { this.value = val; } public int getValue() { return this.value; } } @SuppressWarnings("serial") public static final class TupleMakerMap extends RichMapFunction<Integer, Tuple2<Integer, Integer>> { private Random rnd; @Override public void open(Configuration parameters){ rnd = new Random(0xC0FFEBADBEEFDEADL + getRuntimeContext().getIndexOfThisSubtask()); } @Override public Tuple2<Integer, Integer> map(Integer value) { Integer nodeId = rnd.nextInt(100000); return new Tuple2<>(nodeId, value); } } @SuppressWarnings("serial") public static final class AggregateMapDelta extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { private LongSumAggregator aggr; private LongValue previousAggr; private int superstep; @Override public void open(Configuration conf) { aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR); superstep = getIterationRuntimeContext().getSuperstepNumber(); if (superstep > 1) { previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(NEGATIVE_ELEMENTS_AGGR); // check previous aggregator value Assert.assertEquals(superstep - 1, previousAggr.getValue()); } } @Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { // count the elements that are equal to the superstep number if (value.f1 == superstep) { aggr.aggregate(1l); } return value; } } @SuppressWarnings("serial") public static final class UpdateFilter extends RichFlatMapFunction<Tuple2<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>, Tuple2<Integer, Integer>> { private int superstep; @Override public void open(Configuration conf) { superstep = getIterationRuntimeContext().getSuperstepNumber(); } @Override public void flatMap(Tuple2<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> value, Collector<Tuple2<Integer, Integer>> out) { if (value.f0.f1 > superstep) { out.collect(value.f0); } } } @SuppressWarnings("serial") public static final class ProjectSecondMapper extends RichMapFunction<Tuple2<Integer, Integer>, Integer> { @Override public Integer map(Tuple2<Integer, Integer> value) { return value.f1; } } @SuppressWarnings("serial") public static final class AggregateAndSubtractOneDelta extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { private LongSumAggregator aggr; private LongValue previousAggr; private int superstep; @Override public void open(Configuration conf) { aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR); superstep = getIterationRuntimeContext().getSuperstepNumber(); if (superstep > 1) { previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(NEGATIVE_ELEMENTS_AGGR); // check previous aggregator value Assert.assertEquals(superstep - 1, previousAggr.getValue()); } } @Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { // count the ones if (value.f1 == 1) { aggr.aggregate(1l); } value.f1--; return value; } } }