/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn_pandas;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.Feature;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.TupleUtil;
import sklearn.Initializer;
import sklearn.Transformer;
public class DataFrameMapper extends Initializer {
public DataFrameMapper(String module, String name){
super(module, name);
}
@Override
public List<Feature> initializeFeatures(SkLearnEncoder encoder){
Object _default = getDefault();
List<Object[]> rows = getFeatures();
if(!(Boolean.FALSE).equals(_default)){
throw new IllegalArgumentException();
}
List<Feature> result = new ArrayList<>();
for(Object[] row : rows){
List<Feature> rowFeatures = new ArrayList<>();
List<String> columns = getColumnList(row);
for(String column : columns){
FieldName name = FieldName.create(column);
DataField dataField = encoder.getDataField(name);
if(dataField == null){
dataField = encoder.createDataField(name);
}
rowFeatures.add(new WildcardFeature(encoder, dataField));
}
List<Transformer> transformers = getTransformerList(row);
for(Transformer transformer : transformers){
encoder.updateFeatures(rowFeatures, transformer);
rowFeatures = transformer.encodeFeatures(rowFeatures, encoder);
}
if(row.length > 2){
Map<String, ?> options = (Map)row[2];
String alias = (String)options.get("alias");
if(alias != null){
for(int i = 0; i < rowFeatures.size(); i++){
Feature rowFeature = rowFeatures.get(i);
encoder.renameField(rowFeature.getName(), rowFeatures.size() > 1 ? FieldName.create(alias + "_" + i) : FieldName.create(alias));
}
}
}
result.addAll(rowFeatures);
}
return result;
}
public Object getDefault(){
return get("default");
}
public DataFrameMapper setDefault(Object _default){
put("default", _default);
return this;
}
public List<Object[]> getFeatures(){
return (List)get("features");
}
public DataFrameMapper setFeatures(List<Object[]> features){
put("features", features);
return this;
}
static
private List<String> getColumnList(Object[] feature){
Function<Object, String> function = new Function<Object, String>(){
@Override
public String apply(Object object){
if(object instanceof String){
return (String)object;
}
throw new IllegalArgumentException("The key object (" + ClassDictUtil.formatClass(object) + ") is not a String");
}
};
try {
if(feature[0] instanceof HasArray){
HasArray hasArray = (HasArray)feature[0];
return (List)hasArray.getArrayContent();
} // End if
if(feature[0] instanceof List){
return Lists.transform(((List)feature[0]), function);
}
return Collections.singletonList(function.apply(feature[0]));
} catch(RuntimeException re){
throw new IllegalArgumentException("Invalid mapping key", re);
}
}
static
private List<Transformer> getTransformerList(Object[] feature){
Function<Object, Transformer> function = new Function<Object, Transformer>(){
@Override
public Transformer apply(Object object){
if(object instanceof Transformer){
return (Transformer)object;
}
throw new IllegalArgumentException("The value object (" + ClassDictUtil.formatClass(object) + ") is not a Transformer or is not a supported Transformer subclass");
}
};
try {
if(feature[1] == null){
return Collections.emptyList();
} // End if
if(feature[1] instanceof TransformerPipeline){
TransformerPipeline transformerPipeline = (TransformerPipeline)feature[1];
List<Object[]> steps = transformerPipeline.getSteps();
return Lists.transform((List)TupleUtil.extractElementList(steps, 1), function);
} // End if
if(feature[1] instanceof List){
return Lists.transform((List)feature[1], function);
}
return Collections.singletonList(function.apply(feature[1]));
} catch(RuntimeException re){
throw new IllegalArgumentException("Invalid mapping value", re);
}
}
}