/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.neighbors;
import java.util.ArrayList;
import java.util.List;
import javax.xml.parsers.DocumentBuilder;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Row;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Estimator;
public class KNeighborsUtil {
private KNeighborsUtil(){
}
static
public <E extends Estimator & HasNeighbors & HasTrainingData> NearestNeighborModel encodeNeighbors(E estimator, MiningFunction miningFunction, int numberOfInstances, int numberOfFeatures, Schema schema){
List<String> keys = new ArrayList<>();
InstanceFields instanceFields = new InstanceFields();
KNNInputs knnInputs = new KNNInputs();
Label label = schema.getLabel();
if(label != null){
InstanceField instanceField = new InstanceField(label.getName())
.setColumn("y");
instanceFields.addInstanceFields(instanceField);
keys.add(instanceField.getColumn());
}
List<Feature> features = schema.getFeatures();
for(int i = 0; i < features.size(); i++){
Feature feature = features.get(i);
ContinuousFeature continuousFeature = feature.toContinuousFeature(estimator.getDataType());
FieldName name = continuousFeature.getName();
InstanceField instanceField = new InstanceField(name)
.setColumn("x" + String.valueOf(i + 1));
instanceFields.addInstanceFields(instanceField);
keys.add(instanceField.getColumn());
KNNInput knnInput = new KNNInput(name);
knnInputs.addKNNInputs(knnInput);
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
InlineTable inlineTable = new InlineTable();
List<?> y = estimator.getY();
List<? extends Number> fitX = estimator.getFitX();
ClassDictUtil.checkSize(numberOfInstances, y);
for(int i = 0; i < numberOfInstances; i++){
List<Object> values = new ArrayList<>(1 + numberOfFeatures);
values.add(y.get(i));
values.addAll(CMatrixUtil.getRow(fitX, numberOfInstances, numberOfFeatures, i));
Row row = DOMUtil.createRow(documentBuilder, keys, values);
inlineTable.addRows(row);
}
TrainingInstances trainingInstances = new TrainingInstances(instanceFields)
.setInlineTable(inlineTable)
.setTransformed(true);
ComparisonMeasure comparisonMeasure = encodeComparisonMeasure(estimator.getMetric(), estimator.getP());
String weights = estimator.getWeights();
if(!(weights).equals("uniform")){
throw new IllegalArgumentException(weights);
}
int numberOfNeighbors = estimator.getNumberOfNeighbors();
List<OutputField> outputFields = new ArrayList<>(numberOfNeighbors);
for(int i = 0; i < numberOfNeighbors; i++){
int rank = (i + 1);
OutputField outputField = new OutputField(FieldName.create("neighbor(" + rank + ")"), DataType.STRING)
.setOpType(OpType.CATEGORICAL)
.setResultFeature(ResultFeature.ENTITY_ID)
.setRank(rank);
outputFields.add(outputField);
}
Output output = new Output(outputFields);
NearestNeighborModel nearestNeighborModel = new NearestNeighborModel(MiningFunction.REGRESSION, numberOfNeighbors, ModelUtil.createMiningSchema(schema), trainingInstances, comparisonMeasure, knnInputs)
.setOutput(output);
return nearestNeighborModel;
}
static
private ComparisonMeasure encodeComparisonMeasure(String metric, int p){
switch(metric){
case "minkowski":
{
Measure measure;
switch(p){
case 1:
measure = new CityBlock();
break;
case 2:
measure = new Euclidean();
break;
default:
measure = new Minkowski(p);
break;
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE)
.setCompareFunction(CompareFunction.ABS_DIFF)
.setMeasure(measure);
return comparisonMeasure;
}
default:
throw new IllegalArgumentException(metric);
}
}
}