/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.checkpoint.hooks; import org.apache.flink.core.io.SimpleVersionedSerializer; import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.util.TestLogger; import org.junit.Test; import javax.annotation.Nullable; import java.net.URL; import java.net.URLClassLoader; import java.util.concurrent.Executor; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; /** * Tests for the MasterHooks utility class. */ public class MasterHooksTest extends TestLogger { // ------------------------------------------------------------------------ // hook management // ------------------------------------------------------------------------ @Test public void wrapHook() throws Exception { final String id = "id"; Thread thread = Thread.currentThread(); final ClassLoader originalClassLoader = thread.getContextClassLoader(); final ClassLoader userClassLoader = new URLClassLoader(new URL[0]); final Runnable command = spy(new Runnable() { @Override public void run() { assertEquals(userClassLoader, Thread.currentThread().getContextClassLoader()); } }); MasterTriggerRestoreHook<String> hook = spy(new MasterTriggerRestoreHook<String>() { @Override public String getIdentifier() { assertEquals(userClassLoader, Thread.currentThread().getContextClassLoader()); return id; } @Nullable @Override public Future<String> triggerCheckpoint(long checkpointId, long timestamp, Executor executor) throws Exception { assertEquals(userClassLoader, Thread.currentThread().getContextClassLoader()); executor.execute(command); return null; } @Override public void restoreCheckpoint(long checkpointId, @Nullable String checkpointData) throws Exception { assertEquals(userClassLoader, Thread.currentThread().getContextClassLoader()); } @Nullable @Override public SimpleVersionedSerializer<String> createCheckpointDataSerializer() { assertEquals(userClassLoader, Thread.currentThread().getContextClassLoader()); return null; } }); MasterTriggerRestoreHook<String> wrapped = MasterHooks.wrapHook(hook, userClassLoader); // verify getIdentifier wrapped.getIdentifier(); verify(hook, times(1)).getIdentifier(); assertEquals(originalClassLoader, thread.getContextClassLoader()); // verify triggerCheckpoint and its wrapped executor TestExecutor testExecutor = new TestExecutor(); wrapped.triggerCheckpoint(0L, 0, testExecutor); assertEquals(originalClassLoader, thread.getContextClassLoader()); assertNotNull(testExecutor.command); testExecutor.command.run(); verify(command, times(1)).run(); assertEquals(originalClassLoader, thread.getContextClassLoader()); // verify restoreCheckpoint wrapped.restoreCheckpoint(0L, ""); verify(hook, times(1)).restoreCheckpoint(eq(0L), eq("")); assertEquals(originalClassLoader, thread.getContextClassLoader()); // verify createCheckpointDataSerializer wrapped.createCheckpointDataSerializer(); verify(hook, times(1)).createCheckpointDataSerializer(); assertEquals(originalClassLoader, thread.getContextClassLoader()); } private static class TestExecutor implements Executor { Runnable command; @Override public void execute(Runnable command) { this.command = command; } } }