/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.preprocessing;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ConstantFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.converter.PowerFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.HasNumberOfFeatures;
import sklearn.Transformer;
public class PolynomialFeatures extends Transformer implements HasNumberOfFeatures {
public PolynomialFeatures(String module, String name){
super(module, name);
}
@Override
public int getNumberOfFeatures(){
return getNumberOfInputFeatures();
}
@Override
public List<Feature> encodeFeatures(List<Feature> features, final SkLearnEncoder encoder){
int numberOfInputFeatures = getNumberOfInputFeatures();
int numberOfOutputFeatures = getNumberOfOutputFeatures();
ClassDictUtil.checkSize(numberOfInputFeatures, features);
final
int degree = getDegree();
boolean includeBias = getIncludeBias();
boolean interactionOnly = getInteractionOnly();
List<int[]> powers = new ArrayList<>();
for(int i = (includeBias ? 0 : 1); i <= degree; i++){
List<int[]> degreePowers;
if(interactionOnly){
degreePowers = combinations(numberOfInputFeatures, i);
} else
{
degreePowers = combinations_with_replacement(numberOfInputFeatures, i);
}
powers.addAll(degreePowers);
}
ClassDictUtil.checkSize(numberOfOutputFeatures, powers);
Feature unitFeature = new ConstantFeature(encoder, 1.0d);
Function<Feature, Feature[]> function = new Function<Feature, Feature[]>(){
@Override
public Feature[] apply(Feature feature){
Feature[] features = new Feature[degree];
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
Arrays.fill(features, binaryFeature);
} else
{
features[0] = feature;
ContinuousFeature continuousFeature = feature.toContinuousFeature();
for(int i = 2; i <= degree; i++){
features[i - 1] = new PowerFeature(encoder, continuousFeature.getName(), continuousFeature.getDataType(), i);
}
}
return features;
}
};
List<Feature[]> transformedFeatures = new ArrayList<>(Lists.transform(features, function));
List<Feature> result = new ArrayList<>();
for(int[] power : powers){
StringBuilder sb = new StringBuilder();
String sep = "";
List<Feature> powerFeatures = new ArrayList<>();
for(int i = 0; i < power.length; i++){
if(power[i] >= 1){
Feature transformedFeature = transformedFeatures.get(i)[power[i] - 1];
sb.append(sep);
sep = ":";
sb.append((FeatureUtil.getName(transformedFeature)).getValue());
powerFeatures.add(transformedFeature);
}
}
if(powerFeatures.size() == 0){
result.add(unitFeature);
} else
if(powerFeatures.size() == 1){
result.add(Iterables.getOnlyElement(powerFeatures));
} else
{
String id = sb.toString();
result.add(new InteractionFeature(encoder, FieldName.create(id), DataType.DOUBLE, powerFeatures));
}
}
return result;
}
public int getDegree(){
return ValueUtil.asInt((Number)get("degree"));
}
public Boolean getIncludeBias(){
return (Boolean)get("include_bias");
}
public Boolean getInteractionOnly(){
return (Boolean)get("interaction_only");
}
public int getNumberOfInputFeatures(){
return ValueUtil.asInt((Number)get("n_input_features_"));
}
public int getNumberOfOutputFeatures(){
return ValueUtil.asInt((Number)get("n_output_features_"));
}
/**
* @see https://docs.python.org/2/library/itertools.html#itertools.combinations
*/
static
private List<int[]> combinations(int n, int r){
List<int[]> result = new ArrayList<>();
int[] indices = new int[r];
for(int i = 0; i < r; i++){
indices[i] = i;
}
result.add(power(n, indices));
while(true){
int i = (r - 1);
for(; i > -1; i--){
if(indices[i] != (i + n - r)){
break;
}
}
if(i < 0){
break;
}
indices[i] += 1;
for(int j = (i + 1); j < r; j++){
indices[j] = (indices[j - 1] + 1);
}
result.add(power(n, indices));
}
return result;
}
/**
* @see https://docs.python.org/2/library/itertools.html#itertools.combinations_with_replacement
*/
static
private List<int[]> combinations_with_replacement(int n, int r){
List<int[]> result = new ArrayList<>();
int[] indices = new int[r];
result.add(power(n, indices));
while(true){
int i = (r - 1);
for(; i > -1; i--){
if(indices[i] != (n - 1)){
break;
}
}
if(i < 0){
break;
}
int value = (indices[i] + 1);
for(int j = i; j < r; j++){
indices[j] = value;
}
result.add(power(n, indices));
}
return result;
}
static
private int[] power(int n, int[] indices){
int[] result = new int[n];
for(int index : indices){
result[index] += 1;
}
return result;
}
}