/* * Copyright 2008-2012 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.batch.core.partition.support; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Collection; import java.util.HashSet; import java.util.Set; import java.util.TreeSet; import org.junit.Before; import org.junit.Test; import org.springframework.batch.core.ExitStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobExecutionException; import org.springframework.batch.core.JobInterruptedException; import org.springframework.batch.core.StepExecution; import org.springframework.batch.core.partition.StepExecutionSplitter; import org.springframework.batch.core.step.StepSupport; import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.core.task.TaskRejectedException; public class TaskExecutorPartitionHandlerTests { private TaskExecutorPartitionHandler handler = new TaskExecutorPartitionHandler(); private int count = 0; private Collection<String> stepExecutions = new TreeSet<String>(); private StepExecution stepExecution = new StepExecution("step", new JobExecution(1L)); private StepExecutionSplitter stepExecutionSplitter = new StepExecutionSplitter() { @Override public String getStepName() { return stepExecution.getStepName(); } @Override public Set<StepExecution> split(StepExecution stepExecution, int gridSize) throws JobExecutionException { HashSet<StepExecution> result = new HashSet<StepExecution>(); for (int i = gridSize; i-- > 0;) { result.add(stepExecution.getJobExecution().createStepExecution("foo" + i)); } return result; } }; @Before public void setUp() throws Exception { handler.setStep(new StepSupport() { @Override public void execute(StepExecution stepExecution) throws JobInterruptedException { count++; stepExecutions.add(stepExecution.getStepName()); } }); handler.afterPropertiesSet(); } @Test public void testNullStep() throws Exception { handler = new TaskExecutorPartitionHandler(); try { handler.handle(stepExecutionSplitter, stepExecution); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException e) { // expected String message = e.getMessage(); assertTrue("Wrong message: " + message, message.contains("Step")); } } @Test public void testSetGridSize() throws Exception { handler.setGridSize(2); handler.handle(stepExecutionSplitter, stepExecution); assertEquals(2, count); assertEquals("[foo0, foo1]", stepExecutions.toString()); } @Test public void testSetTaskExecutor() throws Exception { handler.setTaskExecutor(new SimpleAsyncTaskExecutor()); handler.handle(stepExecutionSplitter, stepExecution); assertEquals(1, count); } @Test public void testTaskExecutorFailure() throws Exception { handler.setGridSize(2); handler.setTaskExecutor(new TaskExecutor() { @Override public void execute(Runnable task) { if (count > 0) { throw new TaskRejectedException("foo"); } task.run(); } }); Collection<StepExecution> executions = handler.handle(stepExecutionSplitter, stepExecution); new DefaultStepExecutionAggregator().aggregate(stepExecution, executions); assertEquals(1, count); assertEquals(ExitStatus.FAILED.getExitCode(), stepExecution.getExitStatus().getExitCode()); } }