/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.script;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.plugin.TokenPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.pmml.PMMLModelScriptEngineService;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESIntegTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.test.StreamsUtils.copyToStringFromClasspath;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
/**
*/
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class FullPMMLIT extends ESIntegTestCase {
protected Collection<Class<? extends Plugin>> transportClientPlugins() {
return Collections.singletonList(TokenPlugin.class);
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(TokenPlugin.class);
}
public void testAdult() throws IOException, ExecutionException, InterruptedException {
indexAdultData("/org/elasticsearch/script/adult.data", this);
assertHitCount(client().prepareSearch().get(), 32560);
indexAdultModel("/org/elasticsearch/script/lr_model_adult_full.xml");
checkClassificationCorrect("/org/elasticsearch/script/knime_glm_adult_result.csv");
}
public void testSingleAdult() throws IOException, ExecutionException, InterruptedException {
indexAdultData("/org/elasticsearch/script/singlevalueforintegtest.txt", this);
assertHitCount(client().prepareSearch().get(), 1);
indexAdultModel("/org/elasticsearch/script/lr_model_adult_full.xml");
checkClassificationCorrect("/org/elasticsearch/script/singleresultforintegtest.txt");
}
public void testSingleAdultNotDebug() throws IOException, ExecutionException, InterruptedException {
indexAdultData("/org/elasticsearch/script/singlevalueforintegtest.txt", this);
assertHitCount(client().prepareSearch().get(), 1);
indexAdultModel("/org/elasticsearch/script/naive-bayes-adult-full-r.xml");
Map<String, Object> params = new HashMap<String, Object>();
params.put("debug", false);
SearchResponse searchResponse = client().prepareSearch("test").addScriptField("pmml", new Script("1", ScriptService.ScriptType
.STORED, PMMLModelScriptEngineService.NAME, params)).addStoredField("_source").setSize(10000).get();
assertSearchResponse(searchResponse);
assertThat(searchResponse.getHits().getAt(0).fields().get("pmml").getValue(), instanceOf(String.class));
assertThat(searchResponse.getHits().getAt(0).fields().get("pmml").getValue(), equalTo(">50K"));
}
private void checkClassificationCorrect(String resultFile) throws IOException {
final String testData = copyToStringFromClasspath(resultFile);
String resultLines[] = testData.split("\\r?\\n");
Map<String, String> expectedResults = new HashMap<>();
for (int i = 1; i < resultLines.length; i++) {
expectedResults.put(Integer.toString(i), resultLines[i]);
}
SearchResponse searchResponse = client().prepareSearch("test").addScriptField("pmml", new Script("1", ScriptService.ScriptType
.STORED, PMMLModelScriptEngineService.NAME, new HashMap<String, Object>())).addStoredField("_source").setSize(10000).get();
assertSearchResponse(searchResponse);
for (SearchHit hit : searchResponse.getHits().getHits()) {
@SuppressWarnings("unchecked")
String label = (String) ((Map<String, Object>) (hit.field("pmml").values().get(0))).get("class");
String[] expectedResult = expectedResults.get(hit.id()).split(",");
assertThat(label, equalTo(expectedResult[2].substring(1, expectedResult[2].length() - 1)));
}
}
public static void indexAdultModel(String modelFile) throws IOException {
final String pmmlString = copyToStringFromClasspath(modelFile);
// create spec
client().admin().cluster().preparePutStoredScript().setScriptLang("pmml_model").setId("1").setSource(
jsonBuilder().startObject()
.field("script", pmmlString)
.endObject().bytes()
).get();
}
public static void indexAdultData(String data, ESIntegTestCase testCase) throws IOException, ExecutionException, InterruptedException {
XContentBuilder mappingBuilder = jsonBuilder();
mappingBuilder.startObject();
mappingBuilder.startObject("type")
.startObject("properties")
.startObject("age")
.field("type", "double")
.endObject()
.startObject("workclass")
.field("type", "keyword")
.endObject()
.startObject("fnlwgt")
.field("type", "double")
.endObject()
.startObject("education")
.field("type", "keyword")
.endObject()
.startObject("education_num")
.field("type", "double")
.endObject()
.startObject("marital_status")
.field("type", "keyword")
.endObject()
.startObject("occupation")
.field("type", "keyword")
.endObject()
.startObject("relationship")
.field("type", "keyword")
.endObject()
.startObject("race")
.field("type", "keyword")
.endObject()
.startObject("sex")
.field("type", "keyword")
.endObject()
.startObject("capital_gain")
.field("type", "double")
.endObject()
.startObject("capital_loss")
.field("type", "double")
.endObject()
.startObject("hours_per_week")
.field("type", "double")
.endObject()
.startObject("native_country")
.field("type", "keyword")
.endObject()
.startObject("class")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
mappingBuilder.endObject();
assertAcked(client().admin().indices().prepareCreate("test").addMapping("type", mappingBuilder).get());
final String testData = copyToStringFromClasspath(data);
String testDataLines[] = testData.split("\\r?\\n");
String[] fields = testDataLines[0].split(",");
for (int i = 0; i < fields.length; i++) {
fields[i] = fields[i].trim();
fields[i] = fields[i].substring(1, fields[i].length() - 1);
}
List<IndexRequestBuilder> docs = new ArrayList<>();
for (int i = 1; i < testDataLines.length; i++) {
String[] testDataValues = testDataLines[i].split(",");
// trimm spaces and add value
Map<String, Object> input = new HashMap<>();
for (int j = 0; j < testDataValues.length; j++) {
testDataValues[j] = testDataValues[j].trim();
if (testDataValues[j].equals("") == false) {
input.put(fields[j], testDataValues[j]);
} else {
if (randomBoolean()) {
input.put(fields[j], null);
}
}
}
docs.add(client().prepareIndex("test", "type", Integer.toString(i)).setSource(input));
}
testCase.indexRandom(true, true, docs);
}
}