/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.app.runtime.spark; import co.cask.tephra.Transaction; import co.cask.tephra.TransactionFailureException; import co.cask.tephra.TransactionManager; import co.cask.tephra.TransactionSystemClient; import co.cask.tephra.inmemory.InMemoryTxSystemClient; import co.cask.tephra.persist.TransactionSnapshot; import com.google.common.collect.ImmutableSet; import org.apache.hadoop.conf.Configuration; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; /** * Unit tests for the {@link SparkTransactionService}. */ public class SparkTransactionServiceTest { private static final Logger LOG = LoggerFactory.getLogger(SparkTransactionServiceTest.class); private static TransactionManager txManager; private static TransactionSystemClient txClient; private static SparkTransactionService sparkTxService; private static SparkTransactionClient sparkTxClient; @BeforeClass public static void init() throws UnknownHostException { txManager = new TransactionManager(new Configuration()); txManager.startAndWait(); txClient = new InMemoryTxSystemClient(txManager); sparkTxService = new SparkTransactionService(txClient, InetAddress.getLoopbackAddress().getCanonicalHostName()); sparkTxService.startAndWait(); sparkTxClient = new SparkTransactionClient(sparkTxService.getBaseURI()); } @AfterClass public static void finish() { sparkTxService.stopAndWait(); txManager.stopAndWait(); } /** * Tests the basic flow of starting a job, make a tx request from a stage and ending the job. */ @Test public void testBasicJobRun() throws Exception { AtomicInteger jobIdGen = new AtomicInteger(); AtomicInteger stageIdGen = new AtomicInteger(); // A successful job run testRunJob(jobIdGen.getAndIncrement(), generateStages(stageIdGen, 3), true); // A failure job run testRunJob(jobIdGen.getAndIncrement(), generateStages(stageIdGen, 4), false); } /** * Tests concurrent jobs submission. */ @Test(timeout = 120000L) public void testConcurrentJobRun() throws Exception { final AtomicInteger jobIdGen = new AtomicInteger(); final AtomicInteger stageIdGen = new AtomicInteger(); // Start 30 jobs concurrently int threads = 30; ExecutorService executor = Executors.newFixedThreadPool(threads); try { final CyclicBarrier barrier = new CyclicBarrier(threads); final Random random = new Random(); CompletionService<Boolean> completionService = new ExecutorCompletionService<>(executor); // For each run, return the verification result for (int i = 0; i < threads; i++) { completionService.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { barrier.await(); try { // Run job with 2-5 stages, with job either succeeded or failed testRunJob(jobIdGen.getAndIncrement(), generateStages(stageIdGen, 2 + random.nextInt(4)), random.nextBoolean()); return true; } catch (Throwable t) { LOG.error("testRunJob failed.", t); return false; } } }); } // All testRunJob must be completed successfully boolean result = true; for (int i = 0; i < threads; i++) { result = result && completionService.take().get(); } Assert.assertTrue(result); } finally { executor.shutdown(); } } /** * Tests the explicit transaction that covers multiple jobs. */ @Test public void testExplicitTransaction() throws Exception { final AtomicInteger jobIdGen = new AtomicInteger(); final AtomicInteger stageIdGen = new AtomicInteger(); Transaction transaction = txClient.startLong(); // Execute two jobs with the same explicit transaction testRunJob(jobIdGen.getAndIncrement(), generateStages(stageIdGen, 2), true, transaction); testRunJob(jobIdGen.getAndIncrement(), generateStages(stageIdGen, 3), true, transaction); // Should be able to commit the transaction Assert.assertTrue(txClient.commit(transaction)); } /** * Tests the case where starting of transaction failed. */ @Test public void testFailureTransaction() throws Exception { TransactionManager txManager = new TransactionManager(new Configuration()) { @Override public Transaction startLong() { throw new IllegalStateException("Cannot start long transaction"); } }; txManager.startAndWait(); try { SparkTransactionService sparkTxService = new SparkTransactionService( new InMemoryTxSystemClient(txManager), InetAddress.getLoopbackAddress().getCanonicalHostName()); sparkTxService.startAndWait(); try { // Start a job sparkTxService.jobStarted(1, ImmutableSet.of(2)); // Make a call to the stage transaction endpoint, it should throw TransactionFailureException try { new SparkTransactionClient(sparkTxService.getBaseURI()).getTransaction(2, 1, TimeUnit.SECONDS); Assert.fail("Should failed to get transaction"); } catch (TransactionFailureException e) { // expected } // End the job sparkTxService.jobEnded(1, false); } finally { sparkTxService.stopAndWait(); } } finally { txManager.stopAndWait(); } } /** * Tests the retry timeout logic in the {@link SparkTransactionClient}. */ @Test public void testClientRetry() throws Exception { final Set<Integer> stages = ImmutableSet.of(2); // Delay the call to jobStarted by 3 seconds Executors.newSingleThreadScheduledExecutor().schedule(new Runnable() { @Override public void run() { sparkTxService.jobStarted(1, stages); } }, 3, TimeUnit.SECONDS); // Should be able to get the transaction, hence no exception sparkTxClient.getTransaction(2, 10, TimeUnit.SECONDS); sparkTxService.jobEnded(1, true); } /** * Simulates a single job run which contains multiple stages. * * @param jobId the job id * @param stages stages of the job * @param jobSucceeded end result of the job */ private void testRunJob(int jobId, Set<Integer> stages, boolean jobSucceeded) throws Exception { testRunJob(jobId, stages, jobSucceeded, null); } /** * Simulates a single job run which contains multiple stages with an optional explicit {@link Transaction} to use. * * @param jobId the job id * @param stages stages of the job * @param jobSucceeded end result of the job * @param explicitTransaction the job transaction to use if not {@code null} */ private void testRunJob(int jobId, Set<Integer> stages, boolean jobSucceeded, @Nullable final Transaction explicitTransaction) throws Exception { // Before job start, no transaction will be associated with the stages verifyStagesTransactions(stages, new ClientTransactionVerifier() { @Override public boolean verify(@Nullable Transaction transaction, @Nullable Throwable failureCause) throws Exception { return transaction == null && failureCause instanceof TimeoutException; } }); // Now start the job if (explicitTransaction == null) { sparkTxService.jobStarted(jobId, stages); } else { sparkTxService.jobStarted(jobId, stages, new TransactionInfo() { @Override public Transaction getTransaction() { return explicitTransaction; } @Override public boolean commitOnJobEnded() { return false; } @Override public void onJobStarted() { // no-op } @Override public void onTransactionCompleted(boolean jobSucceeded, @Nullable TransactionFailureException failureCause) { // no-op } }); } // For all stages, it should get the same transaction final Set<Transaction> transactions = Collections.newSetFromMap(new ConcurrentHashMap<Transaction, Boolean>()); verifyStagesTransactions(stages, new ClientTransactionVerifier() { @Override public boolean verify(@Nullable Transaction transaction, @Nullable Throwable failureCause) throws Exception { transactions.add(new TransactionWrapper(transaction)); return transaction != null; } }); // Transactions returned for all stages belonging to the same job must return the same transaction Assert.assertEquals(1, transactions.size()); // The transaction must be in progress Transaction transaction = transactions.iterator().next(); Assert.assertTrue(txManager.getCurrentState().getInProgress().containsKey(transaction.getWritePointer())); // If run with an explicit transaction, then all stages' transactions must be the same as the explicit transaction if (explicitTransaction != null) { Assert.assertEquals(new TransactionWrapper(explicitTransaction), transaction); } // Now finish the job sparkTxService.jobEnded(jobId, jobSucceeded); // After job finished, no transaction will be associated with the stages verifyStagesTransactions(stages, new ClientTransactionVerifier() { @Override public boolean verify(@Nullable Transaction transaction, @Nullable Throwable failureCause) throws Exception { return transaction == null && failureCause instanceof TimeoutException; } }); // Check the transaction state based on the job result TransactionSnapshot txState = txManager.getCurrentState(); // If explicit transaction is used, the transaction should still be in-progress if (explicitTransaction != null) { Assert.assertTrue(txState.getInProgress().containsKey(transaction.getWritePointer())); } else { // With implicit transaction, after job completed, the tx shouldn't be in-progress Assert.assertFalse(txState.getInProgress().containsKey(transaction.getWritePointer())); if (jobSucceeded) { // Transaction must not be in the invalid list Assert.assertFalse(txState.getInvalid().contains(transaction.getWritePointer())); } else { // Transaction must be in the invalid list Assert.assertTrue(txState.getInvalid().contains(transaction.getWritePointer())); } } } /** * Creates a new set of stage ids. */ private Set<Integer> generateStages(AtomicInteger idGen, int stages) { Set<Integer> result = new LinkedHashSet<>(); for (int i = 0; i < stages; i++) { result.add(idGen.getAndIncrement()); } return result; } /** * Verifies the result of get stage transaction for the given set of stages. * The get transaction will be called concurrently for all stages. * * @param stages set of stages to verify * @param verifier a {@link ClientTransactionVerifier} to verify the http call result. */ private void verifyStagesTransactions(Set<Integer> stages, final ClientTransactionVerifier verifier) throws Exception { final CyclicBarrier barrier = new CyclicBarrier(stages.size()); final ExecutorService executor = Executors.newFixedThreadPool(stages.size()); try { CompletionService<Boolean> completionService = new ExecutorCompletionService<>(executor); for (final int stageId : stages) { completionService.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { barrier.await(); try { return verifier.verify(sparkTxClient.getTransaction(stageId, 0, TimeUnit.SECONDS), null); } catch (Throwable t) { return verifier.verify(null, t); } } }); } boolean result = true; for (int i = 0; i < stages.size(); i++) { result = result && completionService.poll(10, TimeUnit.SECONDS).get(); } // All verifications must be true Assert.assertTrue(result); } finally { executor.shutdown(); } } private interface ClientTransactionVerifier { /** * Verifies the result of a call to the transaction service. */ boolean verify(@Nullable Transaction transaction, @Nullable Throwable failureCause) throws Exception; } /** * A wrapper class for Transaction to provide equals and hashCode method. */ private static final class TransactionWrapper extends Transaction { private final Transaction transaction; TransactionWrapper(Transaction tx) { super(tx.getReadPointer(), tx.getTransactionId(), tx.getWritePointer(), tx.getInvalids(), tx.getInProgress(), tx.getFirstShortInProgress(), tx.getType(), tx.getCheckpointWritePointers(), tx.getVisibilityLevel()); this.transaction = tx; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } Transaction other = ((TransactionWrapper) o).transaction; return (transaction.getReadPointer() == other.getReadPointer()) && (transaction.getTransactionId() == other.getTransactionId()) && (transaction.getWritePointer() == other.getWritePointer()) && Arrays.equals(transaction.getInvalids(), other.getInvalids()) && Arrays.equals(transaction.getInProgress(), other.getInProgress()) && (transaction.getFirstShortInProgress() == other.getFirstShortInProgress()) && (transaction.getType() == other.getType()) && Arrays.equals(transaction.getCheckpointWritePointers(), other.getCheckpointWritePointers()) && (transaction.getVisibilityLevel() == other.getVisibilityLevel()); } @Override public int hashCode() { return Objects.hash( transaction.getReadPointer(), transaction.getTransactionId(), transaction.getWritePointer(), Arrays.hashCode(transaction.getInvalids()), Arrays.hashCode(transaction.getInProgress()), transaction.getFirstShortInProgress(), transaction.getType(), Arrays.hashCode(transaction.getCheckpointWritePointers()), transaction.getVisibilityLevel() ); } } }