/*
* Copyright (c) 2015 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sklearn;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.Enumeration;
import java.util.Properties;
import java.util.Set;
import joblib.NDArrayWrapperConstructor;
import joblib.NumpyArrayWrapper;
import net.razorvine.pickle.Opcodes;
import net.razorvine.pickle.Unpickler;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.NDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class PickleUtil {
private PickleUtil(){
}
static
public Storage createStorage(File file){
try {
InputStream is = new FileInputStream(file);
try {
return new CompressedInputStreamStorage(is);
} catch(IOException ioe){
is.close();
}
} catch(IOException ioe){
// Ignored
}
return new FileStorage(file);
}
static
public Object unpickle(Storage storage) throws IOException {
ObjectConstructor[] constructors = {
new NDArrayWrapperConstructor("joblib.numpy_pickle", "NDArrayWrapper", storage),
new NDArrayWrapperConstructor("sklearn.externals.joblib.numpy_pickle", "NDArrayWrapper", storage),
};
for(ObjectConstructor constructor : constructors){
Unpickler.registerConstructor(constructor.getModule(), constructor.getName(), constructor);
}
try(final InputStream is = storage.getObject()){
Unpickler unpickler = new Unpickler(){
@Override
protected Object dispatch(short key) throws IOException {
Object result = super.dispatch(key);;
if(key == Opcodes.BUILD){
Object head = super.stack.peek();
// Modify the stack by replacing NumpyArrayWrapper with NDArray
if(head instanceof NumpyArrayWrapper){
NumpyArrayWrapper arrayWrapper = (NumpyArrayWrapper)head;
super.stack.pop();
NDArray array = arrayWrapper.toArray(is);
super.stack.add(array);
}
}
return result;
}
};
return unpickler.load(is);
}
}
static
private void init(){
Thread thread = Thread.currentThread();
ClassLoader classLoader = thread.getContextClassLoader();
if(classLoader == null){
classLoader = ClassLoader.getSystemClassLoader();
}
Enumeration<URL> urls;
try {
urls = classLoader.getResources("META-INF/sklearn2pmml.properties");
} catch(IOException ioe){
logger.warn("Failed to find resources", ioe);
return;
}
while(urls.hasMoreElements()){
URL url = urls.nextElement();
logger.debug("Loading resource {}", url);
try(InputStream is = url.openStream()){
Properties properties = new Properties();
properties.load(is);
init(classLoader, properties);
} catch(IOException ioe){
logger.warn("Failed to load resource", ioe);
}
}
}
static
private void init(ClassLoader classLoader, Properties properties){
if(properties.isEmpty()){
return;
}
Set<String> keys = properties.stringPropertyNames();
for(String key : keys){
String value = properties.getProperty(key);
if(value == null || ("").equals(value)){
value = key;
}
logger.debug("Mapping Python class {} to Java class {}", key, value);
int dot = key.lastIndexOf('.');
if(dot < 0){
logger.warn("Failed to identify the module and name parts of Python class {}", key);
continue;
}
String module = key.substring(0, dot);
String name = key.substring(dot + 1);
Class<?> clazz;
try {
clazz = classLoader.loadClass(value);
} catch(ClassNotFoundException cnfe){
logger.warn("Failed to load Java class {}", value);
continue;
}
ObjectConstructor constructor;
if((CClassDict.class).isAssignableFrom(clazz)){
constructor = new ExtensionObjectConstructor(module, name, (Class<? extends CClassDict>)clazz);
} else
if((ClassDict.class).isAssignableFrom(clazz)){
constructor = new ObjectConstructor(module, name, (Class<? extends ClassDict>)clazz);
} else
{
logger.warn("Failed to identify the type of Java class {}", value);
continue;
}
Unpickler.registerConstructor(constructor.getModule(), constructor.getName(), constructor);
}
}
private static final Logger logger = LoggerFactory.getLogger(PickleUtil.class);
static {
PickleUtil.init();
}
}