/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.training; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.trainmodel.TrainModelRequestBuilder; import org.elasticsearch.action.trainmodel.TrainModelResponse; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.plugin.TokenPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.script.FullPMMLIT; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.pmml.PMMLModelScriptEngineService; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ESIntegTestCase; import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) public class NaiveBayesModelTrainerIT extends ESIntegTestCase { @Override protected Collection<Class<? extends Plugin>> nodePlugins() { return Collections.singletonList(TokenPlugin.class); } protected Collection<Class<? extends Plugin>> transportClientPlugins() { return Collections.singletonList(TokenPlugin.class); } public void testNaiveBayesTraining() throws Exception { indexDocs(); refresh(); TrainModelRequestBuilder builder = new TrainModelRequestBuilder(client()) .setModelId("abcd") .setModelType("naive_bayes") .addFields("text", "num") .setTargetField(new ModelTargetField("label")) .setTrainingSet(new DataSet("index", "type")); TrainModelResponse response = builder.get(); assertThat(response.getId(), equalTo("abcd")); SearchResponse searchResponse = client().prepareSearch("index").addScriptField("pmml", new Script(response.getId(), ScriptService .ScriptType .STORED, PMMLModelScriptEngineService.NAME, new HashMap<>())).addStoredField("_source").setSize(10000).get(); assertSearchResponse(searchResponse); for (SearchHit hit : searchResponse.getHits().getHits()) { @SuppressWarnings("unchecked") String label = (String) ((Map<String, Object>) (hit.field("pmml").values().get(0))).get("class"); assertThat(label, anyOf(equalTo("good"), equalTo("bad"))); } } @SuppressWarnings("unchecked") public void testNaiveBayesTrainingInElasticsearchSameAsInR() throws Exception { FullPMMLIT.indexAdultData("/org/elasticsearch/script/adult.data", this); FullPMMLIT.indexAdultModel("/org/elasticsearch/script/naive-bayes-adult-full-r-no-missing-values.xml"); refresh(); SearchResponse aggResponse = client().prepareSearch("test").addAggregation(terms("class").field("class").size(Integer.MAX_VALUE) .shardMinDocCount(1).minDocCount(1).order(Terms.Order.term(true))).get(); assertThat(((Terms) aggResponse.getAggregations().getAsMap().get("class")).getBuckets().size(), equalTo(2)); TrainModelRequestBuilder builder = new TrainModelRequestBuilder(client()) .setModelId("abcd") .setModelType("naive_bayes") .addFields("age", "fnlwgt", "education", "education_num", "marital_status", "relationship", "race", "sex", "capital_gain", "capital_loss", "hours_per_week") .setTargetField("class") .setTrainingSet(new DataSet("test", "type")); TrainModelResponse response = builder.get(); client().admin().cluster().prepareGetStoredScript(PMMLModelScriptEngineService.NAME, response.getId()).get(); SearchResponse searchResponseEsModel = client().prepareSearch("test") .addScriptField("pmml", new Script(response.getId(), ScriptService.ScriptType.STORED, PMMLModelScriptEngineService.NAME, new HashMap<>())) .addStoredField("_source").setSize(10000).addSort("_uid", SortOrder.ASC).get(); assertSearchResponse(searchResponseEsModel); SearchResponse searchResponseRModel = client().prepareSearch("test") .addScriptField("pmml", new Script("1", ScriptService.ScriptType.STORED, PMMLModelScriptEngineService.NAME, new HashMap<>())) .addStoredField("_source").setSize(10000) .addSort("_uid", SortOrder.ASC).get(); assertSearchResponse(searchResponseRModel); int hitCounter = 0; for (SearchHit hit : searchResponseEsModel.getHits().getHits()) { String Rlabel = (String) ((Map<String, Object>) (hit.field("pmml").values().get(0))).get("class"); String esLabel = (String) ((Map<String, Object>) (searchResponseEsModel.getHits().getHits()[hitCounter].field("pmml").values ().get(0))) .get("class"); Map<String, Double> RProbs = (Map<String, Double>) ((Map<String, Object>) (hit.field("pmml").values().get(0))).get("probs"); Map<String, Double> esProbs = (Map<String, Double>) ((Map<String, Object>) (searchResponseEsModel.getHits().getHits() [hitCounter].field ("pmml").values().get(0))).get("probs"); assertThat("result " + hitCounter + " has wrong prob:", esProbs.get(">50K"), closeTo(RProbs.get(">50K"), 1.e-5)); assertThat("result " + hitCounter + " has wrong prob:", esProbs.get("<=50K"), closeTo(RProbs.get("<=50K"), 1.e-5)); assertThat("result " + hitCounter + " has wrong class:", esLabel, equalTo(Rlabel)); hitCounter++; } } private void indexDocs() throws IOException { XContentBuilder mapping = jsonBuilder(); mapping.startObject(); { mapping.startObject("type"); { mapping.startObject("properties"); { mapping.startObject("text"); { mapping.field("type", "text"); mapping.field("fielddata", true); } mapping.endObject(); mapping.startObject("label"); { mapping.field("type", "keyword"); } mapping.endObject(); } mapping.endObject(); } mapping.endObject(); } mapping.endObject(); client().admin().indices().prepareCreate("index").setSettings(Settings.builder().put(IndexMetaData.SETTING_NUMBER_OF_SHARDS, 1)) .addMapping("type", mapping).get(); client().prepareIndex("index", "type", "1").setSource("text", "I hate json", "label", "bad", "num", 1).execute().actionGet(); client().prepareIndex("index", "type", "2").setSource("text", "json sucks", "label", "bad", "num", 2).execute().actionGet(); client().prepareIndex("index", "type", "3").setSource("text", "json is much worse than xml", "label", "bad", "num", 3).execute() .actionGet(); client().prepareIndex("index", "type", "4").setSource("text", "xml is lovely", "label", "good", "num", 4).execute().actionGet(); client().prepareIndex("index", "type", "5").setSource("text", "everyone loves xml", "label", "good", "num", 5).execute() .actionGet(); client().prepareIndex("index", "type", "6").setSource("text", "seriously, xml is sooo much better than json", "label", "good", "num", 6).execute().actionGet(); client().prepareIndex("index", "type", "7").setSource("text", "if any of my fellow developers reads this, they will tar and " + "feather me and hang my mutilated body above the entrace to amsterdam headquaters as a warning to others", "label", "good", "num", 7).execute().actionGet(); client().prepareIndex("index", "type", "8").setSource("text", "obviously I am joking", "label", "good", "num", 8).execute() .actionGet(); } }