/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.stream.module.pmml.processor;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.PostConstruct;
import javax.xml.bind.JAXBException;
import javax.xml.transform.Source;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.cloud.stream.annotation.EnableBinding;
import org.springframework.cloud.stream.config.SpelExpressionConverterConfiguration;
import org.springframework.cloud.stream.messaging.Processor;
import org.springframework.tuple.MutableTuple;
import org.springframework.tuple.Tuple;
import org.springframework.tuple.TupleBuilder;
import org.springframework.context.annotation.Import;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.spel.SpelEvaluationException;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.support.MutableMessage;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
/**
* A processor that evaluates a machine learning model stored in PMML format.
*
* @author Eric Bottard
*/
@EnableBinding(Processor.class)
@EnableConfigurationProperties(PmmlProcessorProperties.class)
@Import({CustomConversionServiceRegistrar.class, SpelExpressionConverterConfiguration.class})
public class PmmlProcessor {
private static final Logger logger = LoggerFactory.getLogger(PmmlProcessor.class);
private final ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
@Autowired
@Qualifier(IntegrationContextUtils.INTEGRATION_EVALUATION_CONTEXT_BEAN_NAME)
private EvaluationContext evaluationContext;
@Autowired
private PmmlProcessorProperties properties;
private SpelExpressionParser spelExpressionParser = new SpelExpressionParser();
private PMML pmml;
@Autowired
private BeanFactory beanFactory;
@PostConstruct
public void setUp() throws IOException, SAXException, JAXBException {
try (InputStream is = properties.getModelLocation().getInputStream()) {
Source transformedSource = ImportFilter.apply(new InputSource(is));
pmml = JAXBUtil.unmarshalPMML(transformedSource);
Assert.state(!pmml.getModels().isEmpty(), "The provided PMML file at " + properties.getModelLocation() + " does not contain any model");
}
}
@ServiceActivator(inputChannel = Processor.INPUT, outputChannel = Processor.OUTPUT)
public Object evaluate(Message<?> input) {
Model model = selectModel(input);
Evaluator evaluator = modelEvaluatorFactory.newModelManager(pmml, model);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<FieldName> activeFields = evaluator.getActiveFields();
for (FieldName activeField : activeFields) {
// The raw (ie. user-supplied) value could be any Java primitive value
Object rawValue = resolveActiveValue(input, activeField.getValue());
// The raw value is passed through:
// 1) outlier treatment,
// 2) missing value treatment,
// 3) invalid value treatment
// and 4) type conversion
FieldValue activeValue = evaluator.prepare(activeField, rawValue);
arguments.put(activeField, activeValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
MutableMessage<?> result = convertToMutable(input);
for (Map.Entry<FieldName, ?> entry : results.entrySet()) {
String fieldName = entry.getKey().getValue();
Expression expression = properties.getOutputs().get(fieldName);
if (expression == null) {
expression = spelExpressionParser.parseExpression("payload." + fieldName);
}
logger.debug("Setting result field named {} using SpEL[{} = {}]", fieldName, expression, entry.getValue());
expression.setValue(evaluationContext, result, entry.getValue());
}
return result;
}
private MutableMessage<?> convertToMutable(Message<?> input) {
Object payload = input.getPayload();
if (payload instanceof Tuple && !(payload instanceof MutableTuple)) {
payload = TupleBuilder.mutableTuple().putAll((Tuple) payload).build();
}
return new MutableMessage<>(payload, input.getHeaders());
}
private Object resolveActiveValue(Message<?> input, String fieldName) {
Expression expression = properties.getInputs().get(fieldName);
if (expression == null) {
// Assume same-name mapping on payload properties
expression = spelExpressionParser.parseExpression("payload." + fieldName);
}
Object result = null;
try {
result = expression.getValue(evaluationContext, input);
}
catch (SpelEvaluationException e) {
// The evaluator will get a chance to handle missing values
}
logger.debug("Resolving value for input field {} using SpEL[{}], result is {}", fieldName, expression, result);
return result;
}
private Model selectModel(Message<?> input) {
String modelName = properties.getModelName();
if (modelName == null && properties.getModelNameExpression() == null) {
return pmml.getModels().get(0);
}
else if (properties.getModelNameExpression() != null) {
modelName = properties.getModelNameExpression().getValue(evaluationContext, input, String.class);
}
for (Model model : pmml.getModels()) {
if (model.getModelName().equals(modelName)) {
return model;
}
}
throw new RuntimeException("Unable to use model named '" + modelName + "'");
}
}