/**
* Copyright 2017 Hortonworks.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
**/
package com.hortonworks.registries.model.service;
import com.hortonworks.registries.common.QueryParam;
import com.hortonworks.registries.common.exception.service.exception.request.EntityNotFoundException;
import com.hortonworks.registries.model.data.MLModel;
import com.hortonworks.registries.storage.StorageManager;
import com.hortonworks.registries.storage.util.StorageUtils;
import org.apache.commons.io.IOUtils;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.IOUtil;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeatureType;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.manager.PMMLManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
public final class MLModelRegistryService {
private static final Logger LOG = LoggerFactory.getLogger(MLModelRegistryService.class);
private static final String ML_MODEL_NAME_SPACE = new MLModel().getNameSpace();
private final StorageManager storageManager;
public MLModelRegistryService(StorageManager storageManager) {
this.storageManager = storageManager;
}
public Collection<MLModel> listModelInfos() {
return storageManager.list(ML_MODEL_NAME_SPACE);
}
public Collection<MLModel> listModelInfo(List<QueryParam> params) {
return storageManager.find(ML_MODEL_NAME_SPACE, params);
}
public MLModel addModelInfo(
MLModel modelInfo, InputStream pmmlInputStream, String fileName) throws IOException, SAXException, JAXBException {
if (modelInfo.getId() == null) {
modelInfo.setId(storageManager.nextId(ML_MODEL_NAME_SPACE));
}
LOG.debug("Adding model " + modelInfo.getName());
modelInfo.setTimestamp(System.currentTimeMillis());
modelInfo.setPmml(IOUtils.toString(pmmlInputStream, Charset.defaultCharset()));
modelInfo.setUploadedFileName(fileName);
validateModelInfo(modelInfo);
this.storageManager.add(modelInfo);
return modelInfo;
}
public MLModel addOrUpdateModelInfo(
Long modelId, MLModel modelInfo,
InputStream pmmlInputStream,
String fileName) throws IOException, SAXException, JAXBException {
modelInfo.setId(modelId);
modelInfo.setTimestamp(System.currentTimeMillis());
modelInfo.setPmml(IOUtils.toString(pmmlInputStream, Charset.defaultCharset()));
modelInfo.setUploadedFileName(fileName);
validateModelInfo(modelInfo);
this.storageManager.addOrUpdate(modelInfo);
return modelInfo;
}
public MLModel getModelInfo(String name) {
List<QueryParam> queryParams = Collections.singletonList(new QueryParam(MLModel.NAME, name));
Collection<MLModel> modelInfos = this.storageManager.find(ML_MODEL_NAME_SPACE, queryParams);
if (modelInfos.size() == 0) {
throw EntityNotFoundException.byName(name);
}
return modelInfos.iterator().next();
}
public MLModel getModelInfo(Long modelId) {
MLModel modelInfo = new MLModel();
modelInfo.setId(modelId);
MLModel storedModelInfo = this.storageManager.get(modelInfo.getStorableKey());
if (storedModelInfo == null) {
throw EntityNotFoundException.byId(modelId.toString());
}
return storedModelInfo;
}
public MLModel removeModelInfo(Long modelId) {
MLModel modelInfo = new MLModel();
modelInfo.setId(modelId);
MLModel removedModelInfo = this.storageManager.remove(modelInfo.getStorableKey());
if (removedModelInfo == null) {
throw EntityNotFoundException.byId(modelId.toString());
}
return removedModelInfo;
}
public List<MLModelField> getModelOutputFields(MLModel modelInfo) throws IOException, SAXException, JAXBException {
return doGetOutputFieldsForPMMLStream(modelInfo.getPmml());
}
private List<MLModelField> doGetOutputFieldsForPMMLStream(String pmmlContents) throws SAXException, JAXBException {
List<MLModelField> fieldNames = new ArrayList<>();
PMMLManager pmmlManager = new PMMLManager(IOUtil.unmarshal(new ByteArrayInputStream(pmmlContents.getBytes())));
Evaluator modelEvaluator = (ModelEvaluator<?>) pmmlManager.getModelManager(null, ModelEvaluatorFactory.getInstance());
modelEvaluator.getPredictedFields().forEach((f) -> fieldNames.add(getModelField(modelEvaluator.getDataField(f))));
modelEvaluator.getOutputFields().forEach((f) -> {
OutputField outputField = modelEvaluator.getOutputField(f);
ResultFeatureType resultFeatureType = outputField.getFeature();
if (resultFeatureType != ResultFeatureType.PREDICTED_VALUE &&
resultFeatureType != ResultFeatureType.PREDICTED_DISPLAY_VALUE) {
fieldNames.add(getModelField(outputField));
}
});
return fieldNames;
}
public List<MLModelField> getModelInputFields(MLModel modelInfo) throws IOException, SAXException, JAXBException {
return doGetInputFieldsFromPMMLStream(modelInfo.getPmml());
}
private List<MLModelField> doGetInputFieldsFromPMMLStream(String pmmlContents) throws SAXException, JAXBException {
final List<MLModelField> fieldNames = new ArrayList<>();
PMMLManager pmmlManager = new PMMLManager(IOUtil.unmarshal(new ByteArrayInputStream(pmmlContents.getBytes())));
Evaluator modelEvaluator = (ModelEvaluator<?>) pmmlManager.getModelManager(null, ModelEvaluatorFactory.getInstance());
for (FieldName predictedField : modelEvaluator.getActiveFields()) {
fieldNames.add(getModelField(modelEvaluator.getDataField(predictedField)));
}
return fieldNames;
}
private MLModelField getModelField(Field dataField) {
return new MLModelField(dataField.getName().getValue(), dataField.getDataType().toString());
}
private void validateModelInfo(MLModel modelInfo) throws SAXException, JAXBException {
List<MLModelField> outputFields = doGetOutputFieldsForPMMLStream(modelInfo.getPmml());
if (outputFields.isEmpty()) {
throw new RuntimeException(
String.format("PMML File %s does not support empty output", modelInfo.getUploadedFileName()));
}
StorageUtils.ensureUnique(modelInfo, this::listModelInfo, QueryParam.params(
MLModel.NAME, modelInfo.getName()));
}
}