package hex.deepwater;
import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.DataInfo;
import water.H2O;
import water.Iced;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.TwoDimTable;
import java.io.*;
import java.util.Arrays;
import static hex.genmodel.algos.deepwater.DeepwaterMojoModel.createDeepWaterBackend;
/**
* This class contains the state of the Deep Learning model
* This will be shared: one per node
*/
final public class DeepWaterModelInfo extends Iced {
private int _classes;
byte[] _network; // model definition (graph)
byte[] _modelparams; // internal state of native backend (weights/biases/helpers)
private TwoDimTable summaryTable;
transient BackendTrain _backend; //interface provider
transient ThreadLocal<BackendModel> _model = new ThreadLocal<>(); //pointer to C++ process
public ThreadLocal<BackendModel> getModel() {
if(null == _model) {
_model = new ThreadLocal<>();
}
return _model;
}
int _height;
int _width;
int _channels;
float[] _meanData; //mean pixel value of the training data
DataInfo _dataInfo;
volatile boolean _unstable = false;
void nukeModel() {
if (_backend != null && getModel() != null && getModel().get() != null) {
_backend.delete(getModel().get());
}
getModel().set(null);
}
void nukeBackend() {
nukeModel();
_backend = null;
}
void saveNativeState(String path, int iteration) {
assert(_backend !=null);
assert(getModel()!=null);
_backend.saveModel(getModel().get(), path + ".json"); //independent of iterations
_backend.saveParam(getModel().get(), path + "." + iteration + ".params");
}
float[] predict(float[] data) {
assert(_backend !=null);
assert(getModel()!=null);
return _backend.predict(getModel().get(), data);
}
float[] extractLayer(String layer, float[] data) {
assert(_backend !=null);
assert(getModel()!=null);
return _backend.extractLayer(getModel().get(), layer, data);
}
String listAllLayers() {
assert(_backend !=null);
assert(getModel()!=null);
return _backend.listAllLayers(getModel().get());
}
@Override
public int hashCode() {
return Arrays.hashCode(_network) + Arrays.hashCode(_modelparams);
}
// compute model size (number of model parameters required for making predictions)
// momenta are not counted here, but they are needed for model building
public long size() {
long res = 0;
if (_network!=null) res+=_network.length;
if (_modelparams!=null) res+=_modelparams.length;
return res;
}
public DeepWaterParameters parameters;
public final DeepWaterParameters get_params() { return parameters; }
private long processed_global;
synchronized long get_processed_global() { return processed_global; }
synchronized void set_processed_global(long p) { processed_global = p; }
synchronized void add_processed_global(long p) { processed_global += p; }
private long processed_local;
synchronized long get_processed_local() { return processed_local; }
synchronized void set_processed_local(long p) { processed_local = p; }
synchronized void add_processed_local(long p) { processed_local += p; }
synchronized long get_processed_total() { return processed_global + processed_local; }
private final boolean _classification; // Classification cache (nclasses>1)
private RuntimeOptions getRuntimeOptions() {
RuntimeOptions opts = new RuntimeOptions();
opts.setSeed((int) get_params().getOrMakeRealSeed());
opts.setUseGPU(get_params()._gpu);
opts.setDeviceID(get_params()._device_id);
return opts;
}
private BackendParams getBackendParams() {
BackendParams backendParams = new BackendParams();
backendParams.set("mini_batch_size", get_params()._mini_batch_size);
backendParams.set("clip_gradient", get_params()._clip_gradient);
String network = parameters._network == null ? null : parameters._network.toString();
if (network==null) {
assert (parameters._activation != null);
assert (parameters._hidden != null);
String[] acts = new String[parameters._hidden.length];
String acti;
if (parameters._activation.toString().startsWith("Rectifier")) acti = "relu";
else if (parameters._activation.toString().startsWith("Tanh")) acti = "tanh";
else throw H2O.unimpl();
Arrays.fill(acts, acti);
backendParams.set("activations", acts);
backendParams.set("hidden", parameters._hidden);
backendParams.set("input_dropout_ratio", parameters._input_dropout_ratio);
backendParams.set("hidden_dropout_ratios", parameters._hidden_dropout_ratios);
}
return backendParams;
}
private ImageDataSet getImageDataSet() {
return new ImageDataSet(_width, _height, _channels, _classes);
}
/**
* Main constructor
* @param origParams Model parameters
* @param nClasses number of classes (1 for regression, 0 for autoencoder)
*/
DeepWaterModelInfo(final DeepWaterParameters origParams, int nClasses, int nFeatures) {
_classes = nClasses;
_classification = _classes > 1;
parameters = (DeepWaterParameters) origParams.clone(); //make a copy, don't change model's parameters
_width = nFeatures;
_height = 0;
_channels = 0;
if (parameters._problem_type == DeepWaterParameters.ProblemType.image) {
_width=parameters._image_shape[0];
_height=parameters._image_shape[1];
_channels=parameters._channels;
if (_width==0 || _height==0) {
switch(parameters._network) {
case lenet:
_width = 28;
_height = 28;
break;
case auto:
case alexnet:
case googlenet:
case resnet:
_width = 224;
_height = 224;
break;
case inception_bn:
_width = 299;
_height = 299;
break;
case vgg:
_width = 320;
_height = 320;
break;
case user:
throw new H2OIllegalArgumentException("Please specify width and height for user-given model definition.");
default:
throw H2O.unimpl("Unknown network type: " + parameters._network);
}
}
assert(_width>0);
assert(_height>0);
} else if (parameters._problem_type == DeepWaterParameters.ProblemType.dataset) {
if (parameters._image_shape != null) {
if (parameters._image_shape[0]>0)
_width = parameters._image_shape[0];
if (parameters._image_shape[1]>0)
_height = parameters._image_shape[1];
if (_width>0 && _height>0)
_channels = parameters._channels;
else
_channels = 0;
}
} else if (parameters._problem_type == DeepWaterParameters.ProblemType.text) {
_width =56; //FIXME
} else {
Log.warn("unknown problem_type:", parameters._problem_type);
throw H2O.unimpl();
}
setupNativeBackend();
}
private void setupNativeBackend() {
try {
_backend = createDeepWaterBackend(parameters._backend.toString());
if (_backend == null) throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
ImageDataSet imageDataSet = getImageDataSet();
RuntimeOptions opts = getRuntimeOptions();
BackendParams bparms = getBackendParams();
if (parameters._network != DeepWaterParameters.Network.user) {
String network = parameters._network == null ? null : parameters._network.toString();
if (network != null) {
Log.info("Creating a fresh model of the following network type: " + network);
getModel().set(_backend.buildNet(imageDataSet, opts, bparms, _classes, network));
} else {
Log.info("Creating a fresh model of the following network type: MLP");
getModel().set(_backend.buildNet(imageDataSet, opts, bparms, _classes, "MLP"));
}
}
// load a network if specified
final String networkDef = parameters._network_definition_file;
if (networkDef != null && !networkDef.isEmpty()) {
File f = new File(networkDef);
if(!f.exists() || f.isDirectory()) {
throw new RuntimeException("Network definition file " + f + " not found.");
} else {
Log.info("Loading the network from: " + f.getAbsolutePath());
Log.info("Setting the optimizer and initializing the first and last layer.");
getModel().set(_backend.buildNet(imageDataSet, opts, bparms, _classes, f.getAbsolutePath()));
}
}
if (parameters._mean_image_file != null && !parameters._mean_image_file.isEmpty())
imageDataSet.setMeanData(_backend.loadMeanImage(getModel().get(), parameters._mean_image_file));
_meanData = imageDataSet.getMeanData();
final String networkParms = parameters._network_parameters_file;
if (networkParms != null && !networkParms.isEmpty()) {
File f = new File(networkParms);
if(!paramFilesExist(networkParms)) {
throw new RuntimeException("Network parameter file " + f + " not found.");
} else {
Log.info("Loading the parameters (weights/biases) from: " + f.getAbsolutePath());
assert (getModel() != null);
_backend.loadParam(getModel().get(), f.getAbsolutePath());
}
} else {
Log.warn("No network parameters file specified. Starting from scratch.");
}
nativeToJava(); //store initial state as early as it's created
} catch(Throwable t) {
throw new RuntimeException("Unable to initialize the native Deep Learning backend: " + t.getMessage());
}
}
static boolean paramFilesExist(final String paramPath) {
final File f = new File(paramPath);
String[] list = f.getParentFile().list(new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.contains(f.getName());
}
});
return !f.isDirectory() && (f.exists() || (list != null && list.length > 0));
}
String getBasePath() {
// if (_backend instanceof DeepwaterCaffeBackend)
// return System.getProperty("user.dir") + "/caffe/";
// else
return System.getProperty("java.io.tmpdir");
}
void nativeToJava() {
if (_backend ==null) return;
Log.info("Native backend -> Java.");
long now = System.currentTimeMillis();
File file = null;
// only overwrite the network definition if it's null
if (_network==null) {
try {
file = new File(getBasePath(), Key.make().toString());
_backend.saveModel(getModel().get(), file.toString());
FileInputStream is = new FileInputStream(file);
_network = new byte[(int)file.length()];
is.read(_network);
is.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (file != null)
_backend.deleteSavedModel(file.toString());
}
}
// always overwrite the parameters (weights/biases)
try {
file = new File(getBasePath(), Key.make().toString());
_backend.saveParam(getModel().get(), file.toString());
_modelparams = _backend.readBytes(file);
} catch (IOException e) {
e.printStackTrace();
} finally {
if (file !=null)
_backend.deleteSavedParam(file.toString());
}
long time = System.currentTimeMillis() - now;
Log.info("Took: " + PrettyPrint.msecs(time, true));
}
/**
* Create native backend and fill it with the model's state stored in the Java model
*/
void javaToNative() {
javaToNative(null,null);
}
/**
* Internal helper to create a native backend, and fill its state
* @param network user-given network topology
* @param parameters user-given network state (weights/biases)
*/
private void javaToNative(byte[] network, byte[] parameters) {
long now = System.currentTimeMillis();
//existing state is fine
if (_backend !=null
// either not overwriting with user-given (new) state, or we already are in sync
&& (network == null || Arrays.equals(network,_network))
&& (parameters == null || Arrays.equals(parameters,_modelparams))) {
Log.warn("No need to move the state from Java to native.");
return;
}
if (_backend ==null) {
_backend = createDeepWaterBackend(get_params()._backend.toString()); // new ImageTrain(_width, _height, _channels, _deviceID, (int)parameters.getOrMakeRealSeed(), _gpu);
if (_backend == null) throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
}
if (network==null) network = _network;
if (parameters==null) parameters= _modelparams;
if (network==null || parameters==null) return;
Log.info("Java state -> native backend.");
initModel(network, parameters);
long time = System.currentTimeMillis() - now;
Log.info("Took: " + PrettyPrint.msecs(time, true));
}
void initModel() {
initModel(_network, _modelparams);
}
private void initModel(byte[] network, byte[] parameters) {
File file = null;
// only overwrite the network definition if it's null
try {
file = new File(getBasePath(), Key.make().toString() + ".json");
FileOutputStream os = new FileOutputStream(file);
os.write(network);
os.close();
// Log.info("Randomizing everything.");
getModel().set(_backend.buildNet(getImageDataSet(), getRuntimeOptions(), getBackendParams(), _classes, file.toString())); //randomizing initial state
} catch (IOException e) {
e.printStackTrace();
} finally {
if (file != null)
_backend.deleteSavedModel(file.toString());
}
// always overwrite the parameters (weights/biases)
try {
file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString());
_backend.writeBytes(file, parameters);
_backend.loadParam(getModel().get(), file.toString());
} catch (IOException e) {
e.printStackTrace();
} finally {
if (file != null)
_backend.deleteSavedParam(file.toString());
}
}
/**
* Create a summary table
* @return TwoDimTable with the summary of the model
*/
TwoDimTable createSummaryTable() {
TwoDimTable table = new TwoDimTable(
"Status of Deep Learning Model",
(get_params()._network == null ? ("MLP: " + Arrays.toString(get_params()._hidden)) : get_params()._network.toString()) + ", " + PrettyPrint.bytes(size()) + ", "
+ (!get_params()._autoencoder ? ("predicting " + get_params()._response_column + ", ") : "") +
(get_params()._autoencoder ? "auto-encoder" :
_classification ? (_classes + "-class classification") : "regression")
+ ", "
+ String.format("%,d", get_processed_global()) + " training samples, "
+ "mini-batch size " + String.format("%,d", get_params()._mini_batch_size),
new String[1], //rows
new String[]{"Input Neurons", "Rate", "Momentum" },
new String[]{"int", "double", "double" },
new String[]{"%d", "%5f", "%5f"},
"");
table.set(0, 0, _dataInfo!=null ? _dataInfo.fullN() : _width * _height * _channels);
table.set(0, 1, get_params().learningRate(get_processed_global()));
table.set(0, 2, get_params().momentum(get_processed_global()));
summaryTable = table;
return summaryTable;
}
/**
* Print a summary table
* @return String containing ASCII version of summary table
*/
@Override public String toString() {
StringBuilder sb = new StringBuilder();
if (!get_params()._quiet_mode) {
createSummaryTable();
if (summaryTable!=null) sb.append(summaryTable.toString(1));
}
return sb.toString();
}
/**
* Debugging printout
* @return String with useful info
*/
public String toStringAll() {
StringBuilder sb = new StringBuilder();
sb.append(toString());
sb.append("\nprocessed global: ").append(get_processed_global());
sb.append("\nprocessed local: ").append(get_processed_local());
sb.append("\nprocessed total: ").append(get_processed_total());
sb.append("\n");
return sb.toString();
}
public void add(DeepWaterModelInfo other) {
throw H2O.unimpl();
}
public void mult(double N) {
throw H2O.unimpl();
}
public void div(double N) {
throw H2O.unimpl();
}
}