/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.ml.training;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.AliasMetaData;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.MappingMetaData;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.ml.training.ModelTrainer.TrainingSession;
import org.elasticsearch.plugins.SearchPlugin.QuerySpec;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
*/
public class TrainingServiceTests extends ESTestCase {
private static final IndicesQueriesRegistry indicesQueriesRegistry = new IndicesQueriesRegistry();
static {
// Register a few simple queries to test parsing
registerQuery(new QuerySpec<>(MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new, MatchAllQueryBuilder::fromXContent));
registerQuery(new QuerySpec<>(TermQueryBuilder.NAME, TermQueryBuilder::new, TermQueryBuilder::fromXContent));
}
private static void registerQuery(QuerySpec<?> spec) {
indicesQueriesRegistry.register(spec.getParser(), spec.getName());
}
private final SearchRequestParsers searchParsers = new SearchRequestParsers(indicesQueriesRegistry, null, null, null);
private final Settings MINIMAL_INDEX_SETTINGS = Settings.builder()
.put(IndexMetaData.SETTING_VERSION_CREATED, Version.CURRENT)
.put(IndexMetaData.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetaData.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetaData.SETTING_INDEX_UUID, UUIDs.randomBase64UUID())
.build();
private SearchRequestBuilder mockSearchRequestBuilder(String index, AggregationBuilder aggregationBuilder) {
return mockSearchRequestBuilder(index, aggregationBuilder, null);
}
private SearchRequestBuilder mockSearchRequestBuilder(String index, AggregationBuilder aggregationBuilder, QueryBuilder queryBuilder) {
SearchRequestBuilder searchRequestBuilder = mock(SearchRequestBuilder.class);
when(searchRequestBuilder.setIndices(index)).thenReturn(searchRequestBuilder);
if (queryBuilder == null) {
when(searchRequestBuilder.setQuery(any(QueryBuilder.class))).thenReturn(searchRequestBuilder);
} else {
when(searchRequestBuilder.setQuery(eq(queryBuilder))).thenReturn(searchRequestBuilder);
}
when(searchRequestBuilder.addAggregation(aggregationBuilder)).thenReturn(searchRequestBuilder);
return searchRequestBuilder;
}
private TrainingService mockTrainingService(String index, String trainType, Map<String, Object> mapping,
Settings settings, List<ModelInputField> fields,
ModelTargetField targetField,
Function<AggregationBuilder, SearchRequestBuilder> mockSearchRequestBuilder)
throws IOException {
ClusterService clusterService = mock(ClusterService.class);
MappingMetaData mappingMetaData = new MappingMetaData(trainType, mapping);
IndexMetaData.Builder indexMetaData = IndexMetaData.builder(index)
.putMapping(mappingMetaData)
.putAlias(AliasMetaData.builder("just_me_alias"))
.putAlias(AliasMetaData.builder("other_and_me_alias"))
.settings(MINIMAL_INDEX_SETTINGS);
IndexMetaData.Builder otherIndexMetaData = IndexMetaData.builder("other_index")
.putMapping(mappingMetaData)
.putAlias(AliasMetaData.builder("other_and_me_alias"))
.settings(MINIMAL_INDEX_SETTINGS);
MetaData.Builder metaData = MetaData.builder().put(indexMetaData).put(otherIndexMetaData);
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).metaData(metaData).build());
AggregationBuilder aggregationBuilder = mock(AggregationBuilder.class);
SearchRequestBuilder searchRequestBuilder = mockSearchRequestBuilder.apply(aggregationBuilder);
Client client = mock(Client.class);
when(client.prepareSearch()).thenReturn(searchRequestBuilder);
when(client.prepareSearch()).thenReturn(searchRequestBuilder);
SearchResponse searchResponse = mock(SearchResponse.class);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked") ActionListener<SearchResponse> listener =
(ActionListener<SearchResponse>) invocationOnMock.getArguments()[0];
listener.onResponse(searchResponse);
return null;
}).when(searchRequestBuilder).execute(any());
ModelTrainer modelTrainer = mock(ModelTrainer.class);
when(modelTrainer.modelType()).thenReturn("mock");
ModelTrainers modelTrainers = new ModelTrainers(Collections.singletonList(modelTrainer));
TrainingSession trainingSession = mock(TrainingSession.class);
when(trainingSession.trainingRequest()).thenReturn(aggregationBuilder);
when(trainingSession.model(searchResponse)).thenReturn("Success!");
when(modelTrainers.createTrainingSession(mappingMetaData, "mock", settings, fields, targetField)).thenReturn(trainingSession);
return new TrainingService(Settings.EMPTY, clusterService, client, modelTrainers, searchParsers);
}
public void testPassingCorrectParameter() throws Exception {
List<ModelInputField> fields = Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"));
ModelTargetField targetField = new ModelTargetField("target");
Settings settings = Settings.builder().put("foo", "bar").build();
// Test single index
TrainingService trainingService = mockTrainingService("train_index", "train_type", Collections.singletonMap("foo", "bar"),
settings, fields, targetField, aggregationBuilder -> mockSearchRequestBuilder("train_index", aggregationBuilder));
ExpectedListener listener = new ExpectedListener("Success!");
trainingService.train("mock", settings, "train_index", "train_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
}
public void testSingleIndexAlias() throws Exception {
List<ModelInputField> fields = Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"));
ModelTargetField targetField = new ModelTargetField("target");
Settings settings = Settings.builder().put("foo", "bar").build();
// Test single index alias
TrainingService trainingService = mockTrainingService("train_index", "train_type", Collections.singletonMap("foo", "bar"),
settings, fields, targetField, aggregationBuilder -> mockSearchRequestBuilder("just_me_alias", aggregationBuilder));
ExpectedListener listener = new ExpectedListener("Success!");
trainingService.train("mock", settings, "just_me_alias", "train_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
}
public void testMultiIndexAlias() throws Exception {
List<ModelInputField> fields = Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"));
ModelTargetField targetField = new ModelTargetField("target");
Settings settings = Settings.builder().put("foo", "bar").build();
// Test single index alias
TrainingService trainingService = mockTrainingService("train_index", "train_type", Collections.singletonMap("foo", "bar"),
settings, fields, targetField, aggregationBuilder -> mockSearchRequestBuilder("other_and_me_alias", aggregationBuilder));
ExpectedListener listener = new ExpectedListener(IllegalArgumentException.class, "can only train on a single index");
trainingService.train("mock", settings, "other_and_me_alias", "train_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
}
public void testMissingParameters() throws Exception {
List<ModelInputField> fields = Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"));
ModelTargetField targetField = new ModelTargetField("target");
Settings settings = Settings.builder().put("foo", "bar").build();
TrainingService trainingService = mockTrainingService("train_index", "train_type", Collections.singletonMap("foo", "bar"),
settings, fields, targetField, aggregationBuilder -> mockSearchRequestBuilder("train_index", aggregationBuilder));
ExpectedListener listener = new ExpectedListener(UnsupportedOperationException.class, "Unsupported model type [blah]");
trainingService.train("blah", settings, "train_index", "train_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
listener = new ExpectedListener(IndexNotFoundException.class, "no such index");
trainingService.train("mock", settings, "unknown_index", "train_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
listener = new ExpectedListener(ResourceNotFoundException.class, "the training type [unknown_type] not found");
trainingService.train("mock", settings, "train_index", "unknown_type", Collections.emptyMap(), fields, targetField, listener);
listener.await();
}
public void testPassingCustomTrainingQuery() throws Exception {
List<ModelInputField> fields = Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"));
ModelTargetField targetField = new ModelTargetField("target");
Settings settings = Settings.builder().put("foo", "bar").build();
QueryBuilder expectedQuery = QueryBuilders.termQuery("tag", "training");
TrainingService trainingService = mockTrainingService("train_index", "train_type", Collections.singletonMap("foo", "bar"),
settings, fields, targetField,
aggregationBuilder -> mockSearchRequestBuilder("train_index", aggregationBuilder, expectedQuery));
ExpectedListener listener = new ExpectedListener("Success!");
Map<String, Object> query = Collections.singletonMap("term", Collections.singletonMap("tag", "training"));
trainingService.train("mock", settings, "train_index", "train_type", query, fields, targetField, listener);
listener.await();
}
private class ExpectedListener implements ActionListener<String> {
private final CountDownLatch latch = new CountDownLatch(1);
private final Matcher<Exception> failure;
private final Matcher<String> failureMessage;
private final Matcher<String> success;
public ExpectedListener(Class<? extends Exception> clazz, String message) {
this(instanceOf(clazz), equalTo(message), null);
}
public ExpectedListener(String success) {
this(null, null, equalTo(success));
}
public ExpectedListener(Matcher<Exception> failure, Matcher<String> failureMessage, Matcher<String> success) {
this.failure = failure;
this.failureMessage = failureMessage;
this.success = success;
}
@Override
public void onResponse(String s) {
try {
if (success == null) {
fail("Expected exception, but got response [{" + s + "}]");
}
assertThat(s, success);
} finally {
latch.countDown();
}
}
@Override
public void onFailure(Exception e) {
try {
if (failure == null) {
logger.error("Expected response, but got failure", e);
fail();
}
assertThat(e, failure);
assertThat(e.getMessage(), failureMessage);
} finally {
latch.countDown();
}
}
public void await() throws InterruptedException {
latch.await(10, TimeUnit.SECONDS);
}
}
}