/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.action.search;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.ShardSearchTransportRequest;
import org.elasticsearch.transport.Transport;
import org.junit.Assert;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* SearchPhaseContext for tests
*/
public final class MockSearchPhaseContext implements SearchPhaseContext {
private static final Logger logger = Loggers.getLogger(MockSearchPhaseContext.class);
public AtomicReference<Throwable> phaseFailure = new AtomicReference<>();
final int numShards;
final AtomicInteger numSuccess;
List<ShardSearchFailure> failures = Collections.synchronizedList(new ArrayList<>());
SearchTransportService searchTransport;
Set<Long> releasedSearchContexts = new HashSet<>();
SearchRequest searchRequest = new SearchRequest();
AtomicInteger phasesExecuted = new AtomicInteger();
public MockSearchPhaseContext(int numShards) {
this.numShards = numShards;
numSuccess = new AtomicInteger(numShards);
}
public void assertNoFailure() {
if (phaseFailure.get() != null) {
throw new AssertionError(phaseFailure.get());
}
}
@Override
public int getNumShards() {
return numShards;
}
@Override
public Logger getLogger() {
return logger;
}
@Override
public SearchTask getTask() {
return new SearchTask(0, "n/a", "n/a", "test", null);
}
@Override
public SearchRequest getRequest() {
return searchRequest;
}
@Override
public SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) {
return new SearchResponse(internalSearchResponse, scrollId, numShards, numSuccess.get(), 0,
failures.toArray(new ShardSearchFailure[0]));
}
@Override
public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
phaseFailure.set(cause);
}
@Override
public void onShardFailure(int shardIndex, @Nullable SearchShardTarget shardTarget, Exception e) {
failures.add(new ShardSearchFailure(e, shardTarget));
numSuccess.decrementAndGet();
}
@Override
public Transport.Connection getConnection(String clusterAlias, String nodeId) {
return null; // null is ok here for this test
}
@Override
public SearchTransportService getSearchTransport() {
Assert.assertNotNull(searchTransport);
return searchTransport;
}
@Override
public ShardSearchTransportRequest buildShardSearchRequest(SearchShardIterator shardIt) {
Assert.fail("should not be called");
return null;
}
@Override
public void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) {
phasesExecuted.incrementAndGet();
try {
nextPhase.run();
} catch (Exception e) {
onPhaseFailure(nextPhase, "phase failed", e);
}
}
@Override
public void execute(Runnable command) {
command.run();
}
@Override
public void onResponse(SearchResponse response) {
Assert.fail("should not be called");
}
@Override
public void onFailure(Exception e) {
Assert.fail("should not be called");
}
@Override
public void sendReleaseSearchContext(long contextId, Transport.Connection connection, OriginalIndices originalIndices) {
releasedSearchContexts.add(contextId);
}
}