/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.giraph; import org.apache.giraph.aggregators.LongSumAggregator; import org.apache.giraph.bsp.BspService; import org.apache.giraph.conf.GiraphConfiguration; import org.apache.giraph.conf.GiraphConstants; import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration; import org.apache.giraph.edge.Edge; import org.apache.giraph.edge.EdgeFactory; import org.apache.giraph.examples.SimpleSuperstepComputation; import org.apache.giraph.graph.BasicComputation; import org.apache.giraph.graph.Vertex; import org.apache.giraph.job.GiraphJob; import org.apache.giraph.master.DefaultMasterCompute; import org.apache.giraph.worker.DefaultWorkerContext; import org.apache.giraph.zk.ZooKeeperExt; import org.apache.giraph.zk.ZooKeeperManager; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Writable; import org.apache.log4j.Logger; import org.apache.zookeeper.CreateMode; import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.ZooDefs; import org.junit.Assert; import org.junit.Test; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Tests that worker context and master computation * are properly saved and loaded back at checkpoint. */ public class TestCheckpointing extends BspCase { /** Class logger */ private static final Logger LOG = Logger.getLogger(TestCheckpointing.class); /** ID to be used with test job */ public static final String TEST_JOB_ID = "test_job"; private static SuperstepCallback SUPERSTEP_CALLBACK; /** * Create the test case */ public TestCheckpointing() { super(TestCheckpointing.class.getName()); } @Test public void testBspCheckpoint() throws InterruptedException, IOException, ClassNotFoundException { testBspCheckpoint(false); } @Test public void testAsyncMessageStoreCheckpoint() throws InterruptedException, IOException, ClassNotFoundException { testBspCheckpoint(true); } public void testBspCheckpoint(boolean useAsyncMessageStore) throws IOException, InterruptedException, ClassNotFoundException { Path checkpointsDir = getTempPath("checkpointing"); GiraphConfiguration conf = new GiraphConfiguration(); if (useAsyncMessageStore) { GiraphConstants.ASYNC_MESSAGE_STORE_THREADS_COUNT.set(conf, 2); } SUPERSTEP_CALLBACK = null; GiraphConstants.CLEANUP_CHECKPOINTS_AFTER_SUCCESS.set(conf, false); conf.setCheckpointFrequency(2); long idSum = runOriginalJob(checkpointsDir, conf); assertEquals(10, idSum); SUPERSTEP_CALLBACK = new SuperstepCallback() { @Override public void superstep(long superstep, ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { if (superstep < 2) { Assert.fail("Restarted JOB should not be executed on superstep " + superstep); } } }; runRestartedJob(checkpointsDir, conf, idSum, 2); } private void runRestartedJob(Path checkpointsDir, GiraphConfiguration conf, long idSum, long restartFrom) throws IOException, InterruptedException, ClassNotFoundException { Path outputPath; LOG.info("testBspCheckpoint: Restarting from the latest superstep " + "with checkpoint path = " + checkpointsDir); outputPath = getTempPath("checkpointing_restarted"); GiraphConstants.RESTART_JOB_ID.set(conf, TEST_JOB_ID); conf.set("mapred.job.id", "restarted_test_job"); if (restartFrom >= 0) { conf.set(GiraphConstants.RESTART_SUPERSTEP, Long.toString(restartFrom)); } GiraphJob restartedJob = prepareJob(getCallingMethodName() + "Restarted", conf, outputPath); GiraphConstants.CHECKPOINT_DIRECTORY.set(restartedJob.getConfiguration(), checkpointsDir.toString()); assertTrue(restartedJob.run(true)); if (!runningInDistributedMode()) { long idSumRestarted = CheckpointVertexWorkerContext .getFinalSum(); LOG.info("testBspCheckpoint: idSumRestarted = " + idSumRestarted); assertEquals(idSum, idSumRestarted); } } private long runOriginalJob(Path checkpointsDir, GiraphConfiguration conf) throws IOException, InterruptedException, ClassNotFoundException { Path outputPath = getTempPath("checkpointing_original"); conf.setComputationClass( CheckpointComputation.class); conf.setWorkerContextClass( CheckpointVertexWorkerContext.class); conf.setMasterComputeClass( CheckpointVertexMasterCompute.class); conf.setVertexInputFormatClass(SimpleSuperstepComputation.SimpleSuperstepVertexInputFormat.class); conf.setVertexOutputFormatClass(SimpleSuperstepComputation.SimpleSuperstepVertexOutputFormat.class); conf.set("mapred.job.id", TEST_JOB_ID); GiraphJob job = prepareJob(getCallingMethodName(), conf, outputPath); GiraphConfiguration configuration = job.getConfiguration(); GiraphConstants.CHECKPOINT_DIRECTORY.set(configuration, checkpointsDir.toString()); assertTrue(job.run(true)); long idSum = 0; if (!runningInDistributedMode()) { FileStatus fileStatus = getSinglePartFileStatus(job.getConfiguration(), outputPath); idSum = CheckpointVertexWorkerContext .getFinalSum(); LOG.info("testBspCheckpoint: idSum = " + idSum + " fileLen = " + fileStatus.getLen()); } return idSum; } /** * Actual computation. */ public static class CheckpointComputation extends BasicComputation<LongWritable, IntWritable, FloatWritable, FloatWritable> { @Override public void compute( Vertex<LongWritable, IntWritable, FloatWritable> vertex, Iterable<FloatWritable> messages) throws IOException { CheckpointVertexWorkerContext workerContext = getWorkerContext(); assertEquals(getSuperstep() + 1, workerContext.testValue); if (getSuperstep() > 4) { vertex.voteToHalt(); return; } aggregate(LongSumAggregator.class.getName(), new LongWritable(vertex.getId().get())); float msgValue = 0.0f; for (FloatWritable message : messages) { float curMsgValue = message.get(); msgValue += curMsgValue; } int vertexValue = vertex.getValue().get(); vertex.setValue(new IntWritable(vertexValue + (int) msgValue)); for (Edge<LongWritable, FloatWritable> edge : vertex.getEdges()) { FloatWritable newEdgeValue = new FloatWritable(edge.getValue().get() + (float) vertexValue); Edge<LongWritable, FloatWritable> newEdge = EdgeFactory.create(edge.getTargetVertexId(), newEdgeValue); vertex.addEdge(newEdge); sendMessage(edge.getTargetVertexId(), newEdgeValue); } } } @Test public void testManualCheckpointAtTheBeginning() throws InterruptedException, IOException, ClassNotFoundException { testManualCheckpoint(0); } @Test public void testManualCheckpoint() throws InterruptedException, IOException, ClassNotFoundException { testManualCheckpoint(2); } private void testManualCheckpoint(final int checkpointSuperstep) throws IOException, InterruptedException, ClassNotFoundException { Path checkpointsDir = getTempPath("checkpointing"); GiraphConfiguration conf = new GiraphConfiguration(); SUPERSTEP_CALLBACK = new SuperstepCallback() { @Override public void superstep(long superstep, ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { if (superstep == checkpointSuperstep) { try { ZooKeeperExt zooKeeperExt = new ZooKeeperExt(conf.getZookeeperList(), conf.getZooKeeperSessionTimeout(), conf.getZookeeperOpsMaxAttempts(), conf.getZookeeperOpsRetryWaitMsecs(), TestCheckpointing.this); String basePath = ZooKeeperManager.getBasePath(conf) + BspService.BASE_DIR + "/" + conf.get("mapred.job.id"); zooKeeperExt.createExt( basePath + BspService.FORCE_CHECKPOINT_USER_FLAG, null, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT, true); } catch (IOException | InterruptedException | KeeperException e) { throw new RuntimeException(e); } } else if (superstep > checkpointSuperstep) { Assert.fail("Job should be stopped by now " + superstep); } } }; try { runOriginalJob(checkpointsDir, conf); fail("Original job should fail after checkpointing"); } catch (Exception e) { LOG.info("Original job failed, that's OK " + e); } SUPERSTEP_CALLBACK = new SuperstepCallback() { @Override public void superstep(long superstep, ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { if (superstep < checkpointSuperstep) { Assert.fail("Restarted JOB should not be executed on superstep " + superstep); } } }; runRestartedJob(checkpointsDir, conf, 10, -1); } /** * Worker context associated. */ public static class CheckpointVertexWorkerContext extends DefaultWorkerContext { /** User can access this after the application finishes if local */ private static long FINAL_SUM; private int testValue; public static long getFinalSum() { return FINAL_SUM; } @Override public void postSuperstep() { super.postSuperstep(); sendMessageToMyself(new LongWritable(getSuperstep())); } /** * Send message to all workers (except this worker) * * @param message Message to send */ private void sendMessageToMyself(Writable message) { sendMessageToWorker(message, getMyWorkerIndex()); } @Override public void postApplication() { setFinalSum(this.<LongWritable>getAggregatedValue( LongSumAggregator.class.getName()).get()); LOG.info("FINAL_SUM=" + FINAL_SUM); } /** * Set the final sum * * @param value sum */ private static void setFinalSum(long value) { FINAL_SUM = value; } @Override public void preSuperstep() { assertEquals(getSuperstep(), testValue++); if (getSuperstep() > 0) { List<Writable> messages = getAndClearMessagesFromOtherWorkers(); assertEquals(1, messages.size()); assertEquals(getSuperstep() - 1, ((LongWritable)(messages.get(0))).get()); } } @Override public void readFields(DataInput dataInput) throws IOException { super.readFields(dataInput); testValue = dataInput.readInt(); } @Override public void write(DataOutput dataOutput) throws IOException { super.write(dataOutput); dataOutput.writeInt(testValue); } } /** * Master compute */ public static class CheckpointVertexMasterCompute extends DefaultMasterCompute { private int testValue = 0; @Override public void compute() { long superstep = getSuperstep(); if (SUPERSTEP_CALLBACK != null) { SUPERSTEP_CALLBACK.superstep(getSuperstep(), getConf()); } assertEquals(superstep, testValue++); } @Override public void initialize() throws InstantiationException, IllegalAccessException { registerAggregator(LongSumAggregator.class.getName(), LongSumAggregator.class); } @Override public void readFields(DataInput in) throws IOException { super.readFields(in); testValue = in.readInt(); } @Override public void write(DataOutput out) throws IOException { super.write(out); out.writeInt(testValue); } } private static interface SuperstepCallback { public void superstep(long superstep, ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf); } }