/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.PMML; import org.elasticsearch.ml.modelinput.DataSource; import org.elasticsearch.ml.modelinput.MockDataSource; import org.elasticsearch.ml.modelinput.ModelInputEvaluator; import org.elasticsearch.ml.modelinput.SparseVectorModelInput; import org.elasticsearch.ml.modelinput.VectorModelInput; import org.elasticsearch.ml.modelinput.MapModelInput; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.script.pmml.ProcessPMMLHelper; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import static org.elasticsearch.test.StreamsUtils.copyToStringFromClasspath; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.closeTo; public class VectorizerPMMLSingleNodeTests extends ESTestCase { private DataSource createTestDataSource(String[] work, String education, Integer age) throws IOException { Map<String, List<Object>> data = new HashMap<>(); if (work != null) { Arrays.sort(work); data.put("work", Arrays.asList((Object[])work)); } if (education != null) { data.put("education", Collections.singletonList(education)); } if (age != null) { data.put("age", Collections.singletonList(age)); } return new MockDataSource(data); } @SuppressWarnings("unchecked") public void testGLMOnActualLookup() throws Exception { ModelFactories parser = ModelFactories.createDefaultModelFactories(); DataSource dataSource = createTestDataSource(new String[]{"Self-emp-inc"}, null, 60); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/fake_lr_model_with_missing.xml"); PMML pmml = ProcessPMMLHelper.parsePmml(pmmlString); ModelAndModelInputEvaluator<SparseVectorModelInput, String> fieldsToVectorAndModel = parser.buildFromPMML(pmml, 0); ModelInputEvaluator<SparseVectorModelInput> vectorEntries = fieldsToVectorAndModel.getVectorRangesToVector(); SparseVectorModelInput vector = vectorEntries.convert(dataSource); assertThat(vector.getValues().length, equalTo(3)); assertThat(vector.getIndices().length, equalTo(3)); assertArrayEquals(vector.getIndices(), new int[]{0, 2, 5}); assertArrayEquals(vector.getValues(), new double[]{1.1724330344107299, 1.0, 1.0}, 1.e-7); // test missing values dataSource = createTestDataSource(new String[]{"Self-emp-inc"}, null, null); vector = vectorEntries.convert(dataSource); assertThat((vector.getValues()).length, equalTo(3)); assertThat(vector.getIndices().length, equalTo(3)); assertArrayEquals(vector.getIndices(), new int[]{0, 2, 5}); assertArrayEquals(vector.getValues(), new double[]{-48.20951464010758, 1.0, 1.0}, 1.e-7); // test missing string field - we expect in this case nothing to be in the vector although that might be a problem with the model... dataSource = createTestDataSource(null, null, 60); vector = vectorEntries.convert(dataSource); assertThat((vector.getValues()).length, equalTo(3)); assertThat(vector.getIndices().length, equalTo(3)); assertArrayEquals(vector.getIndices(), new int[]{0, 4, 5}); assertArrayEquals(vector.getValues(), new double[]{1.1724330344107299, 1.0, 1.0}, 1.e-7); } public void testGLMOnActualLookupMultipleStringValues() throws Exception { ModelFactories parser = ModelFactories.createDefaultModelFactories(); DataSource dataSource = createTestDataSource(new String[]{"Self-emp-inc", "Private"}, null, 60); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/fake_lr_model_with_missing.xml"); PMML pmml = ProcessPMMLHelper.parsePmml(pmmlString); ModelAndModelInputEvaluator<SparseVectorModelInput, String> fieldsToVectorAndModel = parser.buildFromPMML(pmml, 0); ModelInputEvaluator<SparseVectorModelInput> vectorEntries = fieldsToVectorAndModel.getVectorRangesToVector(); SparseVectorModelInput vector = vectorEntries.convert(dataSource); assertThat((vector.getValues()).length, equalTo(4)); assertThat(vector.getIndices().length, equalTo(4)); assertArrayEquals(vector.getIndices(), new int[]{0, 1, 2, 5, }); assertArrayEquals(vector.getValues(), new double[]{1.1724330344107299, 1.0, 1.0, 1.0}, 1.e-7); } @SuppressWarnings("unchecked") public void testTreeModelOnActualLookup() throws Exception { ModelFactories parser = ModelFactories.createDefaultModelFactories(); DataSource dataSource = createTestDataSource(new String[]{"Self-emp-inc"}, "Prof-school", 60); final String pmmlString = copyToStringFromClasspath("/org/elasticsearch/script/tree-small-r.xml"); PMML pmml = ProcessPMMLHelper.parsePmml(pmmlString); ModelAndModelInputEvaluator<MapModelInput, String> fieldsToVectorAndModel = parser.buildFromPMML(pmml, 0); ModelInputEvaluator<MapModelInput> vectorEntries = fieldsToVectorAndModel.getVectorRangesToVector(); Map<String, Object> vector = vectorEntries.convert(dataSource).getAsMap(); assertThat(vector.size(), equalTo(3)); assertThat(((Number)((Set) vector.get("age_z")).iterator().next()).doubleValue(), closeTo(1.5702107070685085, 0.0)); assertThat(((Set) vector.get("education")).iterator().next(), equalTo("Prof-school")); assertThat(((Set) vector.get("work")).iterator().next(), equalTo("Self-emp-inc")); // test missing values dataSource = createTestDataSource(null, null, null); vector = vectorEntries.convert(dataSource).getAsMap(); assertThat(vector.size(), equalTo(3)); assertThat(((Number)((Set) vector.get("age_z")).iterator().next()).doubleValue(), closeTo(-76.13993490863606, 0.0)); assertThat(((Set) vector.get("education")).iterator().next(), equalTo("too-lazy-to-study")); assertThat(((Set) vector.get("work")).iterator().next(), equalTo("other")); } }