/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import com.google.common.primitives.Ints;
/**
*
* This class contains common utilities that will be used in the shifu plugins.
*
*/
public class PMMLAdapterCommonUtil {
private static List<String> getSchemaFieldViaUsageType(final MiningSchema schema, final FieldUsageType type) {
List<String> targetFields = new ArrayList<String>();
for(MiningField f: schema.getMiningFields()) {
FieldUsageType uType = f.getUsageType();
if(uType == type)
targetFields.add(f.getName().getValue());
}
return targetFields;
}
/**
* This function returns the target field names based on the given mining
* schema
*
* @param schema
* the schema
* @return target field names
*/
public static List<String> getSchemaTargetFields(final MiningSchema schema) {
return getSchemaFieldViaUsageType(schema, FieldUsageType.TARGET);
}
/**
* This function returns the active field names based on the given mining
* schema
*
* @param schema
* the schema
* @return active field names
*/
public static List<String> getSchemaActiveFields(final MiningSchema schema) {
return getSchemaFieldViaUsageType(schema, FieldUsageType.ACTIVE);
}
/**
* This function returns all used field names based on the given mining
* schema
*
* @param schema
* the schema
* @return field names
*/
public static List<String> getSchemaSelectedFields(final MiningSchema schema) {
List<String> targetFields = new ArrayList<String>();
for(MiningField f: schema.getMiningFields()) {
FieldUsageType uType = f.getUsageType();
if(uType == FieldUsageType.TARGET || uType == FieldUsageType.ACTIVE)
targetFields.add(f.getName().getValue());
}
return targetFields;
}
/**
* Create PMML neural output for the neural network models
*
* @param schema
* the schema
* @param layerID
* which layer the output neuron lies
* @return neural outputs
*/
public static NeuralOutputs getOutputFields(final MiningSchema schema, final int layerID) {
List<String> outputID = getSchemaFieldViaUsageType(schema, FieldUsageType.TARGET);
NeuralOutputs outputs = new NeuralOutputs();
int outputFieldsNum = outputID.size();
outputs.setNumberOfOutputs(outputFieldsNum);
for(int i = 0; i < outputFieldsNum; i++) {
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
field.withExpression(new FieldRef(new FieldName(outputID.get(i))));
outputs.withNeuralOutputs(new NeuralOutput(field, String.valueOf(layerID + "," + i)));
}
return outputs;
}
/**
* Generate Regression Table based on the weight list, intercept and partial
* PMML model
*
* @param weights
* weight list for the Regression Table
* @param intercept
* the intercept
* @param pmmlModel
* partial PMMl model
* @return regression model instance
*/
public static RegressionModel getRegressionTable(final double[] weights, final double intercept,
RegressionModel pmmlModel) {
RegressionTable table = new RegressionTable();
MiningSchema schema = pmmlModel.getMiningSchema();
// TODO may not need target field in LRModel
pmmlModel.withFunctionName(MiningFunctionType.REGRESSION).withNormalizationMethod(
RegressionNormalizationMethodType.LOGIT);
List<String> outputFields = getSchemaFieldViaUsageType(schema, FieldUsageType.TARGET);
// TODO only one outputField, what if we have more than one outputField
pmmlModel.withTargetFieldName(new FieldName(outputFields.get(0)));
table.withTargetCategory(outputFields.get(0));
List<String> activeFields = getSchemaFieldViaUsageType(schema, FieldUsageType.ACTIVE);
int index = 0;
for(DerivedField dField: pmmlModel.getLocalTransformations().getDerivedFields()) {
Expression expression = dField.getExpression();
if(expression instanceof NormContinuous) {
NormContinuous norm = (NormContinuous) expression;
if(activeFields.contains(norm.getField().getValue()))
table.withNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
}
}
pmmlModel.withRegressionTables(table);
return pmmlModel;
}
/**
* get the header names from the PMML data dictionary
*
* @param pmml
* the pmml model
* @return headers
*/
public static String[] getDataDicHeaders(final PMML pmml) {
DataDictionary dictionary = pmml.getDataDictionary();
List<DataField> fields = dictionary.getDataFields();
int len = fields.size();
String[] headers = new String[len];
for(int i = 0; i < len; i++) {
headers[i] = fields.get(i).getName().getValue();
}
return headers;
}
/**
* get the column indexes for all active fields in the input data set
*
* @param pmml
* the pmml model
* @return active id
*/
public static int[] getActiveID(PMML pmml) {
return getDicFieldIDViaType(pmml, FieldUsageType.ACTIVE);
}
/**
* get the column index for the target fields in the input data set
*
* @param pmml
* the pmml model
* @return target id
*/
public static int[] getTargetID(PMML pmml) {
return getDicFieldIDViaType(pmml, FieldUsageType.TARGET);
}
/**
* Based on the usage type, get the column indexes for corresponding fields
* in the input data set
*
* @param pmml
* the pmml model
* @param type
* the type
* @return dic fields
*/
public static int[] getDicFieldIDViaType(PMML pmml, FieldUsageType type) {
List<Integer> activeFields = new ArrayList<Integer>();
HashMap<String, Integer> dMap = new HashMap<String, Integer>();
int index = 0;
for(DataField dField: pmml.getDataDictionary().getDataFields())
dMap.put(dField.getName().getValue(), index++);
for(MiningField mField: pmml.getModels().get(0).getMiningSchema().getMiningFields()) {
if(mField.getUsageType() == type)
activeFields.add(dMap.get(mField.getName().getValue()));
}
return Ints.toArray(activeFields);
}
}