/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.script.pmml; import org.dmg.pmml.Apply; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.DerivedField; import org.dmg.pmml.Expression; import org.dmg.pmml.FieldRef; import org.dmg.pmml.MiningField; import org.dmg.pmml.Model; import org.dmg.pmml.NormContinuous; import org.dmg.pmml.OpType; import org.dmg.pmml.PMML; import org.dmg.pmml.TransformationDictionary; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.inject.Provider; import org.elasticsearch.ml.modelinput.PMMLVectorRange; import org.jpmml.model.ImportFilter; import org.jpmml.model.JAXBUtil; import org.xml.sax.InputSource; import org.xml.sax.SAXException; import javax.xml.bind.JAXBException; import javax.xml.transform.Source; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.Charset; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; public class ProcessPMMLHelper { public static DataField getRawDataField(DataDictionary dataDictionary, String rawFieldName) { // now find the actual dataField DataField rawField = null; for (DataField dataField : dataDictionary.getDataFields()) { String rawDataFieldName = dataField.getName().getValue(); if (rawDataFieldName.equals(rawFieldName)) { rawField = dataField; break; } } if (rawField == null) { throw new UnsupportedOperationException("Could not trace back {} to a raw input field. Maybe saomething is not implemented " + "yet or the PMML file is faulty."); } return rawField; } // allDerivedFields is a list of all derived fields that are to be considered // derivedFields is the list of derived fields for fieldName // returns the last derived field that was found public static String getDerivedFields(String fieldName, List<DerivedField> allDerivedFields, List<DerivedField> derivedFields) { // trace back all derived fields until we must arrive at an actual data field. This unfortunately means we have to // loop over dervived fields as often as we find one.. DerivedField lastFoundDerivedField; String lastFieldName = fieldName; do { lastFoundDerivedField = null; for (DerivedField derivedField : allDerivedFields) { if (derivedField.getName().getValue().equals(lastFieldName)) { lastFoundDerivedField = derivedField; derivedFields.add(derivedField); // now get the next fieldname this field references // this is tricky, because this information can be anywhere... lastFieldName = getReferencedFieldName(derivedField); lastFoundDerivedField = derivedField; } } } while (lastFoundDerivedField != null); return lastFieldName; } private static String getReferencedFieldName(DerivedField derivedField) { String referencedField = null; if (derivedField.getExpression() != null) { if (derivedField.getExpression() instanceof Apply) { // TODO throw uoe in case the function is not "if missing" - much more to implement! for (Expression expression : ((Apply) derivedField.getExpression()).getExpressions()) { if (expression instanceof FieldRef) { referencedField = ((FieldRef) expression).getField().getValue(); } } } else if (derivedField.getExpression() instanceof NormContinuous) { referencedField = ((NormContinuous) derivedField.getExpression()).getField().getValue(); } else { throw new UnsupportedOperationException("So far only Apply expression implemented."); } } else { // there is a million ways in which derived fields can reference other fields. // need to implement them all! throw new UnsupportedOperationException("So far only implemented if function for derived fields."); } if (referencedField == null) { throw new UnsupportedOperationException("could not find raw field name. Maybe this derived field references another derived " + "field? Did not implement that yet."); } return referencedField; } public static PMML parsePmml(final String pmmlString) { // this is bad but I have not figured out yet how to avoid the permission for suppressAccessCheck return AccessController.doPrivileged(new PrivilegedAction<PMML>() { public PMML run() { try (InputStream is = new ByteArrayInputStream(pmmlString.getBytes(Charset.defaultCharset()))) { Source transformedSource = ImportFilter.apply(new InputSource(is)); return JAXBUtil.unmarshalPMML(transformedSource); } catch (SAXException e) { throw new ElasticsearchException("could not convert xml to pmml model", e); } catch (JAXBException e) { throw new ElasticsearchException("could not convert xml to pmml model", e); } catch (IOException e) { throw new ElasticsearchException("could not convert xml to pmml model", e); } } }); } public static List<DerivedField> getAllDerivedFields(Model model, TransformationDictionary transformationDictionary) { List<DerivedField> allDerivedFields = new ArrayList<>(); if (transformationDictionary != null) { allDerivedFields.addAll(transformationDictionary.getDerivedFields()); } if (model.getLocalTransformations() != null) { allDerivedFields.addAll(model.getLocalTransformations().getDerivedFields()); } return allDerivedFields; } public static MiningField getMiningField(Model model, String rawFieldName) { MiningField miningField = null; // also pass in the mining schema for additional parameters for (MiningField aMiningField : model.getMiningSchema().getMiningFields()) { if (aMiningField.getKey().getValue().equals(rawFieldName)) { miningField = aMiningField; } } return miningField; } public static PMMLVectorRange extractVectorRange(Model model, DataDictionary dataDictionary, TransformationDictionary transformationDictionary, String fieldName, Provider<Collection<String>> categories, int position, Map<String, OpType> types) { List<DerivedField> allDerivedFields = ProcessPMMLHelper.getAllDerivedFields(model, transformationDictionary); List<DerivedField> derivedFields = new ArrayList<>(); String rawFieldName = ProcessPMMLHelper.getDerivedFields(fieldName, allDerivedFields, derivedFields); DataField rawField = ProcessPMMLHelper.getRawDataField(dataDictionary, rawFieldName); MiningField miningField = ProcessPMMLHelper.getMiningField(model, rawFieldName); PMMLVectorRange featureEntries = getFieldVector(position, derivedFields, rawField, miningField, categories, types); return featureEntries; } public static PMMLVectorRange getFieldVector(int indexCounter, List<DerivedField> derivedFields, DataField rawField, MiningField miningField, Provider<Collection<String>> categories, Map<String, OpType> types) { PMMLVectorRange featureEntries; OpType opType; if (derivedFields.size() == 0) { opType = rawField.getOpType(); } else { opType = derivedFields.get(0).getOpType(); } if (opType.equals(OpType.CONTINUOUS)) { featureEntries = new PMMLVectorRange.ContinousSingleEntryVectorRange(rawField, miningField, derivedFields.toArray(new DerivedField[derivedFields.size()])); featureEntries.addVectorEntry(indexCounter, "dummyValue"); } else if (opType.equals(OpType.CATEGORICAL)) { featureEntries = new PMMLVectorRange.SparseCategoricalVectorRange(rawField, miningField, derivedFields.toArray(new DerivedField[derivedFields.size()])); for (String value : categories.get()) { featureEntries.addVectorEntry(indexCounter, value); indexCounter++; } } else { throw new UnsupportedOperationException("Only implemented continuous and categorical variables so far."); } if(types != null) { types.put(featureEntries.getLastDerivedFieldName(), opType); } return featureEntries; } }