/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package numpy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.jpmml.sklearn.CClassDict;
public class DType extends CClassDict {
public DType(String module, String name){
super(module, name);
}
@Override
public void __init__(Object[] args){
super.__setstate__(createAttributeMap(INIT_ATTRIBUTES, args));
}
/**
* https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/descriptor.c
*/
@Override
public void __setstate__(Object[] args){
super.__setstate__(createAttributeMap(SETSTATE_ATTRIBUTES, args));
}
public DataType getDataType(){
String className = getClassName();
switch(className){
case "numpy.bool_":
return DataType.BOOLEAN;
case "numpy.int_":
case "numpy.int8":
case "numpy.int16":
case "numpy.int32":
case "numpy.int64":
return DataType.INTEGER;
case "numpy.float32":
return DataType.FLOAT;
case "numpy.float_":
case "numpy.float64":
return DataType.DOUBLE;
default:
throw new IllegalArgumentException(className);
}
}
public Object toDescr(){
Map<String, Object[]> values = getValues();
if(values == null){
String obj = getObj();
String order = getOrder();
return formatDescr(obj, order);
}
Set<String> valueKeys = values.keySet();
if((TREE_KEYS).equals(valueKeys)){
return formatDescr(TREE_KEYS, values);
} else
if((NODEDATA_KEYS).equals(valueKeys)){
return formatDescr(NODEDATA_KEYS, values);
}
throw new IllegalArgumentException();
}
public Map<String, Object[]> getValues(){
return (Map)get("values");
}
public String getObj(){
return (String)get("obj");
}
public String getOrder(){
return (String)get("order");
}
static
private List<Object[]> formatDescr(Collection<String> keys, Map<String, Object[]> values){
List<Object[]> result = new ArrayList<>();
for(String key : keys){
Object[] value = values.get(key);
DType dType = (DType)value[0];
result.add(new Object[]{key, dType.toDescr()});
}
return result;
}
static
private String formatDescr(String obj, String order){
if(obj == null){
throw new IllegalArgumentException();
}
return (order != null ? (order + obj) : obj);
}
private static final String[] INIT_ATTRIBUTES = {
"obj",
"align",
"copy"
};
private static final String[] SETSTATE_ATTRIBUTES = {
"version",
"order",
"subdescr",
"names",
"values",
"w_size",
"alignment",
"flags"
};
private static final Set<String> TREE_KEYS = new LinkedHashSet<>(Arrays.asList("left_child", "right_child", "feature", "threshold", "impurity", "n_node_samples", "weighted_n_node_samples"));
private static final Set<String> NODEDATA_KEYS = new LinkedHashSet<>(Arrays.asList("idx_start", "idx_end", "is_leaf", "radius"));
}