/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.modelinput; import org.dmg.pmml.Apply; import org.dmg.pmml.Constant; import org.dmg.pmml.DataField; import org.dmg.pmml.DerivedField; import org.dmg.pmml.Expression; import org.dmg.pmml.MiningField; import org.dmg.pmml.NormContinuous; import org.elasticsearch.common.collect.Tuple; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; /* * Maps a single field to vector entries. Includes pre processing. * */ public abstract class PMMLVectorRange extends VectorRange { protected PreProcessingStep[] preProcessingSteps; protected List<Object> applyPreProcessing(Map<String, List<Object>> fieldValues) { List<Object> processedValues = new ArrayList<>(); List<Object> valueList = new ArrayList<>(); if (fieldValues.get(field) == null) { valueList = new ArrayList<>(); valueList.add(null); } else if (fieldValues.get(field).size() == 0) { valueList.add(null); } else { valueList.addAll(fieldValues.get(field)); } for (Object value : valueList) { for (int i = 0; i < preProcessingSteps.length; i++) { value = preProcessingSteps[i].apply(value); } processedValues.add(value); } return processedValues; } public PMMLVectorRange(DataField dataField, MiningField miningField, DerivedField[] derivedFields) { super(dataField.getName().getValue(), derivedFields.length == 0 ? dataField.getName().getValue() : derivedFields[derivedFields.length - 1].getName().getValue(), derivedFields.length == 0 ? dataField.getDataType().value() : derivedFields[derivedFields.length - 1].getDataType().value ()); this.field = dataField.getName().getValue(); if (miningField.getMissingValueReplacement() != null) { preProcessingSteps = new PreProcessingStep[derivedFields.length + 1]; preProcessingSteps[0] = new MissingValuePreProcess(dataField, miningField.getMissingValueReplacement()); } else { preProcessingSteps = new PreProcessingStep[derivedFields.length]; } fillPreProcessingSteps(derivedFields); } public PMMLVectorRange(String field, String lastDerivedFieldName, String type) { super(field, lastDerivedFieldName, type); } public abstract void addVectorEntry(int indexCounter, String value); /** * Converts a 1 of k feature into a vector that has a 1 where the field value is the nth category and 0 everywhere else. * Categories will be numbered according to the order given in categories parameter. */ public static class SparseCategoricalVectorRange extends PMMLVectorRange { Map<String, Integer> categoryToIndexHashMap = new HashMap<>(); public SparseCategoricalVectorRange(DataField dataField, MiningField miningField, DerivedField[] derivedFields) { super(dataField, miningField, derivedFields); } @Override public EsVector getVector(DataSource dataSource) { throw new UnsupportedOperationException("Remove this later, we should not get here."); } @Override public EsVector getVector(Map<String, List<Object>> fieldValues) { Tuple<int[], double[]> indicesAndValues; List<Object> processedCategory = applyPreProcessing(fieldValues); List<Integer> indices = new ArrayList<>(); Integer lastIndex = -1; for (Object value : processedCategory) { Integer index = categoryToIndexHashMap.get(value); if (index != null) { indices.add(index); assert lastIndex < index; lastIndex = index; } } int[] indicesArray = new int[indices.size()]; double[] values = new double[indices.size()]; int indexCounter = 0; for (Integer index : indices) { indicesArray[indexCounter] = index; values[indexCounter] = 1.0; indexCounter++; } indicesAndValues = new Tuple<>(indicesArray, values); return new EsSparseNumericVector(indicesAndValues); } @Override public void addVectorEntry(int indexCounter, String value) { categoryToIndexHashMap.put(value, indexCounter); } @Override public int size() { return categoryToIndexHashMap.size(); } } /** * Converts a 1 of k feature into a vector that has a 1 where the field value is the nth category and 0 everywhere else. * Categories will be numbered according to the order given in categories parameter. */ public static class ContinousSingleEntryVectorRange extends PMMLVectorRange { int index = -1; /** * The derived fields must be given in backwards order of the processing chain. */ public ContinousSingleEntryVectorRange(DataField dataField, MiningField miningField, DerivedField... derivedFields) { super(dataField, miningField, derivedFields); } @Override public EsVector getVector(DataSource dataSource) { throw new UnsupportedOperationException("Remove this later, we should not get here."); } @Override public EsVector getVector(Map<String, List<Object>> fieldValues) { Tuple<int[], double[]> indicesAndValues; List<Object> finalValues = applyPreProcessing(fieldValues); if (finalValues.size() > 0) { indicesAndValues = new Tuple<>(new int[]{index}, new double[]{((Number) finalValues.get(0)).doubleValue()}); return new EsSparseNumericVector(indicesAndValues); } else { return new EsSparseNumericVector(new Tuple<>(new int[]{}, new double[]{})); } } @Override public void addVectorEntry(int indexCounter, String value) { index = indexCounter; } @Override public int size() { return 1; } } protected void fillPreProcessingSteps(DerivedField[] derivedFields) { int derivedFieldIndex = derivedFields.length - 1; // don't start at the beginning, we might have a pre processing step there already from the mining field for (int preProcessingStepIndex = preProcessingSteps.length - derivedFields.length; preProcessingStepIndex < preProcessingSteps .length; preProcessingStepIndex++) { DerivedField derivedField = derivedFields[derivedFieldIndex]; if (derivedField.getExpression() != null) { handleExpression(preProcessingStepIndex, derivedField); } else { throw new UnsupportedOperationException("So far only Apply implemented."); } derivedFieldIndex--; } } private void handleExpression(int preProcessingStepIndex, DerivedField derivedField) { if (derivedField.getExpression() instanceof Apply) { for (Expression expression : ((Apply) derivedField.getExpression()).getExpressions()) { if (expression instanceof Apply) { if (((Apply) expression).getFunction().equals("isMissing")) { // now find the value that is supposed to replace the missing value for (Expression expression2 : ((Apply) derivedField.getExpression()).getExpressions()) { if (expression2 instanceof Constant) { String missingValue = ((Constant) expression2).getValue(); preProcessingSteps[preProcessingStepIndex] = new MissingValuePreProcess(derivedField, missingValue); break; } } } else { throw new UnsupportedOperationException("So far only if isMissing implemented."); } } } } else if (derivedField.getExpression() instanceof NormContinuous) { preProcessingSteps[preProcessingStepIndex] = new NormContinousPreProcess((NormContinuous) derivedField .getExpression(), derivedField.getName().getValue()); } else { throw new UnsupportedOperationException("So far only Apply expression implemented."); } } public static class Intercept extends PMMLVectorRange { int index; private String interceptName; public Intercept(String interceptName, String type) { super(null, null, type); this.interceptName = interceptName; } @Override public void addVectorEntry(int indexCounter, String value) { this.index = indexCounter; } @Override public int size() { return 1; } @Override public EsVector getVector(DataSource dataSource) { return new EsSparseNumericVector(new Tuple<>(new int[]{index}, new double[]{1.0})); } @Override public EsVector getVector(Map<String, List<Object>> fieldValues) { return new EsSparseNumericVector(new Tuple<>(new int[]{index}, new double[]{1.0})); } } public static class FieldToValue extends PMMLVectorRange { String finalFieldName; public FieldToValue(DataField dataField, MiningField miningField, DerivedField... derivedFields) { super(dataField, miningField, derivedFields); finalFieldName = preProcessingSteps.length > 0 ? preProcessingSteps[preProcessingSteps.length - 1].name() : field; } @Override public void addVectorEntry(int indexCounter, String value) { throw new UnsupportedOperationException("Not implemented for FieldToValue"); } @Override public int size() { return 1; } @Override public EsVector getVector(DataSource dataSource) { throw new UnsupportedOperationException("Not implemented for FieldToValue"); } @Override public EsVector getVector(Map<String, List<Object>> fieldValues) { List<Object> finalValue = applyPreProcessing(fieldValues); Set<Object> valueSet = new HashSet<>(); valueSet.addAll(finalValue); Map<String, Set<Object>> values = new HashMap<>(); values.put(finalFieldName, valueSet); return new EsValueMapVector(values); } } }