/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.PMML; import org.dmg.pmml.TreeModel; import org.elasticsearch.ml.modelinput.MockDataSource; import org.elasticsearch.ml.modelinput.SparseVectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInputEvaluator; import org.elasticsearch.ml.modelinput.VectorRange; import org.elasticsearch.ml.modelinput.VectorRangesToVectorPMML; import org.elasticsearch.ml.modelinput.MapModelInput; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import static org.elasticsearch.script.pmml.ProcessPMMLHelper.parsePmml; import static org.elasticsearch.test.StreamsUtils.copyToStringFromClasspath; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.notNullValue; public class PMMLParsingTests extends ESTestCase { public void testSimplePipelineParsingGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/logistic_regression.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); assertThat(((VectorModelInputEvaluator)fieldsToVectorAndModel.getVectorRangesToVector()).getVectorRangeList().size(), equalTo(15)); } public void testTwoStepPipelineParsing() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/lr_model.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(3)); assertVectorsCorrect(vectorEntries); } public void testTwoStepPipelineParsingReorderedGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/lr_model_reordered.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(3)); assertVectorsCorrect(vectorEntries); } public void assertVectorsCorrect(VectorModelInputEvaluator vectorEntries) throws IOException { final String testData = copyToStringFromClasspath("/org/elasticsearch/script/test.data"); final String expectedResults = copyToStringFromClasspath("/org/elasticsearch/script/lr_result.txt"); String testDataLines[] = testData.split("\\r?\\n"); String expectedResultsLines[] = expectedResults.split("\\r?\\n"); for (int i = 0; i < testDataLines.length; i++) { String[] testDataValues = testDataLines[i].split(","); List<Object> ageInput = new ArrayList<>(); if (testDataValues[0].equals("") == false) { ageInput.add(Double.parseDouble(testDataValues[0])); } List<Object> workInput = new ArrayList<>(); if (testDataValues[1].trim().equals("") == false) { workInput.add(testDataValues[1].trim()); } Map<String, List<Object>> input = new HashMap<>(); input.put("age", ageInput); input.put("work", workInput); SparseVectorModelInput result = vectorEntries.convert(new MockDataSource(input)); String[] expectedResult = expectedResultsLines[i + 1].split(","); double expectedAgeValue = Double.parseDouble(expectedResult[0]); // assertThat(Double.parseDouble(expectedResult[0]), Matchers.closeTo(((double[]) result.get("values"))[0], 1.e-7)); if (workInput.size() == 0) { // this might be a problem with the model. not sure. the "other" value does not appear in it. assertArrayEquals(result.getValues(), new double[]{expectedAgeValue, 1.0d}, 1.e-7); assertArrayEquals(result.getIndices(), new int[]{0, 4}); } else if ("Private".equals(workInput.get(0))) { assertArrayEquals(result.getValues(), new double[]{expectedAgeValue, 1.0d, 1.0d}, 1.e-7); assertArrayEquals(result.getIndices(), new int[]{0, 1, 4}); } else if ("Self-emp-inc".equals(workInput.get(0))) { assertArrayEquals(result.getValues(), new double[]{expectedAgeValue, 1.0d, 1.0d}, 1.e-7); assertArrayEquals(result.getIndices(), new int[]{0, 2, 4}); } else if ("State-gov".equals(workInput.get(0))) { assertArrayEquals(result.getValues(), new double[]{expectedAgeValue, 1.0d, 1.0d}, 1.e-7); assertArrayEquals(result.getIndices(), new int[]{0, 3, 4}); } else { fail("work input was " + workInput); } } } public void testModelAndFeatureParsingGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/lr_model.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(3)); assertModelCorrect(fieldsToVectorAndModel); } public void testBigModelAndFeatureParsingGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/lr_model_adult_full.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(15)); assertBiggerModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/adult.data", "/org/elasticsearch/script/knime_glm_adult_result.csv"); } public void testBigModelAndFeatureParsingFromRExportGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/glm-adult-full-r.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(12)); assertBiggerModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/adult.data", "/org/elasticsearch/script/r_glm_adult_result" + ".csv"); } public void testBigModelCorrectSingleValueGLM() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/lr_model_adult_full.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(15)); assertBiggerModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/singlevalueforintegtest.txt", "/org/elasticsearch/script/singleresultforintegtest.txt"); } private void assertModelCorrect(ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel) throws IOException { final String testData = copyToStringFromClasspath("/org/elasticsearch/script/test.data"); final String expectedResults = copyToStringFromClasspath("/org/elasticsearch/script/lr_result.txt"); String testDataLines[] = testData.split("\\r?\\n"); String expectedResultsLines[] = expectedResults.split("\\r?\\n"); for (int i = 0; i < testDataLines.length; i++) { String[] testDataValues = testDataLines[i].split(","); List<Object> ageInput = new ArrayList<>(); if (testDataValues[0].equals("") == false) { ageInput.add(Double.parseDouble(testDataValues[0])); } List<Object> workInput = new ArrayList<>(); if (testDataValues[1].trim().equals("") == false) { workInput.add(testDataValues[1].trim()); } Map<String, List<Object>> input = new HashMap<>(); input.put("age", ageInput); input.put("work", workInput); @SuppressWarnings("unchecked") VectorModelInput result = fieldsToVectorAndModel.getVectorRangesToVector().convert(new MockDataSource(input)); String[] expectedResult = expectedResultsLines[i + 1].split(","); String expectedClass = expectedResult[expectedResult.length - 1]; expectedClass = expectedClass.substring(1, expectedClass.length() - 1); @SuppressWarnings("unchecked") Map<String, Object> resultValues = fieldsToVectorAndModel.getModel().evaluateDebug(result); assertThat(expectedClass, equalTo(resultValues.get("class"))); } } private void assertBiggerModelCorrect(ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel, String inputData, String resultData) throws IOException { final String testData = copyToStringFromClasspath(inputData); final String expectedResults = copyToStringFromClasspath(resultData); String testDataLines[] = testData.split("\\r?\\n"); String expectedResultsLines[] = expectedResults.split("\\r?\\n"); String[] fields = testDataLines[0].split(","); for (int i = 0; i < fields.length; i++) { fields[i] = fields[i].trim(); fields[i] = fields[i].substring(1, fields[i].length() - 1); } for (int i = 1; i < testDataLines.length; i++) { String[] testDataValues = testDataLines[i].split(","); // trimm spaces and add value Map<String, List<Object>> input = new HashMap<>(); for (int j = 0; j < testDataValues.length; j++) { testDataValues[j] = testDataValues[j].trim(); if (testDataValues[j].equals("") == false) { List<Object> fieldInput = new ArrayList<>(); if (j == 0 || j == 2 || j == 4 || j == 10 || j == 11 || j == 12) { fieldInput.add(Double.parseDouble(testDataValues[j])); } else { fieldInput.add(testDataValues[j]); } input.put(fields[j], fieldInput); } else { if (randomBoolean()) { input.put(fields[j], new ArrayList<>()); } } } VectorModelInput vectorModelInput = fieldsToVectorAndModel.getVectorRangesToVector().convert(new MockDataSource(input)); String[] expectedResult = expectedResultsLines[i].split(","); String expectedClass = expectedResult[2]; expectedClass = expectedClass.substring(1, expectedClass.length() - 1); @SuppressWarnings("unchecked") Map<String, Object> resultValues = fieldsToVectorAndModel.getModel().evaluateDebug(vectorModelInput); @SuppressWarnings("unchecked") double prob0 = (Double) ((Map<String, Object>) resultValues.get("probs")).get("<=50K"); @SuppressWarnings("unchecked") double prob1 = (Double) ((Map<String, Object>) resultValues.get("probs")).get(">50K"); assertThat("result " + i + " had wrong probability for class " + "<=50K", prob0, Matchers.closeTo(Double.parseDouble(expectedResult[0]), 1.e-7)); assertThat("result " + i + " had wrong probability for class " + ">50K", prob1, Matchers.closeTo(Double.parseDouble(expectedResult[1]), 1.e-7)); assertThat(expectedClass, equalTo(resultValues.get("class"))); } } /*tests for tree model*/ public void testBigModelAndFeatureParsingFromRExportTreeModel() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/tree-adult-full-r.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<MapModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorRangesToVectorPMML.VectorRangesToVectorPMMLTreeModel vectorEntries = (VectorRangesToVectorPMML .VectorRangesToVectorPMMLTreeModel) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getEntries().size(), equalTo(11)); assertTreeModelModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/adult.data", "/org/elasticsearch/script/r_tree_adult_result.csv"); } private void assertTreeModelModelCorrect(ModelAndModelInputEvaluator<MapModelInput, String> fieldsToVectorAndModel, String inputData, String resultData) throws IOException { assertThat(fieldsToVectorAndModel.getModel(), notNullValue()); final String testData = copyToStringFromClasspath(inputData); final String expectedResults = copyToStringFromClasspath(resultData); String testDataLines[] = testData.split("\\r?\\n"); String expectedResultsLines[] = expectedResults.split("\\r?\\n"); String[] fields = testDataLines[0].split(","); for (int i = 0; i < fields.length; i++) { fields[i] = fields[i].trim(); fields[i] = fields[i].substring(1, fields[i].length() - 1); } for (int i = 1; i < testDataLines.length; i++) { String[] testDataValues = testDataLines[i].split(","); // trimm spaces and add value Map<String, List<Object>> input = new HashMap<>(); for (int j = 0; j < testDataValues.length; j++) { testDataValues[j] = testDataValues[j].trim(); if (testDataValues[j].equals("") == false) { List<Object> fieldInput = new ArrayList<>(); if (j == 0 || j == 2 || j == 4 || j == 10 || j == 11 || j == 12) { fieldInput.add(Double.parseDouble(testDataValues[j])); } else { fieldInput.add(testDataValues[j]); } input.put(fields[j], fieldInput); } else { if (randomBoolean()) { input.put(fields[j], new ArrayList<>()); } } } @SuppressWarnings("unchecked") Map<String, Object> result = (Map<String, Object>) ((VectorRangesToVectorPMML) fieldsToVectorAndModel.getVectorRangesToVector()) .vector(input); String[] expectedResult = expectedResultsLines[i].split(","); String expectedClass = expectedResult[expectedResult.length - 1]; expectedClass = expectedClass.substring(1, expectedClass.length() - 1); @SuppressWarnings("unchecked") Map<String, Object> resultValues = fieldsToVectorAndModel.getModel().evaluateDebug(new MapModelInput(result)); assertThat("result " + i + " has wrong prediction", expectedClass, equalTo(resultValues.get("class"))); } } public void testExtractFieldNames() throws IOException { final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/tree-adult-full-r.xml"); PMML pmml = parsePmml(pmmlString); TreeModel treeModel = (TreeModel) pmml.getModels().get(0); Set<String> expectedFieldNames = new HashSet<>(); expectedFieldNames.addAll(Arrays.asList(new String[]{"age_z", "relationship", "marital_status", "hours_per_week_z", "sex", "occupation", "education", "education_num_z", "native_country", "race", "workclass"})); Set<String> fieldNames = new HashSet<>(); TreeModelFactory.getFieldNamesFromNode(fieldNames, treeModel.getNode()); assertThat(expectedFieldNames.size(), equalTo(fieldNames.size())); for (String fieldName : expectedFieldNames) { assertTrue(fieldNames.contains(fieldName)); } } public void testFieldTypeMapExtract() throws IOException { final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/tree-small-r.xml"); PMML pmml = parsePmml(pmmlString); TreeModel treeModel = (TreeModel) pmml.getModels().get(0); List<VectorRange> fields = TreeModelFactory.getFieldValuesList(treeModel, pmml.getDataDictionary(), pmml.getTransformationDictionary()); Map<String, String> fieldToTypeMap = TreeModelFactory.getFieldToTypeMap(fields); assertTrue(fieldToTypeMap.containsKey("age_z")); assertThat(fieldToTypeMap.get("age_z"), equalTo("double")); assertTrue(fieldToTypeMap.containsKey("work")); assertThat(fieldToTypeMap.get("work"), equalTo("string")); assertTrue(fieldToTypeMap.containsKey("education")); assertThat(fieldToTypeMap.get("education"), equalTo("string")); } /*tests for naive bayes model*/ public void testBigModelAndFeatureParsingFromRExportNaiveBayesModel() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/naive-bayes-adult-full-r.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(10)); assertBiggerModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/naive_bayes_full_single_value.txt", "/org/elasticsearch/script/naive_bayes_full_single_result.txt"); } /*tests for naive bayes model*/ public void testBigModelAndFeatureParsingFromRExportNaiveBayesModelReorderdParams() throws IOException { ModelFactories factories = ModelFactories.createDefaultModelFactories(); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/naive-bayes-adult-full-r-reordered.xml"); PMML pmml = parsePmml(pmmlString); ModelAndModelInputEvaluator<VectorModelInput, String> fieldsToVectorAndModel = factories.buildFromPMML(pmml, 0); VectorModelInputEvaluator vectorEntries = (VectorModelInputEvaluator) fieldsToVectorAndModel.getVectorRangesToVector(); assertThat(vectorEntries.getVectorRangeList().size(), equalTo(10)); assertBiggerModelCorrect(fieldsToVectorAndModel, "/org/elasticsearch/script/naive_bayes_full_single_value.txt", "/org/elasticsearch/script/naive_bayes_full_single_result.txt"); } }