// =================================================================================================
// Copyright 2011 Twitter, Inc.
// -------------------------------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this work except in compliance with the License.
// You may obtain a copy of the License in the LICENSE file, or at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =================================================================================================
package com.twitter.common.thrift.callers;
import java.lang.reflect.Method;
import java.util.concurrent.atomic.AtomicLong;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableSet;
import org.junit.Before;
import org.junit.Test;
import com.twitter.common.stats.StatsProvider;
import static org.easymock.EasyMock.expect;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
/**
* TODO(William Farner): Test async.
*
* @author William Farner
*/
public class RetryingCallerTest extends AbstractCallerTest {
private static final int NUM_RETRIES = 2;
private static final ImmutableSet<Class<? extends Exception>> NO_RETRYABLE =
ImmutableSet.of();
private static final ImmutableSet<Class<? extends Exception>> RETRYABLE =
ImmutableSet.<Class<? extends Exception>>of(IllegalArgumentException.class);
private StatsProvider statsProvider;
@Before
public void mySetUp() {
statsProvider = createMock(StatsProvider.class);
}
@Test
public void testSuccess() throws Throwable {
expectCall("foo");
control.replay();
RetryingCaller retry = makeRetry(false, NO_RETRYABLE);
assertThat(call(retry), is("foo"));
assertThat(memoizeGetCounter.get(methodA).get(), is(0L));
}
@Test
public void testException() throws Throwable {
Throwable exception = nonRetryable();
expectCall(exception);
control.replay();
RetryingCaller retry = makeRetry(false, NO_RETRYABLE);
try {
call(retry);
fail();
} catch (Throwable t) {
assertThat(t, is(exception));
}
assertThat(memoizeGetCounter.get(methodA).get(), is(0L));
}
@Test
public void testRetriesSuccess() throws Throwable {
expectCall(retryable());
expectCall(retryable());
expectCall("foo");
control.replay();
RetryingCaller retry = makeRetry(false, RETRYABLE);
assertThat(call(retry), is("foo"));
assertThat(memoizeGetCounter.get(methodA).get(), is((long) NUM_RETRIES));
}
@Test
public void testRetryLimit() throws Throwable {
expectCall(retryable());
expectCall(retryable());
Throwable exception = retryable();
expectCall(exception);
control.replay();
RetryingCaller retry = makeRetry(false, RETRYABLE);
try {
call(retry);
fail();
} catch (Throwable t) {
assertThat(t, is(exception));
}
assertThat(memoizeGetCounter.get(methodA).get(), is(2L));
}
private Throwable retryable() {
return new IllegalArgumentException();
}
private Throwable nonRetryable() {
return new NullPointerException();
}
private LoadingCache<Method, AtomicLong> memoizeGetCounter = CacheBuilder.newBuilder().build(
new CacheLoader<Method, AtomicLong>() {
@Override public AtomicLong load(Method method) {
AtomicLong atomicLong = new AtomicLong();
expect(statsProvider.makeCounter("test_" + method.getName() + "_retries"))
.andReturn(atomicLong);
return atomicLong;
}
});
@Override
protected void expectCall(String returnValue) throws Throwable {
super.expectCall(returnValue);
memoizeGetCounter.get(methodA);
}
@Override
protected void expectCall(Throwable thrown) throws Throwable {
super.expectCall(thrown);
memoizeGetCounter.get(methodA);
}
private RetryingCaller makeRetry(boolean async,
ImmutableSet<Class<? extends Exception>> retryableExceptions) {
return new RetryingCaller(caller, async, statsProvider, "test", NUM_RETRIES,
retryableExceptions, false);
}
}