/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.bulk;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest.OpType;
import org.elasticsearch.action.delete.DeleteResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.junit.After;
import org.junit.Before;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class RetryTests extends ESTestCase {
// no need to wait fof a long time in tests
private static final TimeValue DELAY = TimeValue.timeValueMillis(1L);
private static final int CALLS_TO_FAIL = 5;
private MockBulkClient bulkClient;
/**
* Headers that are expected to be sent with all bulk requests.
*/
private Map<String, String> expectedHeaders = new HashMap<>();
@Override
@Before
public void setUp() throws Exception {
super.setUp();
this.bulkClient = new MockBulkClient(getTestName(), CALLS_TO_FAIL);
// Stash some random headers so we can assert that we preserve them
bulkClient.threadPool().getThreadContext().stashContext();
expectedHeaders.clear();
expectedHeaders.put(randomAlphaOfLength(5), randomAlphaOfLength(5));
bulkClient.threadPool().getThreadContext().putHeader(expectedHeaders);
}
@Override
@After
public void tearDown() throws Exception {
super.tearDown();
this.bulkClient.close();
}
private BulkRequest createBulkRequest() {
BulkRequest request = new BulkRequest();
request.add(new UpdateRequest("shop", "products", "1"));
request.add(new UpdateRequest("shop", "products", "2"));
request.add(new UpdateRequest("shop", "products", "3"));
request.add(new UpdateRequest("shop", "products", "4"));
request.add(new UpdateRequest("shop", "products", "5"));
return request;
}
public void testRetryBacksOff() throws Exception {
BackoffPolicy backoff = BackoffPolicy.constantBackoff(DELAY, CALLS_TO_FAIL);
BulkRequest bulkRequest = createBulkRequest();
BulkResponse response = new Retry(EsRejectedExecutionException.class, backoff, bulkClient.threadPool())
.withBackoff(bulkClient::bulk, bulkRequest, bulkClient.settings())
.actionGet();
assertFalse(response.hasFailures());
assertThat(response.getItems().length, equalTo(bulkRequest.numberOfActions()));
}
public void testRetryFailsAfterBackoff() throws Exception {
BackoffPolicy backoff = BackoffPolicy.constantBackoff(DELAY, CALLS_TO_FAIL - 1);
BulkRequest bulkRequest = createBulkRequest();
BulkResponse response = new Retry(EsRejectedExecutionException.class, backoff, bulkClient.threadPool())
.withBackoff(bulkClient::bulk, bulkRequest, bulkClient.settings())
.actionGet();
assertTrue(response.hasFailures());
assertThat(response.getItems().length, equalTo(bulkRequest.numberOfActions()));
}
public void testRetryWithListenerBacksOff() throws Exception {
BackoffPolicy backoff = BackoffPolicy.constantBackoff(DELAY, CALLS_TO_FAIL);
AssertingListener listener = new AssertingListener();
BulkRequest bulkRequest = createBulkRequest();
Retry retry = new Retry(EsRejectedExecutionException.class, backoff, bulkClient.threadPool());
retry.withBackoff(bulkClient::bulk, bulkRequest, listener, bulkClient.settings());
listener.awaitCallbacksCalled();
listener.assertOnResponseCalled();
listener.assertResponseWithoutFailures();
listener.assertResponseWithNumberOfItems(bulkRequest.numberOfActions());
listener.assertOnFailureNeverCalled();
}
public void testRetryWithListenerFailsAfterBacksOff() throws Exception {
BackoffPolicy backoff = BackoffPolicy.constantBackoff(DELAY, CALLS_TO_FAIL - 1);
AssertingListener listener = new AssertingListener();
BulkRequest bulkRequest = createBulkRequest();
Retry retry = new Retry(EsRejectedExecutionException.class, backoff, bulkClient.threadPool());
retry.withBackoff(bulkClient::bulk, bulkRequest, listener, bulkClient.settings());
listener.awaitCallbacksCalled();
listener.assertOnResponseCalled();
listener.assertResponseWithFailures();
listener.assertResponseWithNumberOfItems(bulkRequest.numberOfActions());
listener.assertOnFailureNeverCalled();
}
private static class AssertingListener implements ActionListener<BulkResponse> {
private final CountDownLatch latch;
private final AtomicInteger countOnResponseCalled = new AtomicInteger();
private volatile Throwable lastFailure;
private volatile BulkResponse response;
private AssertingListener() {
latch = new CountDownLatch(1);
}
public void awaitCallbacksCalled() throws InterruptedException {
latch.await();
}
@Override
public void onResponse(BulkResponse bulkItemResponses) {
this.response = bulkItemResponses;
countOnResponseCalled.incrementAndGet();
latch.countDown();
}
@Override
public void onFailure(Exception e) {
this.lastFailure = e;
latch.countDown();
}
public void assertOnResponseCalled() {
assertThat(countOnResponseCalled.get(), equalTo(1));
}
public void assertResponseWithNumberOfItems(int numItems) {
assertThat(response.getItems().length, equalTo(numItems));
}
public void assertResponseWithoutFailures() {
assertThat(response, notNullValue());
assertFalse("Response should not have failures", response.hasFailures());
}
public void assertResponseWithFailures() {
assertThat(response, notNullValue());
assertTrue("Response should have failures", response.hasFailures());
}
public void assertOnFailureNeverCalled() {
assertThat(lastFailure, nullValue());
}
}
private class MockBulkClient extends NoOpClient {
private int numberOfCallsToFail;
private MockBulkClient(String testName, int numberOfCallsToFail) {
super(testName);
this.numberOfCallsToFail = numberOfCallsToFail;
}
@Override
public ActionFuture<BulkResponse> bulk(BulkRequest request) {
PlainActionFuture<BulkResponse> responseFuture = new PlainActionFuture<>();
bulk(request, responseFuture);
return responseFuture;
}
@Override
public void bulk(BulkRequest request, ActionListener<BulkResponse> listener) {
if (false == expectedHeaders.equals(threadPool().getThreadContext().getHeaders())) {
listener.onFailure(
new RuntimeException("Expected " + expectedHeaders + " but got " + threadPool().getThreadContext().getHeaders()));
return;
}
// do everything synchronously, that's fine for a test
boolean shouldFail = numberOfCallsToFail > 0;
numberOfCallsToFail--;
BulkItemResponse[] itemResponses = new BulkItemResponse[request.requests().size()];
// if we have to fail, we need to fail at least once "reliably", the rest can be random
int itemToFail = randomInt(request.requests().size() - 1);
for (int idx = 0; idx < request.requests().size(); idx++) {
if (shouldFail && (randomBoolean() || idx == itemToFail)) {
itemResponses[idx] = failedResponse();
} else {
itemResponses[idx] = successfulResponse();
}
}
listener.onResponse(new BulkResponse(itemResponses, 1000L));
}
private BulkItemResponse successfulResponse() {
return new BulkItemResponse(1, OpType.DELETE, new DeleteResponse());
}
private BulkItemResponse failedResponse() {
return new BulkItemResponse(1, OpType.INDEX, new BulkItemResponse.Failure("test", "test", "1", new EsRejectedExecutionException("pool full")));
}
}
}