/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml;
import org.apache.commons.io.IOUtils;
import org.dmg.pmml.*;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.xml.sax.InputSource;
import javax.xml.transform.sax.SAXSource;
import javax.xml.transform.stream.StreamResult;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class PMMLUtils {
public static List<Extension> createExtensions(Map<String, String> extensionMap) {
List<Extension> extensions = new ArrayList<Extension>();
for(Map.Entry<String, String> entry: extensionMap.entrySet()) {
String key = entry.getKey();
Extension extension = new Extension();
extension.setName(key);
extension.setValue(entry.getValue());
extensions.add(extension);
}
return extensions;
}
public static Extension getExtension(List<Extension> extensions, String key) {
for(Extension extension: extensions) {
if(key.equals(extension.getName())) {
return extension;
}
}
throw new RuntimeException("No such extension: " + key);
}
public static UnivariateStats getUnivariateStatsByFieldName(ModelStats modelStats, FieldName fieldName) {
for(UnivariateStats univariateStats: modelStats.getUnivariateStats()) {
if(univariateStats.getField().equals(fieldName)) {
return univariateStats;
}
}
throw new RuntimeException("No UnivariateStats for field: " + fieldName);
}
public static void savePMML(PMML pmml, String path) {
OutputStream os = null;
try {
os = new FileOutputStream(path);
StreamResult result = new StreamResult(os);
JAXBUtil.marshalPMML(pmml, result);
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(os);
}
}
public static PMML loadPMML(String path) throws Exception {
InputStream is = null;
try {
is = new FileInputStream(path);
InputSource source = new InputSource(is);
SAXSource transformedSource = ImportFilter.apply(source);
return JAXBUtil.unmarshalPMML(transformedSource);
} catch (Exception e) {
e.printStackTrace();
throw e;
} finally {
IOUtils.closeQuietly(is);
}
}
public static DataType getDefaultDataTypeByOpType(OpType optype) {
if(optype.equals(OpType.CONTINUOUS)) {
return DataType.DOUBLE;
} else {
return DataType.STRING;
}
}
public static Model createModelByType(String name) {
if(name.equalsIgnoreCase("NeuralNetwork")) {
return new NeuralNetwork();
} else {
throw new RuntimeException("Model not supported: " + name);
}
}
public static Model getModelByName(PMML pmml, String name) {
for(Model model: pmml.getModels()) {
if(model.getModelName().equals(name)) {
return model;
}
}
throw new RuntimeException("No such model: " + name);
}
public static Integer getTargetFieldNumByName(DataDictionary dataDictionary, String name) {
int size = dataDictionary.getNumberOfFields();
for(int i = 0; i < size; i++) {
DataField dataField = dataDictionary.getDataFields().get(i);
if(dataField.getName().getValue().equals(name)) {
return i;
}
}
throw new RuntimeException("Target Field Not Found: " + name);
}
public static Map<FieldName, Integer> getFieldNumMap(DataDictionary dataDictionary) {
Map<FieldName, Integer> fieldNumMap = new HashMap<FieldName, Integer>();
int size = dataDictionary.getNumberOfFields();
for(int i = 0; i < size; i++) {
DataField dataField = dataDictionary.getDataFields().get(i);
fieldNumMap.put(dataField.getName(), i);
}
return fieldNumMap;
}
public static Map<FieldName, DerivedField> getDerivedFieldMap(LocalTransformations localTransformations) {
Map<FieldName, DerivedField> derivedFieldMap = new HashMap<FieldName, DerivedField>();
for(DerivedField derivedField: localTransformations.getDerivedFields()) {
derivedFieldMap.put(derivedField.getName(), derivedField);
}
return derivedFieldMap;
}
public static Map<FieldName, MiningField> getMiningFieldMap(MiningSchema miningSchema) {
Map<FieldName, MiningField> miningFieldMap = new HashMap<FieldName, MiningField>();
for(MiningField miningField: miningSchema.getMiningFields()) {
miningFieldMap.put(miningField.getName(), miningField);
}
return miningFieldMap;
}
public static Integer getNumActiveMiningFields(MiningSchema miningSchema) {
Integer cnt = 0;
for(MiningField miningField: miningSchema.getMiningFields()) {
if(miningField.getUsageType().equals(FieldUsageType.ACTIVE)) {
cnt += 1;
}
}
return cnt;
}
public static Integer getNumTargetMiningFields(MiningSchema miningSchema) {
Integer cnt = 0;
for(MiningField miningField: miningSchema.getMiningFields()) {
if(miningField.getUsageType().equals(FieldUsageType.TARGET)) {
cnt += 1;
}
}
return cnt;
}
}