/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.ml.factories;
import org.dmg.pmml.DataField;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.RegressionModel;
import org.elasticsearch.ml.modelinput.MockDataSource;
import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator;
import org.elasticsearch.ml.modelinput.ModelInputEvaluator;
import org.elasticsearch.ml.modelinput.VectorModelInput;
import org.elasticsearch.ml.models.EsLinearSVMModel;
import org.elasticsearch.ml.models.EsLogisticRegressionModel;
import org.elasticsearch.ml.models.EsModelEvaluator;
import org.elasticsearch.ml.factories.ModelFactories;
import org.elasticsearch.script.PMMLGenerator;
import org.elasticsearch.script.pmml.ProcessPMMLHelper;
import org.elasticsearch.test.ESTestCase;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import javax.xml.transform.Source;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
/**
*/
public class ModelTests extends ESTestCase {
public void testEsLogisticRegressionModels() throws IOException, JAXBException, SAXException {
ModelFactories factories = ModelFactories.createDefaultModelFactories();
for (int i = 0; i < 100; i++) {
double[] modelParams = new double[]{
randomFloat() * randomIntBetween(-100, +100),
randomFloat() * randomIntBetween(-100, +100),
randomFloat() * randomIntBetween(-100, +100),
randomFloat() * randomIntBetween(-100, +100)
};
String pmmlString;
boolean lrModel = randomBoolean();
if (lrModel) {
pmmlString = PMMLGenerator.generateLRPMMLModel(0.1, modelParams, new double[]{1, 0});
} else {
pmmlString = PMMLGenerator.generateSVMPMMLModel(0.1, modelParams, new double[]{1, 0});
}
PMML pmml = ProcessPMMLHelper.parsePmml(pmmlString);
assertEquals(1, pmml.getModels().size());
ModelAndModelInputEvaluator<VectorModelInput, String> modelAndInput = factories.buildFromPMML(pmml, 0);
EsModelEvaluator<VectorModelInput, String> modelEvaluator = modelAndInput.getModel();
ModelInputEvaluator<VectorModelInput> inputEvaluator = modelAndInput.getVectorRangesToVector();
if (lrModel) {
assertThat(modelEvaluator, instanceOf(EsLogisticRegressionModel.class));
} else {
assertThat(modelEvaluator, instanceOf(EsLinearSVMModel.class));
}
Map<String, List<Object>> vector = new HashMap<>();
vector.put("field_0", Collections.singletonList(1));
vector.put("field_1", Collections.singletonList(1));
vector.put("field_2", Collections.singletonList(1));
vector.put("field_3", Collections.singletonList(0));
MockDataSource dataSource = new MockDataSource(vector);
VectorModelInput vectorModelInput = inputEvaluator.convert(dataSource);
String result = modelEvaluator.evaluate(vectorModelInput);
logger.info("model = {}, result = {}", lrModel ? "lr" : "svm", result);
assertThat(result, anyOf(equalTo("0.0"), equalTo("1.0")));
double val = modelParams[0] + modelParams[1] + modelParams[2] + 0.1;
if (val > 0) {
assertThat(result, equalTo("1.0"));
} else {
assertThat(result, equalTo("0.0"));
}
}
}
public void testGenerateLRPMML() throws JAXBException, IOException, SAXException {
double[] weights = new double[]{randomDouble(), randomDouble(), randomDouble(), randomDouble()};
double intercept = randomDouble();
String generatedPMMLModel = PMMLGenerator.generateLRPMMLModel(intercept, weights, new double[]{1, 0});
PMML hopefullyCorrectPMML;
try (InputStream is = new ByteArrayInputStream(generatedPMMLModel.getBytes(Charset.defaultCharset()))) {
Source transformedSource = ImportFilter.apply(new InputSource(is));
hopefullyCorrectPMML = JAXBUtil.unmarshalPMML(transformedSource);
}
String pmmlString = "<PMML xmlns=\"http://www.dmg.org/PMML-4_2\">\n" +
" <DataDictionary numberOfFields=\"5\">\n" +
" <DataField name=\"field_0\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_1\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_2\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_3\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"target\" optype=\"categorical\" dataType=\"string\"/>\n" +
" </DataDictionary>\n" +
" <RegressionModel modelName=\"logistic regression\" functionName=\"classification\" normalizationMethod=\"logit\">\n" +
" <MiningSchema>\n" +
" <MiningField name=\"field_0\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_1\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_2\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_3\" usageType=\"active\"/>\n" +
" <MiningField name=\"target\" usageType=\"target\"/>\n" +
" </MiningSchema>\n" +
" <RegressionTable intercept=\"" + Double.toString(intercept) + "\" targetCategory=\"1\">\n" +
" <NumericPredictor name=\"field_0\" coefficient=\"" + weights[0] + "\"/>\n" +
" <NumericPredictor name=\"field_1\" coefficient=\"" + weights[1] + "\"/>\n" +
" <NumericPredictor name=\"field_2\" coefficient=\"" + weights[2] + "\"/>\n" +
" <NumericPredictor name=\"field_3\" coefficient=\"" + weights[3] + "\"/>\n" +
" </RegressionTable>\n" +
" <RegressionTable intercept=\"-0.0\" targetCategory=\"0\"/>\n" +
" </RegressionModel>\n" +
"</PMML>";
PMML truePMML;
try (InputStream is = new ByteArrayInputStream(pmmlString.getBytes(Charset.defaultCharset()))) {
Source transformedSource = ImportFilter.apply(new InputSource(is));
truePMML = JAXBUtil.unmarshalPMML(transformedSource);
}
compareModels(truePMML, hopefullyCorrectPMML);
}
public void testGenerateSVMPMML() throws JAXBException, IOException, SAXException {
double[] weights = new double[]{randomDouble(), randomDouble(), randomDouble(), randomDouble()};
double intercept = randomDouble();
String generatedPMMLModel = PMMLGenerator.generateSVMPMMLModel(intercept, weights, new double[]{1, 0});
PMML hopefullyCorrectPMML;
try (InputStream is = new ByteArrayInputStream(generatedPMMLModel.getBytes(Charset.defaultCharset()))) {
Source transformedSource = ImportFilter.apply(new InputSource(is));
hopefullyCorrectPMML = JAXBUtil.unmarshalPMML(transformedSource);
}
String pmmlString = "<PMML xmlns=\"http://www.dmg.org/PMML-4_2\">\n" +
" <DataDictionary numberOfFields=\"5\">\n" +
" <DataField name=\"field_0\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_1\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_2\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"field_3\" optype=\"continuous\" dataType=\"double\"/>\n" +
" <DataField name=\"target\" optype=\"categorical\" dataType=\"string\"/>\n" +
" </DataDictionary>\n" +
" <RegressionModel modelName=\"linear SVM\" functionName=\"classification\" normalizationMethod=\"none\">\n" +
" <MiningSchema>\n" +
" <MiningField name=\"field_0\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_1\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_2\" usageType=\"active\"/>\n" +
" <MiningField name=\"field_3\" usageType=\"active\"/>\n" +
" <MiningField name=\"target\" usageType=\"target\"/>\n" +
" </MiningSchema>\n" +
" <RegressionTable intercept=\"" + intercept + "\" targetCategory=\"1\">\n" +
" <NumericPredictor name=\"field_0\" coefficient=\"" + weights[0] + "\"/>\n" +
" <NumericPredictor name=\"field_1\" coefficient=\"" + weights[1] + "\"/>\n" +
" <NumericPredictor name=\"field_2\" coefficient=\"" + weights[2] + "\"/>\n" +
" <NumericPredictor name=\"field_3\" coefficient=\"" + weights[3] + "\"/>\n" +
" </RegressionTable>\n" +
" <RegressionTable intercept=\"0.0\" targetCategory=\"0\"/>\n" +
" </RegressionModel>\n" +
"</PMML>";
PMML truePMML;
try (InputStream is = new ByteArrayInputStream(pmmlString.getBytes(Charset.defaultCharset()))) {
Source transformedSource = ImportFilter.apply(new InputSource(is));
truePMML = JAXBUtil.unmarshalPMML(transformedSource);
}
compareModels(truePMML, hopefullyCorrectPMML);
}
public void compareModels(PMML model1, PMML model2) {
assertThat(model1.getDataDictionary().getNumberOfFields(), equalTo(model2.getDataDictionary().getNumberOfFields()));
int i = 0;
for (DataField dataField : model1.getDataDictionary().getDataFields()) {
DataField otherDataField = model2.getDataDictionary().getDataFields().get(i);
assertThat(dataField.getDataType(), equalTo(otherDataField.getDataType()));
assertThat(dataField.getName(), equalTo(otherDataField.getName()));
i++;
}
assertThat(model1.getModels().size(), equalTo(model2.getModels().size()));
i = 0;
for (Model model : model1.getModels()) {
if (model.getModelName().equals("linear SVM")) {
assertThat(model, instanceOf(RegressionModel.class));
assertThat(model2.getModels().get(i), instanceOf(RegressionModel.class));
compareModels((RegressionModel) model, (RegressionModel) model2.getModels().get(i));
} else if (model.getModelName().equals("logistic regression")) {
assertThat(model, instanceOf(RegressionModel.class));
assertThat(model2.getModels().get(i), instanceOf(RegressionModel.class));
compareModels((RegressionModel) model, (RegressionModel) model2.getModels().get(i));
} else {
throw new UnsupportedOperationException("model " + model.getAlgorithmName() +
" is not supported and therfore not tested yet");
}
i++;
}
}
private static void compareModels(RegressionModel model1, RegressionModel model2) {
assertThat(model1.getFunctionName().value(), equalTo(model2.getFunctionName().value()));
assertThat(model1.getFunctionName().value(), equalTo(model2.getFunctionName().value()));
assertThat(model1.getNormalizationMethod().value(), equalTo(model2.getNormalizationMethod().value()));
compareMiningFields(model1, model2);
}
private static void compareMiningFields(Model model1, Model model2) {
int i = 0;
for (MiningField miningField : model1.getMiningSchema().getMiningFields()) {
MiningField otherMiningField = model2.getMiningSchema().getMiningFields().get(i);
compareMiningFields(miningField, otherMiningField);
i++;
}
}
private static void compareMiningFields(MiningField miningField, MiningField otherMiningField) {
assertThat(miningField.getName(), equalTo(otherMiningField.getName()));
assertThat(miningField.getUsageType().value(), equalTo(otherMiningField.getUsageType().value()));
}
}