package hex.genmodel;
import hex.genmodel.utils.ParseUtils;
import hex.genmodel.utils.StringEscapeUtils;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* Helper class to deserialize a model from MOJO format. This is a counterpart to `ModelMojoWriter`.
*/
public abstract class ModelMojoReader<M extends MojoModel> {
protected M _model;
protected MojoReaderBackend _reader;
private Map<String, Object> _lkv;
public static MojoModel readFrom(MojoReaderBackend reader) throws IOException {
Map<String, Object> info = parseModelInfo(reader);
if (! info.containsKey("algorithm"))
throw new IllegalStateException("Unable to find information about the model's algorithm.");
String algo = String.valueOf(info.get("algorithm"));
ModelMojoReader mmr = ModelMojoFactory.getMojoReader(algo);
mmr._lkv = info;
mmr._reader = reader;
try {
mmr.readAll();
} finally {
if (mmr instanceof Closeable)
((Closeable) mmr).close();
}
return mmr._model;
}
//--------------------------------------------------------------------------------------------------------------------
// Inheritance interface: ModelMojoWriter subclasses are expected to override these methods to provide custom behavior
//--------------------------------------------------------------------------------------------------------------------
protected abstract void readModelData() throws IOException;
protected abstract M makeModel(String[] columns, String[][] domains);
//--------------------------------------------------------------------------------------------------------------------
// Interface for subclasses
//--------------------------------------------------------------------------------------------------------------------
/**
* Retrieve value from the model's kv store which was previously put there using `writekv(key, value)`. We will
* attempt to cast to your expected type, but this is obviously unsafe. Note that the value is deserialized from
* the underlying string representation using {@link ParseUtils#tryParse(String, Object)}, which occasionally may get
* the answer wrong.
* If the `key` is missing in the local kv store, null will be returned. However when assigning to a primitive type
* this would result in an NPE, so beware.
*/
@SuppressWarnings("unchecked")
protected <T> T readkv(String key) {
return (T) readkv(key, null);
}
/**
* Retrieves the value associated with a given key. If value is not set of the key, a given default value is returned
* instead. Uses same parsing logic as {@link ModelMojoReader#readkv(String)}. If default value is not null it's type
* is used to assist the parser to determine the return type.
* @param key name of the key
* @param defVal default value
* @param <T> return type
* @return parsed value
*/
@SuppressWarnings("unchecked")
protected <T> T readkv(String key, T defVal) {
Object val = _lkv.get(key);
if (! (val instanceof RawValue))
return val != null ? (T) val : defVal;
return ((RawValue) val).parse(defVal);
}
/**
* Retrieve binary data previously saved to the mojo file using `writeblob(key, blob)`.
*/
protected byte[] readblob(String name) throws IOException {
return _reader.getBinaryFile(name);
}
protected boolean exists(String name) {
return _reader.exists(name);
}
/**
* Retrieve text previously saved using `startWritingTextFile` + `writeln` as an array of lines. Each line is
* trimmed to remove the leading and trailing whitespace.
*/
protected Iterable<String> readtext(String name) throws IOException {
return readtext(name, false);
}
/**
* Retrieve text previously saved using `startWritingTextFile` + `writeln` as an array of lines. Each line is
* trimmed to remove the leading and trailing whitespace. Removes escaping of the new line characters in enabled.
*/
protected Iterable<String> readtext(String name, boolean unescapeNewlines) throws IOException {
BufferedReader br = _reader.getTextFile(name);
String line;
ArrayList<String> res = new ArrayList<>(50);
while (true) {
line = br.readLine();
if (line == null) break;
if (unescapeNewlines)
line = StringEscapeUtils.unescapeNewlines(line);
res.add(line.trim());
}
return res;
}
//--------------------------------------------------------------------------------------------------------------------
// Private
//--------------------------------------------------------------------------------------------------------------------
private void readAll() throws IOException {
String[] columns = (String[]) _lkv.get("[columns]");
String[][] domains = parseModelDomains(columns.length);
_model = makeModel(columns, domains);
_model._uuid = readkv("uuid");
_model._category = hex.ModelCategory.valueOf((String) readkv("category"));
_model._supervised = readkv("supervised");
_model._nfeatures = readkv("n_features");
_model._nclasses = readkv("n_classes");
_model._balanceClasses = readkv("balance_classes");
_model._defaultThreshold = readkv("default_threshold");
_model._priorClassDistrib = readkv("prior_class_distrib");
_model._modelClassDistrib = readkv("model_class_distrib");
_model._offsetColumn = readkv("offset_column");
readModelData();
}
private static Map<String, Object> parseModelInfo(MojoReaderBackend reader) throws IOException {
BufferedReader br = reader.getTextFile("model.ini");
Map<String, Object> info = new HashMap<>();
String line;
int section = 0;
int ic = 0; // Index for `columns` array
String[] columns = new String[0]; // array of column names, will be initialized later
Map<Integer, String> domains = new HashMap<>(); // map of (categorical column index => name of the domain file)
while (true) {
line = br.readLine();
if (line == null) break;
line = line.trim();
if (line.startsWith("#") || line.isEmpty()) continue;
if (line.equals("[info]"))
section = 1;
else if (line.equals("[columns]")) {
section = 2; // Enter the [columns] section
if (! info.containsKey("n_columns"))
throw new IOException("`n_columns` variable is missing in the model info.");
int n_columns = Integer.parseInt(((RawValue) info.get("n_columns"))._val);
columns = new String[n_columns];
info.put("[columns]", columns);
} else if (line.equals("[domains]")) {
section = 3; // Enter the [domains] section
info.put("[domains]", domains);
} else if (section == 1) {
// [info] section: just parse key-value pairs and store them into the `info` map.
String[] res = line.split("\\s*=\\s*", 2);
info.put(res[0], res[0].equals("uuid")? res[1] : new RawValue(res[1]));
} else if (section == 2) {
// [columns] section
if (ic >= columns.length)
throw new IOException("`n_columns` variable is too small.");
columns[ic++] = line;
} else if (section == 3) {
// [domains] section
String[] res = line.split(":\\s*", 2);
int col_index = Integer.parseInt(res[0]);
domains.put(col_index, res[1]);
}
}
return info;
}
private String[][] parseModelDomains(int n_columns) throws IOException {
String[][] domains = new String[n_columns][];
// noinspection unchecked
Map<Integer, String> domass = (Map<Integer, String>) _lkv.get("[domains]");
for (Map.Entry<Integer, String> e : domass.entrySet()) {
int col_index = e.getKey();
// There is a file with categories of the response column, but we ignore it.
if (col_index >= n_columns) continue;
String[] info = e.getValue().split(" ", 2);
int n_elements = Integer.parseInt(info[0]);
String domfile = info[1];
String[] domain = new String[n_elements];
BufferedReader br = _reader.getTextFile("domains/" + domfile);
String line;
int id = 0; // domain elements counter
while (true) {
line = br.readLine();
if (line == null) break;
domain[id++] = line;
}
if (id != n_elements)
throw new IOException("Not enough elements in the domain file");
domains[col_index] = domain;
}
return domains;
}
private static class RawValue {
private final String _val;
RawValue(String val) { _val = val; }
@SuppressWarnings("unchecked")
<T> T parse(T defVal) { return (T) ParseUtils.tryParse(_val, defVal); }
@Override
public String toString() { return _val; }
}
}