/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.support.replication;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionWriteResponse;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.UnavailableShardsException;
import org.elasticsearch.action.admin.indices.flush.FlushRequest;
import org.elasticsearch.action.admin.indices.flush.FlushResponse;
import org.elasticsearch.action.admin.indices.flush.TransportFlushAction;
import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
import org.elasticsearch.cluster.ClusterService;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.cluster.TestClusterService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.local.LocalTransport;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.IOException;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state;
import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndOneReplica;
import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithNoShard;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class BroadcastReplicationTests extends ESTestCase {
private static ThreadPool threadPool;
private TestClusterService clusterService;
private static CircuitBreakerService circuitBreakerService;
private TransportService transportService;
private LocalTransport transport;
private TestBroadcastReplicationAction broadcastReplicationAction;
@BeforeClass
public static void beforeClass() {
threadPool = new ThreadPool("BroadcastReplicationTests");
circuitBreakerService = new NoneCircuitBreakerService();
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
transport = new LocalTransport(Settings.EMPTY, threadPool, Version.CURRENT, new NamedWriteableRegistry(), circuitBreakerService);
clusterService = new TestClusterService(threadPool);
transportService = new TransportService(transport, threadPool);
transportService.start();
transportService.acceptIncomingRequests();
broadcastReplicationAction = new TestBroadcastReplicationAction(Settings.EMPTY, threadPool, clusterService, transportService, new ActionFilters(new HashSet<ActionFilter>()), new IndexNameExpressionResolver(Settings.EMPTY), null);
}
@AfterClass
public static void afterClass() {
ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
threadPool = null;
}
@Test
public void testNotStartedPrimary() throws InterruptedException, ExecutionException, IOException {
final String index = "test";
clusterService.setState(state(index, randomBoolean(),
randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.UNASSIGNED, ShardRoutingState.UNASSIGNED));
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
Future<BroadcastResponse> response = (broadcastReplicationAction.execute(new BroadcastRequest().indices(index)));
for (Tuple<ShardId, ActionListener<ActionWriteResponse>> shardRequests : broadcastReplicationAction.capturedShardRequests) {
if (randomBoolean()) {
shardRequests.v2().onFailure(new NoShardAvailableActionException(shardRequests.v1()));
} else {
shardRequests.v2().onFailure(new UnavailableShardsException(shardRequests.v1(), "test exception"));
}
}
response.get();
logger.info("total shards: {}, ", response.get().getTotalShards());
// we expect no failures here because UnavailableShardsException does not count as failed
assertBroadcastResponse(2, 0, 0, response.get(), null);
}
@Test
public void testStartedPrimary() throws InterruptedException, ExecutionException, IOException {
final String index = "test";
clusterService.setState(state(index, randomBoolean(),
ShardRoutingState.STARTED));
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
Future<BroadcastResponse> response = (broadcastReplicationAction.execute(new BroadcastRequest().indices(index)));
for (Tuple<ShardId, ActionListener<ActionWriteResponse>> shardRequests : broadcastReplicationAction.capturedShardRequests) {
ActionWriteResponse actionWriteResponse = new ActionWriteResponse();
actionWriteResponse.setShardInfo(new ActionWriteResponse.ShardInfo(1, 1, new ActionWriteResponse.ShardInfo.Failure[0]));
shardRequests.v2().onResponse(actionWriteResponse);
}
logger.info("total shards: {}, ", response.get().getTotalShards());
assertBroadcastResponse(1, 1, 0, response.get(), null);
}
@Test
public void testResultCombine() throws InterruptedException, ExecutionException, IOException {
final String index = "test";
int numShards = randomInt(2)+1;
clusterService.setState(stateWithAssignedPrimariesAndOneReplica(index, numShards));
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
Future<BroadcastResponse> response = (broadcastReplicationAction.execute(new BroadcastRequest().indices(index)));
int succeeded = 0;
int failed = 0;
for (Tuple<ShardId, ActionListener<ActionWriteResponse>> shardRequests : broadcastReplicationAction.capturedShardRequests) {
if (randomBoolean()) {
ActionWriteResponse.ShardInfo.Failure[] failures = new ActionWriteResponse.ShardInfo.Failure[0];
int shardsSucceeded = randomInt(1) + 1;
succeeded += shardsSucceeded;
ActionWriteResponse actionWriteResponse = new ActionWriteResponse();
if (shardsSucceeded == 1 && randomBoolean()) {
//sometimes add failure (no failure means shard unavailable)
failures = new ActionWriteResponse.ShardInfo.Failure[1];
failures[0] = new ActionWriteResponse.ShardInfo.Failure(index, shardRequests.v1().id(), null, new Exception("pretend shard failed"), RestStatus.GATEWAY_TIMEOUT, false);
failed++;
}
actionWriteResponse.setShardInfo(new ActionWriteResponse.ShardInfo(2, shardsSucceeded, failures));
shardRequests.v2().onResponse(actionWriteResponse);
} else {
// sometimes fail
failed += 2;
// just add a general exception and see if failed shards will be incremented by 2
shardRequests.v2().onFailure(new Exception("pretend shard failed"));
}
}
assertBroadcastResponse(2 * numShards, succeeded, failed, response.get(), Exception.class);
}
@Test
public void testNoShards() throws InterruptedException, ExecutionException, IOException {
clusterService.setState(stateWithNoShard());
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
BroadcastResponse response = executeAndAssertImmediateResponse(broadcastReplicationAction, new BroadcastRequest());
assertBroadcastResponse(0, 0, 0, response, null);
}
@Test
public void testShardsList() throws InterruptedException, ExecutionException {
final String index = "test";
final ShardId shardId = new ShardId(index, 0);
ClusterState clusterState = state(index, randomBoolean(),
randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.UNASSIGNED, ShardRoutingState.UNASSIGNED);
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
List<ShardId> shards = broadcastReplicationAction.shards(new BroadcastRequest().indices(shardId.index().name()), clusterState);
assertThat(shards.size(), equalTo(1));
assertThat(shards.get(0), equalTo(shardId));
}
private class TestBroadcastReplicationAction extends TransportBroadcastReplicationAction<BroadcastRequest, BroadcastResponse, ReplicationRequest, ActionWriteResponse> {
protected final Set<Tuple<ShardId, ActionListener<ActionWriteResponse>>> capturedShardRequests = ConcurrentCollections.newConcurrentSet();
public TestBroadcastReplicationAction(Settings settings, ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, TransportReplicationAction replicatedBroadcastShardAction) {
super("test-broadcast-replication-action", BroadcastRequest.class, settings, threadPool, clusterService, transportService, actionFilters, indexNameExpressionResolver, replicatedBroadcastShardAction);
}
@Override
protected ActionWriteResponse newShardResponse() {
return new ActionWriteResponse();
}
@Override
protected ReplicationRequest newShardRequest(BroadcastRequest request, ShardId shardId) {
return new ReplicationRequest().setShardId(shardId);
}
@Override
protected BroadcastResponse newResponse(int successfulShards, int failedShards, int totalNumCopies, List shardFailures) {
return new BroadcastResponse(totalNumCopies, successfulShards, failedShards, shardFailures);
}
@Override
protected void shardExecute(Task task, BroadcastRequest request, ShardId shardId, ActionListener<ActionWriteResponse> shardActionListener) {
capturedShardRequests.add(new Tuple<>(shardId, shardActionListener));
}
}
public FlushResponse assertImmediateResponse(String index, TransportFlushAction flushAction) throws InterruptedException, ExecutionException {
Date beginDate = new Date();
FlushResponse flushResponse = flushAction.execute(new FlushRequest(index)).get();
Date endDate = new Date();
long maxTime = 500;
assertThat("this should not take longer than " + maxTime + " ms. The request hangs somewhere", endDate.getTime() - beginDate.getTime(), lessThanOrEqualTo(maxTime));
return flushResponse;
}
public BroadcastResponse executeAndAssertImmediateResponse(TransportBroadcastReplicationAction broadcastAction, BroadcastRequest request) throws InterruptedException, ExecutionException {
return (BroadcastResponse) broadcastAction.execute(request).actionGet("5s");
}
private void assertBroadcastResponse(int total, int successful, int failed, BroadcastResponse response, Class exceptionClass) {
assertThat(response.getSuccessfulShards(), equalTo(successful));
assertThat(response.getTotalShards(), equalTo(total));
assertThat(response.getFailedShards(), equalTo(failed));
for (int i = 0; i < failed; i++) {
assertThat(response.getShardFailures()[0].getCause().getCause(), instanceOf(exceptionClass));
}
}
}