package org.deeplearning4j.util;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*;
import java.util.UUID;
/**
* Guess a model from the given path
* @author Adam Gibson
*/
@Slf4j
public class ModelGuesser {
/**
* A facade for {@link ModelSerializer#restoreNormalizerFromInputStream(InputStream)}
* @param is the input stream to load form
* @return the loaded normalizer
* @throws IOException
*/
public static Normalizer<?> loadNormalizer(InputStream is) throws IOException {
return ModelSerializer.restoreNormalizerFromInputStream(is);
}
/**
* A facade for {@link ModelSerializer#restoreNormalizerFromFile(File)}
* @param path the path to the file
* @return the loaded normalizer
*/
public static Normalizer<?> loadNormalizer(String path) {
return ModelSerializer.restoreNormalizerFromFile(new File(path));
}
/**
* Load the model from the given file path
* @param path the path of the file to "guess"
*
* @return the loaded model
* @throws Exception
*/
public static Object loadConfigGuess(String path) throws Exception {
String input = FileUtils.readFileToString(new File(path));
//note here that we load json BEFORE YAML. YAML
//turns out to load just fine *accidentally*
try {
return MultiLayerConfiguration.fromJson(input);
} catch (Exception e) {
log.warn("Tried multi layer config from json", e);
try {
return KerasModelImport.importKerasModelConfiguration(path);
} catch (Exception e1) {
log.warn("Tried keras model config", e);
try {
return KerasModelImport.importKerasSequentialConfiguration(path);
} catch (Exception e2) {
log.warn("Tried keras sequence config", e);
try {
return ComputationGraphConfiguration.fromJson(input);
} catch (Exception e3) {
log.warn("Tried computation graph from json");
try {
return MultiLayerConfiguration.fromYaml(input);
} catch (Exception e4) {
log.warn("Tried multi layer configuration from yaml");
try {
return ComputationGraphConfiguration.fromYaml(input);
} catch (Exception e5) {
throw e5;
}
}
}
}
}
}
}
/**
* Load the model from the given input stream
* @param stream the path of the file to "guess"
*
* @return the loaded model
* @throws Exception
*/
public static Object loadConfigGuess(InputStream stream) throws Exception {
File tmp = new File("model-" + UUID.randomUUID().toString());
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmp));
IOUtils.copy(stream, bufferedOutputStream);
bufferedOutputStream.flush();
bufferedOutputStream.close();
tmp.deleteOnExit();
Object load = loadConfigGuess(tmp.getAbsolutePath());
tmp.delete();
return load;
}
/**
* Load the model from the given file path
* @param path the path of the file to "guess"
*
* @return the loaded model
* @throws Exception
*/
public static Model loadModelGuess(String path) throws Exception {
try {
return ModelSerializer.restoreMultiLayerNetwork(new File(path), true);
} catch (Exception e) {
log.warn("Tried multi layer network");
try {
return ModelSerializer.restoreComputationGraph(new File(path), true);
} catch (Exception e1) {
log.warn("Tried computation graph");
try {
return ModelSerializer.restoreMultiLayerNetwork(new File(path), false);
}catch(Exception e4) {
try {
return ModelSerializer.restoreComputationGraph(new File(path), false);
}catch(Exception e5) {
try {
return KerasModelImport.importKerasModelAndWeights(path);
} catch (Exception e2) {
log.warn("Tried multi layer network keras");
try {
return KerasModelImport.importKerasSequentialModelAndWeights(path);
} catch (Exception e3) {
throw e3;
}
}
}
}
}
}
}
/**
* Load the model from the given input stream
* @param stream the path of the file to "guess"
*
* @return the loaded model
* @throws Exception
*/
public static Model loadModelGuess(InputStream stream) throws Exception {
try {
return ModelSerializer.restoreMultiLayerNetwork(stream, true);
} catch (Exception e) {
try {
return ModelSerializer.restoreComputationGraph(stream, true);
} catch (Exception e1) {
try {
return ModelSerializer.restoreMultiLayerNetwork(stream, false);
}catch(Exception e5) {
try {
return ModelSerializer.restoreComputationGraph(stream, false);
}catch(Exception e6) {
try {
return KerasModelImport.importKerasModelAndWeights(stream);
} catch (Exception e2) {
try {
return KerasModelImport.importKerasSequentialModelAndWeights(stream);
} catch (Exception e3) {
throw e3;
}
}
}
}
}
}
}
}