/*
* Copyright (c) 2010-2016. Axon Framework
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.axonframework.messaging.unitofwork;
import org.axonframework.common.MockException;
import org.axonframework.messaging.GenericMessage;
import org.axonframework.messaging.Message;
import org.junit.Before;
import org.junit.Test;
import java.util.*;
import static junit.framework.TestCase.*;
import static org.axonframework.messaging.unitofwork.UnitOfWork.Phase.*;
/**
* @author Rene de Waele
*/
public class BatchingUnitOfWorkTest {
private List<PhaseTransition> transitions;
private BatchingUnitOfWork<?> subject;
@Before
public void setUp() {
transitions = new ArrayList<>();
}
@Test
public void testExecuteTask() throws Exception {
List<Message<?>> messages = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
subject = new BatchingUnitOfWork<>(messages);
subject.executeWithResult(() -> {
registerListeners(subject);
return resultFor(subject.getMessage());
});
validatePhaseTransitions(Arrays.asList(PREPARE_COMMIT, COMMIT, AFTER_COMMIT, CLEANUP), messages);
Map<Message<?>, ExecutionResult> expectedResults = new HashMap<>();
messages.forEach(m -> expectedResults.put(m, new ExecutionResult(resultFor(m))));
assertEquals(expectedResults, subject.getExecutionResults());
}
@Test
public void testRollback() throws Exception {
List<Message<?>> messages = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
subject = new BatchingUnitOfWork<>(messages);
MockException e = new MockException();
try {
subject.executeWithResult(() -> {
registerListeners(subject);
if (subject.getMessage().getPayload().equals(1)) {
throw e;
}
return resultFor(subject.getMessage());
});
} catch (Exception ignored) {
}
validatePhaseTransitions(Arrays.asList(ROLLBACK, CLEANUP), messages.subList(0, 2));
Map<Message<?>, ExecutionResult> expectedResult = new HashMap<>();
messages.forEach(m -> expectedResult.put(m, new ExecutionResult(e)));
assertEquals(expectedResult, subject.getExecutionResults());
}
@Test
public void testSuppressedExceptionOnRollback() throws Exception {
List<Message<?>> messages = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
subject = new BatchingUnitOfWork<>(messages);
MockException taskException = new MockException("task exception");
MockException commitException = new MockException("commit exception");
try {
subject.executeWithResult(() -> {
registerListeners(subject);
if (subject.getMessage().getPayload().equals(2)) {
subject.addHandler(PREPARE_COMMIT, u -> {
throw commitException;
});
throw taskException;
}
return resultFor(subject.getMessage());
}, e -> false);
} catch (Exception ignored) {
}
validatePhaseTransitions(Arrays.asList(PREPARE_COMMIT, ROLLBACK, CLEANUP), messages);
Map<Message<?>, ExecutionResult> expectedResult = new HashMap<>();
expectedResult.put(messages.get(0), new ExecutionResult(commitException));
expectedResult.put(messages.get(1), new ExecutionResult(commitException));
expectedResult.put(messages.get(2), new ExecutionResult(taskException));
assertEquals(expectedResult, subject.getExecutionResults());
assertSame(commitException, taskException.getSuppressed()[0]);
}
private void registerListeners(UnitOfWork<?> unitOfWork) {
unitOfWork.onPrepareCommit(u -> transitions.add(new PhaseTransition(u.getMessage(), PREPARE_COMMIT)));
unitOfWork.onCommit(u -> transitions.add(new PhaseTransition(u.getMessage(), COMMIT)));
unitOfWork.afterCommit(u -> transitions.add(new PhaseTransition(u.getMessage(), AFTER_COMMIT)));
unitOfWork.onRollback(u -> transitions.add(new PhaseTransition(u.getMessage(), ROLLBACK)));
unitOfWork.onCleanup(u -> transitions.add(new PhaseTransition(u.getMessage(), CLEANUP)));
}
private static Message<?> toMessage(Object payload) {
return new GenericMessage<>(payload);
}
public static Object resultFor(Message<?> message) {
return "Result for: " + message.getPayload();
}
private void validatePhaseTransitions(List<UnitOfWork.Phase> phases, List<Message<?>> messages) {
Iterator<PhaseTransition> iterator = transitions.iterator();
for (UnitOfWork.Phase phase : phases) {
Iterator<Message<?>> messageIterator = phase.isReverseCallbackOrder()
? new LinkedList<>(messages).descendingIterator() : messages.iterator();
messageIterator.forEachRemaining(message -> {
PhaseTransition expected = new PhaseTransition(message, phase);
assertTrue(iterator.hasNext());
PhaseTransition actual = iterator.next();
assertEquals(expected, actual);
});
}
}
private static class PhaseTransition {
private final UnitOfWork.Phase phase;
private final Message<?> message;
public PhaseTransition(Message<?> message, UnitOfWork.Phase phase) {
this.message = message;
this.phase = phase;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PhaseTransition that = (PhaseTransition) o;
return phase == that.phase &&
Objects.equals(message, that.message);
}
@Override
public int hashCode() {
return Objects.hash(phase, message);
}
@Override
public String toString() {
return phase + " -> " + message.getPayload();
}
}
}