package com.lordofthejars.nosqlunit.elasticsearch2;
import com.google.common.primitives.Ints;
import com.lordofthejars.nosqlunit.core.AbstractCustomizableDatabaseOperation;
import com.lordofthejars.nosqlunit.core.NoSqlAssertionError;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.count.CountResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;
public class ElasticsearchOperation extends
AbstractCustomizableDatabaseOperation<ElasticsearchConnectionCallback, Client> {
private Client client;
public ElasticsearchOperation(Client client) {
this.client = client;
setInsertionStrategy(new DefaultElasticsearchInsertionStrategy());
setComparisonStrategy(new DefaultElasticsearchComparisonStrategy());
}
@Override
public void insert(InputStream dataScript) {
insertData(dataScript);
}
private void insertData(InputStream dataScript) {
try {
executeInsertion(new ElasticsearchConnectionCallback() {
@Override
public Client nodeClient() {
return client;
}
}, dataScript);
} catch (Throwable e) {
throw new IllegalArgumentException(e);
}
}
@Override
public void deleteAll() {
clearDocuments();
}
private void clearDocuments() {
if (isAnyIndexPresent()) {
final SearchResponse countResponse = client.prepareSearch()
.setSearchType(SearchType.QUERY_THEN_FETCH)
.setQuery(QueryBuilders.matchAllQuery())
.setSize(0)
.execute()
.actionGet();
int docCount = Ints.saturatedCast(countResponse.getHits().totalHits());
final SearchResponse scrollResponse = client.prepareSearch()
.setSearchType(SearchType.SCAN)
.setScroll(new TimeValue(1L, TimeUnit.MINUTES))
.setQuery(QueryBuilders.matchAllQuery())
.setSize(docCount)
.execute()
.actionGet();
final BulkRequestBuilder bulkRequestBuilder = client.prepareBulk();
while (true) {
final SearchResponse searchResponse = client.prepareSearchScroll(scrollResponse.getScrollId())
.setScroll(new TimeValue(1L, TimeUnit.MINUTES))
.execute()
.actionGet();
for (SearchHit hit : searchResponse.getHits().getHits()) {
bulkRequestBuilder.add(client.prepareDelete(hit.index(), hit.type(), hit.id()));
}
//Break condition: No hits are returned
if (searchResponse.getHits().getHits().length == 0) {
break;
}
}
if (bulkRequestBuilder.numberOfActions() > 0) {
final BulkResponse bulkResponse = bulkRequestBuilder.execute().actionGet();
}
refreshNode();
}
}
private boolean isAnyIndexPresent() {
CountResponse numberOfElements = client.prepareCount().execute().actionGet();
return numberOfElements.getCount() > 0;
}
private void refreshNode() {
client.admin().indices().prepareRefresh().execute().actionGet();
}
@Override
public boolean databaseIs(InputStream expectedData) {
try {
return executeComparison(new ElasticsearchConnectionCallback() {
@Override
public Client nodeClient() {
return client;
}
}, expectedData);
} catch (NoSqlAssertionError e) {
throw e;
} catch (Throwable e) {
throw new IllegalStateException(e);
}
}
@Override
public Client connectionManager() {
return client;
}
}