/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.accumulators; import akka.actor.ActorRef; import akka.actor.ActorSystem; import akka.pattern.Patterns; import akka.testkit.JavaTestKit; import akka.util.Timeout; import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.accumulators.IntCounter; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.io.OutputFormat; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.LocalEnvironment; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.optimizer.DataStatistics; import org.apache.flink.optimizer.Optimizer; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.instance.AkkaActorGateway; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.testingUtils.TestingCluster; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import static org.junit.Assert.*; /** * Tests the availability of accumulator results during runtime. The test case tests a user-defined * accumulator and Flink's internal accumulators for two consecutive tasks. * * CHAINED[Source -> Map] -> Sink * * Checks are performed as the elements arrive at the operators. Checks consist of a message sent by * the task to the task manager which notifies the job manager and sends the current accumulators. * The task blocks until the test has been notified about the current accumulator values. * * A barrier between the operators ensures that that pipelining is disabled for the streaming test. * The batch job reads the records one at a time. The streaming code buffers the records beforehand; * that's why exact guarantees about the number of records read are very hard to make. Thus, why we * check for an upper bound of the elements read. */ public class AccumulatorLiveITCase extends TestLogger { private static final Logger LOG = LoggerFactory.getLogger(AccumulatorLiveITCase.class); private static ActorSystem system; private static ActorGateway jobManagerGateway; private static ActorRef taskManager; private static JobID jobID; private static JobGraph jobGraph; // name of user accumulator private static String ACCUMULATOR_NAME = "test"; // number of heartbeat intervals to check private static final int NUM_ITERATIONS = 5; private static List<String> inputData = new ArrayList<>(NUM_ITERATIONS); private static final FiniteDuration TIMEOUT = new FiniteDuration(10, TimeUnit.SECONDS); @Before public void before() throws Exception { system = AkkaUtils.createLocalActorSystem(new Configuration()); Configuration config = new Configuration(); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 1); config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 1); config.setString(ConfigConstants.AKKA_ASK_TIMEOUT, TestingUtils.DEFAULT_AKKA_ASK_TIMEOUT()); TestingCluster testingCluster = new TestingCluster(config, false, true); testingCluster.start(); jobManagerGateway = testingCluster.getLeaderGateway(TestingUtils.TESTING_DURATION()); taskManager = testingCluster.getTaskManagersAsJava().get(0); // generate test data for (int i=0; i < NUM_ITERATIONS; i++) { inputData.add(i, String.valueOf(i+1)); } NotifyingMapper.finished = false; } @After public void after() throws Exception { JavaTestKit.shutdownActorSystem(system); inputData.clear(); } @Test public void testBatch() throws Exception { /** The program **/ ExecutionEnvironment env = new BatchPlanExtractor(); env.setParallelism(1); DataSet<String> input = env.fromCollection(inputData); input .flatMap(new NotifyingMapper()) .output(new NotifyingOutputFormat()); env.execute(); // Extract job graph and set job id for the task to notify of accumulator changes. jobGraph = getOptimizedPlan(((BatchPlanExtractor) env).plan); jobID = jobGraph.getJobID(); verifyResults(); } @Test public void testStreaming() throws Exception { StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); env.setParallelism(1); DataStream<String> input = env.fromCollection(inputData); input .flatMap(new NotifyingMapper()) .writeUsingOutputFormat(new NotifyingOutputFormat()).disableChaining(); jobGraph = env.getStreamGraph().getJobGraph(); jobID = jobGraph.getJobID(); verifyResults(); } private static void verifyResults() { new JavaTestKit(system) {{ ActorGateway selfGateway = new AkkaActorGateway(getRef(), jobManagerGateway.leaderSessionID()); // register for accumulator changes jobManagerGateway.tell(new TestingJobManagerMessages.NotifyWhenAccumulatorChange(jobID), selfGateway); expectMsgEquals(TIMEOUT, true); // submit job jobManagerGateway.tell( new JobManagerMessages.SubmitJob( jobGraph, ListeningBehaviour.EXECUTION_RESULT), selfGateway); expectMsgClass(TIMEOUT, JobManagerMessages.JobSubmitSuccess.class); TestingJobManagerMessages.UpdatedAccumulators msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT); Map<String, Accumulator<?, ?>> userAccumulators = msg.userAccumulators(); ExecutionAttemptID mapperTaskID = null; ExecutionAttemptID sinkTaskID = null; /* Check for accumulator values */ if(checkUserAccumulators(0, userAccumulators)) { LOG.info("Passed initial check for map task."); } else { fail("Wrong accumulator results when map task begins execution."); } int expectedAccVal = 0; /* for mapper task */ for (int i = 1; i <= NUM_ITERATIONS; i++) { expectedAccVal += i; // receive message msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT); userAccumulators = msg.userAccumulators(); LOG.info("{}", userAccumulators); if (checkUserAccumulators(expectedAccVal, userAccumulators)) { LOG.info("Passed round #" + i); } else if (checkUserAccumulators(expectedAccVal, userAccumulators)) { // we determined the wrong task id and need to switch the two here ExecutionAttemptID temp = mapperTaskID; mapperTaskID = sinkTaskID; sinkTaskID = temp; LOG.info("Passed round #" + i); } else { fail("Failed in round #" + i); } } msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT); userAccumulators = msg.userAccumulators(); if(checkUserAccumulators(expectedAccVal, userAccumulators)) { LOG.info("Passed initial check for sink task."); } else { fail("Wrong accumulator results when sink task begins execution."); } /* for sink task */ for (int i = 1; i <= NUM_ITERATIONS; i++) { // receive message msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT); userAccumulators = msg.userAccumulators(); LOG.info("{}", userAccumulators); if (checkUserAccumulators(expectedAccVal, userAccumulators)) { LOG.info("Passed round #" + i); } else { fail("Failed in round #" + i); } } expectMsgClass(TIMEOUT, JobManagerMessages.JobResultSuccess.class); }}; } private static boolean checkUserAccumulators(int expected, Map<String, Accumulator<?,?>> accumulatorMap) { LOG.info("checking user accumulators"); return accumulatorMap.containsKey(ACCUMULATOR_NAME) && expected == ((IntCounter)accumulatorMap.get(ACCUMULATOR_NAME)).getLocalValue(); } /** * UDF that notifies when it changes the accumulator values */ private static class NotifyingMapper extends RichFlatMapFunction<String, Integer> { private static final long serialVersionUID = 1L; private IntCounter counter = new IntCounter(); private static boolean finished = false; @Override public void open(Configuration parameters) throws Exception { getRuntimeContext().addAccumulator(ACCUMULATOR_NAME, counter); notifyTaskManagerOfAccumulatorUpdate(); } @Override public void flatMap(String value, Collector<Integer> out) throws Exception { int val = Integer.valueOf(value); counter.add(val); out.collect(val); LOG.debug("Emitting value {}.", value); notifyTaskManagerOfAccumulatorUpdate(); } @Override public void close() throws Exception { finished = true; } } /** * Outputs format which notifies of accumulator changes and waits for the previous mapper. */ private static class NotifyingOutputFormat implements OutputFormat<Integer> { private static final long serialVersionUID = 1L; @Override public void configure(Configuration parameters) { } @Override public void open(int taskNumber, int numTasks) throws IOException { while (!NotifyingMapper.finished) { try { Thread.sleep(1000); } catch (InterruptedException e) {} } notifyTaskManagerOfAccumulatorUpdate(); } @Override public void writeRecord(Integer record) throws IOException { notifyTaskManagerOfAccumulatorUpdate(); } @Override public void close() throws IOException { } } /** * Notify task manager of accumulator update and wait until the Heartbeat containing the message * has been reported. */ public static void notifyTaskManagerOfAccumulatorUpdate() { new JavaTestKit(system) {{ Timeout timeout = new Timeout(TIMEOUT); Future<Object> ask = Patterns.ask(taskManager, new TestingTaskManagerMessages.AccumulatorsChanged(jobID), timeout); try { Await.result(ask, timeout.duration()); } catch (Exception e) { fail("Failed to notify task manager of accumulator update."); } }}; } /** * Helpers to generate the JobGraph */ private static JobGraph getOptimizedPlan(Plan plan) { Optimizer pc = new Optimizer(new DataStatistics(), new Configuration()); JobGraphGenerator jgg = new JobGraphGenerator(); OptimizedPlan op = pc.compile(plan); return jgg.compileJobGraph(op); } private static class BatchPlanExtractor extends LocalEnvironment { private Plan plan = null; @Override public JobExecutionResult execute(String jobName) throws Exception { plan = createProgramPlan(); return new JobExecutionResult(new JobID(), -1, null); } } /** * This is used to for creating the example topology. {@link #execute} is never called, we * only use this to call {@link #getStreamGraph()}. */ private static class DummyStreamExecutionEnvironment extends StreamExecutionEnvironment { @Override public JobExecutionResult execute() throws Exception { return execute("default"); } @Override public JobExecutionResult execute(String jobName) throws Exception { throw new RuntimeException("This should not be called."); } } }