/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.training; import org.elasticsearch.action.trainmodel.TrainModelRequest; import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; import java.util.Arrays; import java.util.Collections; import java.util.Map; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; /** */ public class TrainingRequestParsingTests extends ESTestCase { public void testMinimalModel() throws Exception { TrainModelRequest trainModelRequest = new TrainModelRequest(); XContentBuilder sourceBuilder = jsonBuilder(); sourceBuilder.startObject(); { sourceBuilder.field("id", "abcd"); sourceBuilder.array("fields", "field1", "field2", "field3"); sourceBuilder.field("target_field", "class"); sourceBuilder.startObject("training_set"); { sourceBuilder.field("index", "test"); sourceBuilder.field("type", "type"); } sourceBuilder.endObject(); } sourceBuilder.endObject(); trainModelRequest.source(sourceBuilder.bytes()); assertThat(trainModelRequest.getModelId(), equalTo("abcd")); assertThat(trainModelRequest.getFields(), equalTo( Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"), new ModelInputField("field3")))); assertThat(trainModelRequest.getTargetField(), equalTo(new ModelTargetField("class"))); assertThat(trainModelRequest.getTrainingSet(), equalTo(new DataSet("test", "type"))); assertThat(trainModelRequest.getModelSettings(), equalTo(Settings.EMPTY)); } public void testModelWithCustomSettings() throws Exception { TrainModelRequest trainModelRequest = new TrainModelRequest(); XContentBuilder sourceBuilder = jsonBuilder(); sourceBuilder.startObject(); { sourceBuilder.field("id", "abcd"); sourceBuilder.startArray("fields"); { sourceBuilder.startObject().field("name", "field1").endObject(); sourceBuilder.startObject().field("name", "field2").endObject(); sourceBuilder.startObject().field("name", "field3").endObject(); } sourceBuilder.endArray(); sourceBuilder.startObject("target_field").field("name", "class").endObject(); sourceBuilder.startObject("training_set"); { sourceBuilder.field("index", "test"); sourceBuilder.field("type", "type"); } sourceBuilder.endObject(); sourceBuilder.startObject("settings"); { sourceBuilder.field("foo", "bar"); sourceBuilder.field("bar", "baz"); } sourceBuilder.endObject(); } sourceBuilder.endObject(); trainModelRequest.source(sourceBuilder.bytes()); assertThat(trainModelRequest.getModelId(), equalTo("abcd")); assertThat(trainModelRequest.getFields(), equalTo( Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"), new ModelInputField("field3")))); assertThat(trainModelRequest.getTargetField(), equalTo(new ModelTargetField("class"))); assertThat(trainModelRequest.getTrainingSet(), equalTo(new DataSet("test", "type"))); assertThat(trainModelRequest.getModelSettings(), equalTo(Settings.builder().put("foo", "bar").put("bar","baz").build())); } public void testModelWithTrainingQuery() throws Exception { TrainModelRequest trainModelRequest = new TrainModelRequest(); XContentBuilder sourceBuilder = jsonBuilder(); sourceBuilder.startObject(); { sourceBuilder.field("id", "abcd"); sourceBuilder.array("fields", "field1", "field2", "field3"); sourceBuilder.field("target_field", "class"); sourceBuilder.startObject("training_set"); { sourceBuilder.field("index", "test"); sourceBuilder.field("type", "type"); } sourceBuilder.endObject(); sourceBuilder.startObject("training_set"); { sourceBuilder.field("index", "test"); sourceBuilder.field("type", "type"); sourceBuilder.startObject("query").startObject("match_all").endObject().endObject(); } sourceBuilder.endObject(); } sourceBuilder.endObject(); trainModelRequest.source(sourceBuilder.bytes()); assertThat(trainModelRequest.getModelId(), equalTo("abcd")); assertThat(trainModelRequest.getFields(), equalTo( Arrays.asList(new ModelInputField("field1"), new ModelInputField("field2"), new ModelInputField("field3")))); assertThat(trainModelRequest.getTargetField(), equalTo(new ModelTargetField("class"))); Map<String, Object> matchAllQuery = Collections.singletonMap("match_all", Collections.emptyMap()); assertThat(trainModelRequest.getTrainingSet(), equalTo(new DataSet("test", "type", matchAllQuery))); } }