/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.action.support.replication; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.WriteResponse; import org.elasticsearch.action.support.replication.ReplicationOperation.ReplicaResponse; import org.elasticsearch.client.transport.NoNodeAvailableException; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportService; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.mockito.ArgumentCaptor; import java.util.HashSet; import java.util.Locale; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TransportWriteActionTests extends ESTestCase { private static ThreadPool threadPool; private ClusterService clusterService; private IndexShard indexShard; private Translog.Location location; @BeforeClass public static void beforeClass() { threadPool = new TestThreadPool("ShardReplicationTests"); } @Before public void initCommonMocks() { indexShard = mock(IndexShard.class); location = mock(Translog.Location.class); clusterService = createClusterService(threadPool); } @After public void tearDown() throws Exception { super.tearDown(); clusterService.close(); } @AfterClass public static void afterClass() { ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); threadPool = null; } <T> void assertListenerThrows(String msg, PlainActionFuture<T> listener, Class<?> klass) throws InterruptedException { try { listener.get(); fail(msg); } catch (ExecutionException ex) { assertThat(ex.getCause(), instanceOf(klass)); } } public void testPrimaryNoRefreshCall() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.NONE); // The default, but we'll set it anyway just to be explicit TestAction testAction = new TestAction(); TransportWriteAction.WritePrimaryResult<TestRequest, TestResponse> result = testAction.shardOperationOnPrimary(request, indexShard); CapturingActionListener<TestResponse> listener = new CapturingActionListener<>(); result.respond(listener); assertNotNull(listener.response); assertNull(listener.failure); verify(indexShard, never()).refresh(any()); verify(indexShard, never()).addRefreshListener(any(), any()); } public void testReplicaNoRefreshCall() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.NONE); // The default, but we'll set it anyway just to be explicit TestAction testAction = new TestAction(); TransportWriteAction.WriteReplicaResult<TestRequest> result = testAction.shardOperationOnReplica(request, indexShard); CapturingActionListener<TransportResponse.Empty> listener = new CapturingActionListener<>(); result.respond(listener); assertNotNull(listener.response); assertNull(listener.failure); verify(indexShard, never()).refresh(any()); verify(indexShard, never()).addRefreshListener(any(), any()); } public void testPrimaryImmediateRefresh() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.IMMEDIATE); TestAction testAction = new TestAction(); TransportWriteAction.WritePrimaryResult<TestRequest, TestResponse> result = testAction.shardOperationOnPrimary(request, indexShard); CapturingActionListener<TestResponse> listener = new CapturingActionListener<>(); result.respond(listener); assertNotNull(listener.response); assertNull(listener.failure); assertTrue(listener.response.forcedRefresh); verify(indexShard).refresh("refresh_flag_index"); verify(indexShard, never()).addRefreshListener(any(), any()); } public void testReplicaImmediateRefresh() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.IMMEDIATE); TestAction testAction = new TestAction(); TransportWriteAction.WriteReplicaResult<TestRequest> result = testAction.shardOperationOnReplica(request, indexShard); CapturingActionListener<TransportResponse.Empty> listener = new CapturingActionListener<>(); result.respond(listener); assertNotNull(listener.response); assertNull(listener.failure); verify(indexShard).refresh("refresh_flag_index"); verify(indexShard, never()).addRefreshListener(any(), any()); } public void testPrimaryWaitForRefresh() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); TestAction testAction = new TestAction(); TransportWriteAction.WritePrimaryResult<TestRequest, TestResponse> result = testAction.shardOperationOnPrimary(request, indexShard); CapturingActionListener<TestResponse> listener = new CapturingActionListener<>(); result.respond(listener); assertNull(listener.response); // Haven't reallresponded yet @SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<Consumer<Boolean>> refreshListener = ArgumentCaptor.forClass((Class) Consumer.class); verify(indexShard, never()).refresh(any()); verify(indexShard).addRefreshListener(any(), refreshListener.capture()); // Now we can fire the listener manually and we'll get a response boolean forcedRefresh = randomBoolean(); refreshListener.getValue().accept(forcedRefresh); assertNotNull(listener.response); assertNull(listener.failure); assertEquals(forcedRefresh, listener.response.forcedRefresh); } public void testReplicaWaitForRefresh() throws Exception { TestRequest request = new TestRequest(); request.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); TestAction testAction = new TestAction(); TransportWriteAction.WriteReplicaResult<TestRequest> result = testAction.shardOperationOnReplica(request, indexShard); CapturingActionListener<TransportResponse.Empty> listener = new CapturingActionListener<>(); result.respond(listener); assertNull(listener.response); // Haven't responded yet @SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<Consumer<Boolean>> refreshListener = ArgumentCaptor.forClass((Class) Consumer.class); verify(indexShard, never()).refresh(any()); verify(indexShard).addRefreshListener(any(), refreshListener.capture()); // Now we can fire the listener manually and we'll get a response boolean forcedRefresh = randomBoolean(); refreshListener.getValue().accept(forcedRefresh); assertNotNull(listener.response); assertNull(listener.failure); } public void testDocumentFailureInShardOperationOnPrimary() throws Exception { TestRequest request = new TestRequest(); TestAction testAction = new TestAction(true, true); TransportWriteAction.WritePrimaryResult<TestRequest, TestResponse> writePrimaryResult = testAction.shardOperationOnPrimary(request, indexShard); CapturingActionListener<TestResponse> listener = new CapturingActionListener<>(); writePrimaryResult.respond(listener); assertNull(listener.response); assertNotNull(listener.failure); } public void testDocumentFailureInShardOperationOnReplica() throws Exception { TestRequest request = new TestRequest(); TestAction testAction = new TestAction(randomBoolean(), true); TransportWriteAction.WriteReplicaResult<TestRequest> writeReplicaResult = testAction.shardOperationOnReplica(request, indexShard); CapturingActionListener<TransportResponse.Empty> listener = new CapturingActionListener<>(); writeReplicaResult.respond(listener); assertNull(listener.response); assertNotNull(listener.failure); } public void testReplicaProxy() throws InterruptedException, ExecutionException { CapturingTransport transport = new CapturingTransport(); TransportService transportService = new TransportService(clusterService.getSettings(), transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> clusterService.localNode(), null); transportService.start(); transportService.acceptIncomingRequests(); ShardStateAction shardStateAction = new ShardStateAction(Settings.EMPTY, clusterService, transportService, null, null, threadPool); TestAction action = action = new TestAction(Settings.EMPTY, "testAction", transportService, clusterService, shardStateAction, threadPool); ReplicationOperation.Replicas proxy = action.newReplicasProxy(); final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); ClusterState state = ClusterStateCreationUtils.stateWithActivePrimary(index, true, 1 + randomInt(3), randomInt(2)); logger.info("using state: {}", state); ClusterServiceUtils.setState(clusterService, state); // check that at unknown node fails PlainActionFuture<ReplicaResponse> listener = new PlainActionFuture<>(); proxy.performOn( TestShardRouting.newShardRouting(shardId, "NOT THERE", false, randomFrom(ShardRoutingState.values())), new TestRequest(), randomNonNegativeLong(), listener); assertTrue(listener.isDone()); assertListenerThrows("non existent node should throw a NoNodeAvailableException", listener, NoNodeAvailableException.class); final IndexShardRoutingTable shardRoutings = state.routingTable().shardRoutingTable(shardId); final ShardRouting replica = randomFrom(shardRoutings.replicaShards().stream() .filter(ShardRouting::assignedToNode).collect(Collectors.toList())); listener = new PlainActionFuture<>(); proxy.performOn(replica, new TestRequest(), randomNonNegativeLong(), listener); assertFalse(listener.isDone()); CapturingTransport.CapturedRequest[] captures = transport.getCapturedRequestsAndClear(); assertThat(captures, arrayWithSize(1)); if (randomBoolean()) { final TransportReplicationAction.ReplicaResponse response = new TransportReplicationAction.ReplicaResponse(randomAlphaOfLength(10), randomLong()); transport.handleResponse(captures[0].requestId, response); assertTrue(listener.isDone()); assertThat(listener.get(), equalTo(response)); } else if (randomBoolean()) { transport.handleRemoteError(captures[0].requestId, new ElasticsearchException("simulated")); assertTrue(listener.isDone()); assertListenerThrows("listener should reflect remote error", listener, ElasticsearchException.class); } else { transport.handleError(captures[0].requestId, new TransportException("simulated")); assertTrue(listener.isDone()); assertListenerThrows("listener should reflect remote error", listener, TransportException.class); } AtomicReference<Object> failure = new AtomicReference<>(); AtomicReference<Object> ignoredFailure = new AtomicReference<>(); AtomicBoolean success = new AtomicBoolean(); proxy.failShardIfNeeded(replica, randomIntBetween(1, 10), "test", new ElasticsearchException("simulated"), () -> success.set(true), failure::set, ignoredFailure::set ); CapturingTransport.CapturedRequest[] shardFailedRequests = transport.getCapturedRequestsAndClear(); // A write replication action proxy should fail the shard assertEquals(1, shardFailedRequests.length); CapturingTransport.CapturedRequest shardFailedRequest = shardFailedRequests[0]; ShardStateAction.ShardEntry shardEntry = (ShardStateAction.ShardEntry) shardFailedRequest.request; // the shard the request was sent to and the shard to be failed should be the same assertEquals(shardEntry.getShardId(), replica.shardId()); assertEquals(shardEntry.getAllocationId(), replica.allocationId().getId()); if (randomBoolean()) { // simulate success transport.handleResponse(shardFailedRequest.requestId, TransportResponse.Empty.INSTANCE); assertTrue(success.get()); assertNull(failure.get()); assertNull(ignoredFailure.get()); } else if (randomBoolean()) { // simulate the primary has been demoted transport.handleRemoteError(shardFailedRequest.requestId, new ShardStateAction.NoLongerPrimaryShardException(replica.shardId(), "shard-failed-test")); assertFalse(success.get()); assertNotNull(failure.get()); assertNull(ignoredFailure.get()); } else { // simulated an "ignored" exception transport.handleRemoteError(shardFailedRequest.requestId, new NodeClosedException(state.nodes().getLocalNode())); assertFalse(success.get()); assertNull(failure.get()); assertNotNull(ignoredFailure.get()); } } private class TestAction extends TransportWriteAction<TestRequest, TestRequest, TestResponse> { private final boolean withDocumentFailureOnPrimary; private final boolean withDocumentFailureOnReplica; protected TestAction() { this(false, false); } protected TestAction(boolean withDocumentFailureOnPrimary, boolean withDocumentFailureOnReplica) { super(Settings.EMPTY, "test", new TransportService(Settings.EMPTY, null, null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null), null, null, null, null, new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(Settings.EMPTY), TestRequest::new, TestRequest::new, ThreadPool.Names.SAME); this.withDocumentFailureOnPrimary = withDocumentFailureOnPrimary; this.withDocumentFailureOnReplica = withDocumentFailureOnReplica; } protected TestAction(Settings settings, String actionName, TransportService transportService, ClusterService clusterService, ShardStateAction shardStateAction, ThreadPool threadPool) { super(settings, actionName, transportService, clusterService, mockIndicesService(clusterService), threadPool, shardStateAction, new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(Settings.EMPTY), TestRequest::new, TestRequest::new, ThreadPool.Names.SAME); this.withDocumentFailureOnPrimary = false; this.withDocumentFailureOnReplica = false; } @Override protected TestResponse newResponseInstance() { return new TestResponse(); } @Override protected WritePrimaryResult<TestRequest, TestResponse> shardOperationOnPrimary( TestRequest request, IndexShard primary) throws Exception { final WritePrimaryResult<TestRequest, TestResponse> primaryResult; if (withDocumentFailureOnPrimary) { primaryResult = new WritePrimaryResult<>(request, null, null, new RuntimeException("simulated"), primary, logger); } else { primaryResult = new WritePrimaryResult<>(request, new TestResponse(), location, null, primary, logger); } return primaryResult; } @Override protected WriteReplicaResult<TestRequest> shardOperationOnReplica(TestRequest request, IndexShard replica) throws Exception { final WriteReplicaResult<TestRequest> replicaResult; if (withDocumentFailureOnReplica) { replicaResult = new WriteReplicaResult<>(request, null, new RuntimeException("simulated"), replica, logger); } else { replicaResult = new WriteReplicaResult<>(request, location, null, replica, logger); } return replicaResult; } } final IndexService mockIndexService(final IndexMetaData indexMetaData, ClusterService clusterService) { final IndexService indexService = mock(IndexService.class); when(indexService.getShard(anyInt())).then(invocation -> { int shard = (Integer) invocation.getArguments()[0]; final ShardId shardId = new ShardId(indexMetaData.getIndex(), shard); if (shard > indexMetaData.getNumberOfShards()) { throw new ShardNotFoundException(shardId); } return mockIndexShard(shardId, clusterService); }); return indexService; } final IndicesService mockIndicesService(ClusterService clusterService) { final IndicesService indicesService = mock(IndicesService.class); when(indicesService.indexServiceSafe(any(Index.class))).then(invocation -> { Index index = (Index)invocation.getArguments()[0]; final ClusterState state = clusterService.state(); final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); return mockIndexService(indexSafe, clusterService); }); when(indicesService.indexService(any(Index.class))).then(invocation -> { Index index = (Index) invocation.getArguments()[0]; final ClusterState state = clusterService.state(); if (state.metaData().hasIndex(index.getName())) { final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); return mockIndexService(clusterService.state().metaData().getIndexSafe(index), clusterService); } else { return null; } }); return indicesService; } private final AtomicInteger count = new AtomicInteger(0); private final AtomicBoolean isRelocated = new AtomicBoolean(false); private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService) { final IndexShard indexShard = mock(IndexShard.class); doAnswer(invocation -> { ActionListener<Releasable> callback = (ActionListener<Releasable>) invocation.getArguments()[0]; count.incrementAndGet(); callback.onResponse(count::decrementAndGet); return null; }).when(indexShard).acquirePrimaryOperationLock(any(ActionListener.class), anyString()); doAnswer(invocation -> { long term = (Long)invocation.getArguments()[0]; ActionListener<Releasable> callback = (ActionListener<Releasable>) invocation.getArguments()[1]; final long primaryTerm = indexShard.getPrimaryTerm(); if (term < primaryTerm) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s operation term [%d] is too old (current [%d])", shardId, term, primaryTerm)); } count.incrementAndGet(); callback.onResponse(count::decrementAndGet); return null; }).when(indexShard).acquireReplicaOperationLock(anyLong(), any(ActionListener.class), anyString()); when(indexShard.routingEntry()).thenAnswer(invocationOnMock -> { final ClusterState state = clusterService.state(); final RoutingNode node = state.getRoutingNodes().node(state.nodes().getLocalNodeId()); final ShardRouting routing = node.getByShardId(shardId); if (routing == null) { throw new ShardNotFoundException(shardId, "shard is no longer assigned to current node"); } return routing; }); when(indexShard.state()).thenAnswer(invocationOnMock -> isRelocated.get() ? IndexShardState.RELOCATED : IndexShardState.STARTED); doThrow(new AssertionError("failed shard is not supported")).when(indexShard).failShard(anyString(), any(Exception.class)); when(indexShard.getPrimaryTerm()).thenAnswer(i -> clusterService.state().metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id())); return indexShard; } private static class TestRequest extends ReplicatedWriteRequest<TestRequest> { TestRequest() { setShardId(new ShardId("test", "test", 1)); } @Override public String toString() { return "TestRequest{}"; } } private static class TestResponse extends ReplicationResponse implements WriteResponse { boolean forcedRefresh; @Override public void setForcedRefresh(boolean forcedRefresh) { this.forcedRefresh = forcedRefresh; } } private static class CapturingActionListener<R> implements ActionListener<R> { private R response; private Exception failure; @Override public void onResponse(R response) { this.response = response; } @Override public void onFailure(Exception failure) { this.failure = failure; } } }