/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sklearn;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import joblib.NDArrayWrapper;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.NDArray;
import numpy.core.NDArrayUtil;
public class ClassDictUtil {
private ClassDictUtil(){
}
static
public List<?> getArray(ClassDict dict, String name){
Object object = dict.get(name);
if(object instanceof HasArray){
HasArray hasArray = (HasArray)object;
return hasArray.getArrayContent();
} // End if
if(object instanceof Number){
return Collections.singletonList(object);
}
throw new IllegalArgumentException("The value of the " + ClassDictUtil.formatMember(dict, name) + " attribute (" + ClassDictUtil.formatClass(object) + ") is not a supported array type");
}
static
public List<?> getArray(ClassDict dict, String name, String key){
Object object = dict.get(name);
if(object instanceof NDArrayWrapper){
NDArrayWrapper arrayWrapper = (NDArrayWrapper)object;
object = arrayWrapper.getContent();
} // End if
if(object instanceof NDArray){
NDArray array = (NDArray)object;
return NDArrayUtil.getContent(array, key);
}
throw new IllegalArgumentException("The value of the " + ClassDictUtil.formatMember(dict, name) + " attribute (" + ClassDictUtil.formatClass(object) + ") is not a supported array type");
}
static
public int[] getShape(ClassDict dict, String name, int length){
int[] shape = getShape(dict, name);
if(shape.length != length){
throw new IllegalArgumentException("The dimensionality of the " + ClassDictUtil.formatMember(dict, name) + " attribute (" + shape.length + ") is not " + length);
}
return shape;
}
static
public int[] getShape(ClassDict dict, String name){
Object object = dict.get(name);
if(object instanceof HasArray){
HasArray hasArray = (HasArray)object;
return hasArray.getArrayShape();
} // End if
if(object instanceof Number){
return new int[]{1};
}
throw new IllegalArgumentException("The value of the " + ClassDictUtil.formatMember(dict, name) + " attribute (" + ClassDictUtil.formatClass(object) +") is not a supported array type");
}
static
public void checkSize(Collection<?>... collections){
Collection<?> prevCollection = null;
for(Collection<?> collection : collections){
if(collection == null){
continue;
} // End if
if(prevCollection != null && collection.size() != prevCollection.size()){
throw new IllegalArgumentException("Expected the same number of elements, got different numbers of elements");
}
prevCollection = collection;
}
}
static
public void checkSize(int size, Collection<?>... collections){
for(Collection<?> collection : collections){
if(collection == null){
continue;
} // End if
if(collection.size() != size){
throw new IllegalArgumentException("Expected " + size + " element(s), got " + collection.size() + " element(s)");
}
}
}
static
public String formatMember(ClassDict dict, String name){
String clazz = (String)dict.get("__class__");
return (clazz + "." + name);
}
static
public String formatClass(Object object){
if(object == null){
return null;
} // End if
if(object instanceof ClassDict){
ClassDict dict = (ClassDict)object;
String clazz = (String)dict.get("__class__");
return "Python class " + clazz;
}
Class<?> clazz = object.getClass();
return "Java class " + clazz.getName();
}
static
public String toString(ClassDict dict){
StringBuffer sb = new StringBuffer();
sb.append("\n{\n");
String sep = "";
List<? extends Map.Entry<String, ?>> entries = new ArrayList<>(dict.entrySet());
Comparator<Map.Entry<String, ?>> comparator = new Comparator<Map.Entry<String, ?>>(){
@Override
public int compare(Map.Entry<String, ?> left, Map.Entry<String, ?> right){
return (left.getKey()).compareToIgnoreCase(right.getKey());
}
};
Collections.sort(entries, comparator);
for(Map.Entry<String, ?> entry : entries){
sb.append(sep);
sep = "\n";
String key = entry.getKey();
Object value = entry.getValue();
sb.append("\t" + key + "=" + value + (" // " + (value != null ? (value.getClass()).getName() : "N/A")));
}
sb.append("\n}\n");
return sb.toString();
}
}