package hex.grid;
import hex.*;
import water.*;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.util.*;
import water.util.PojoUtils.FieldNaming;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Collection;
/**
* A Grid of Models representing result of hyper-parameter space exploration.
* Lazily filled in, this object represents the potentially infinite variety
* of hyperparameters of a given model & dataset.
*
* @param <MP> type of model build parameters
*/
public class Grid<MP extends Model.Parameters> extends Lockable<Grid<MP>> {
/**
* Publicly available Grid prototype - used by REST API.
*
* @see hex.schemas.GridSchemaV99
*/
public static final Grid GRID_PROTO = new Grid(null, null, null, null);
// A cache of double[] hyper-parameters mapping to Models.
private final IcedHashMap<IcedLong, Key<Model>> _models = new IcedHashMap<>();
// Used "based" model parameters for this grid search.
private final MP _params;
// Failed model parameters - represents points in hyper space for which model
// generation failed. If the element is null, then look into
private MP[] _failed_params;
// Detailed messages about a failure for given failed model parameters in
// <code>_failed_params</code>.
private String[] _failure_details;
// Collected stack trace for failure.
private String[] _failure_stack_traces;
// Contains "raw" representation of parameters which fail The parameters are
// represented in textual form, since simple <code>java.lang.Object</code>
// cannot be serialized by H2O serialization.
private String[][] _failed_raw_params;
// Names of used hyper parameters for this grid search.
private final String[] _hyper_names;
private final FieldNaming _field_naming_strategy;
private ScoringInfo[] _scoring_infos = null;
/**
* Construct a new grid object to store results of grid search.
*
* @param key reference to this object
* @param params initial parameters used by grid search
* @param hyperNames names of used hyper parameters
*/
protected Grid(Key key, MP params, String[] hyperNames, FieldNaming fieldNaming) {
super(key);
_params = params != null ? (MP) params.clone() : null;
_hyper_names = hyperNames;
Class<MP> paramsClass = params != null ? (Class<MP>) params.getClass() : null;
_failed_params = paramsClass != null ? (MP[]) Array.newInstance(paramsClass, 0) : null;
_failure_details = new String[]{};
_failed_raw_params = new String[][]{};
_failure_stack_traces = new String[]{};
_field_naming_strategy = fieldNaming;
}
/**
* Returns name of model included in this object. Note: only sensible for
* Grids which search over a single class of Models.
*
* @return name of model (for example, "DRF", "GBM")
*/
public String getModelName() {
return _params.algoName();
}
public ScoringInfo[] getScoringInfos() {
return _scoring_infos;
}
public void setScoringInfos(ScoringInfo[] scoring_infos) {
this._scoring_infos = scoring_infos;
}
/**
* Ask the Grid for a suggested next hyperparameter value, given an existing Model as a starting
* point and the complete set of hyperparameter limits. Returning a NaN signals there is no next
* suggestion, which is reasonable if the obvious "next" value does not exist (e.g. exhausted all
* possibilities of an categorical). It is OK if a Model for the suggested value already exists; this
* will be checked before building any model.
*
* @param h The h-th hyperparameter
* @param m A model to act as a starting point
* @param hyperLimits Upper bounds for this search
* @return Suggested next value for hyperparameter h or NaN if no next value
protected double suggestedNextHyperValue(int h, Model m, double[] hyperLimits) {
throw H2O.fail();
}*/
/**
* Returns the data frame used to train all these models. <p> All models are trained on the same
* data frame, but might be validated on multiple different frames. </p>
*
* @return training frame shared among all models
*/
public Frame getTrainingFrame() {
return _params.train();
}
/**
* Returns model for given combination of model parameters or null if the model does not exist.
*
* @param params parameters of the model
* @return A model run with these parameters, or null if the model does not exist.
*/
public Model getModel(MP params) {
Key<Model> mKey = getModelKey(params);
return mKey != null ? mKey.get() : null;
}
public Key<Model> getModelKey(MP params) {
long checksum = params.checksum();
return getModelKey(checksum);
}
Key<Model> getModelKey(long paramsChecksum) {
Key<Model> mKey = _models.get(IcedLong.valueOf(paramsChecksum));
return mKey;
}
/* FIXME: should pass model parameters instead of checksum, but model
* parameters are not imutable and model builder modifies them! */
/* package */
synchronized Key<Model> putModel(long checksum, Key<Model> modelKey) {
return _models.put(IcedLong.valueOf(checksum), modelKey);
}
/**
* This method appends a new item to the list of failed model parameters.
* <p/>
* <p> The failed parameters object represents a point in hyper space which cannot be used for
* model building. </p>
*
* @param params model parameters which caused model builder failure, can be null
* @param rawParams array of "raw" parameter values
* @params failureDetails textual description of model building failure
* @params stackTrace stringify stacktrace
*/
private void appendFailedModelParameters(MP params, String[] rawParams, String failureDetails, String stackTrace) {
assert rawParams != null : "API has to always pass rawParams";
// Append parameter
MP[] a = _failed_params;
MP[] na = Arrays.copyOf(a, a.length + 1);
na[a.length] = params;
_failed_params = na;
// Append message
String[] m = _failure_details;
String[] nm = Arrays.copyOf(m, m.length + 1);
nm[m.length] = failureDetails;
_failure_details = nm;
// Append raw parames
String[][] rp = _failed_raw_params;
String[][] nrp = Arrays.copyOf(rp, rp.length + 1);
nrp[rp.length] = rawParams;
_failed_raw_params = nrp;
// Append stack trace
String[] st = _failure_stack_traces;
String[] nst = Arrays.copyOf(st, st.length + 1);
nst[st.length] = stackTrace;
_failure_stack_traces = nst;
}
/**
* This method appends a new item to the list of failed model parameters.
* <p/>
* <p> The failed parameters object represents a point in hyper space which cannot be used for
* model building.</p>
* <p/>
* <p> Should be used only from <code>GridSearch</code> job.</p>
*
* @param params model parameters which caused model builder failure
* @params e exception causing a failure
*/
void appendFailedModelParameters(MP params, Exception e) {
assert params != null : "Model parameters should be always != null !";
String[] rawParams = ArrayUtils.toString(getHyperValues(params));
appendFailedModelParameters(params, rawParams, e.getMessage(), StringUtils.toString(e));
}
/**
* This method appends a new item to the list of failed hyper-parameters.
* <p/>
* <p> The failed parameters object represents a point in hyper space which cannot be used to
* construct a new model parameters.</p>
* <p/>
* <p> Should be used only from <code>GridSearch</code> job.</p>
*
* @param rawParams list of "raw" hyper values which caused a failure to prepare model parameters
* @params e exception causing a failure
*/
/* package */ void appendFailedModelParameters(Object[] rawParams, Exception e) {
assert rawParams != null : "Raw parameters should be always != null !";
appendFailedModelParameters(null, ArrayUtils.toString(rawParams), e.getMessage(), StringUtils.toString(e));
}
/**
* Returns keys of all models included in this object.
*
* @return list of model keys
*/
public Key<Model>[] getModelKeys() {
return _models.values().toArray(new Key[_models.size()]);
}
/**
* Return all models included in this grid object.
*
* @return all models in this grid
*/
public Model[] getModels() {
Collection<Key<Model>> modelKeys = _models.values();
Model[] models = new Model[modelKeys.size()];
int i = 0;
for (Key<Model> mKey : modelKeys) {
models[i] = mKey != null ? mKey.get() : null;
i++;
}
return models;
}
/**
* Returns number of models in this grid.
*/
public int getModelCount() {
return _models.size();
}
/**
* Returns number of unsuccessful attempts to build a model.
*/
public int getFailureCount() {
return _failed_params.length;
}
/**
* Returns an array of model parameters which caused model build failure.
* <p/>
* The null-element in the array means, that model parameters cannot be constructed, and the
* client should use {@link #getFailedParameters()} to obtain "raw" model parameters.
* <p/>
* Note: cannot return <code>MP[]</code> because of PUBDEV-1863 See:
* https://0xdata.atlassian.net/browse/PUBDEV-1863
*/
public Model.Parameters[] getFailedParameters() {
return _failed_params;
}
/**
* Returns detailed messages about model build failures.
*/
public String[] getFailureDetails() {
return _failure_details;
}
/**
* Returns string representation of model build failures'
* stack traces.
*/
public String[] getFailureStackTraces() {
return _failure_stack_traces;
}
/**
* Returns list of raw model parameters causing model building failure.
*/
public String[][] getFailedRawParameters() {
return _failed_raw_params;
}
/**
* Return value of hyper parameters used for this grid search.
*
* @param parms model parameters
* @return values of hyper parameters used by grid search producing this grid object.
*/
public Object[] getHyperValues(MP parms) {
Object[] result = new Object[_hyper_names.length];
for (int i = 0; i < _hyper_names.length; i++) {
result[i] = PojoUtils.getFieldValue(parms, _hyper_names[i], _field_naming_strategy);
}
return result;
}
/**
* Returns an array of used hyper parameters names.
*
* @return names of hyper parameters used in this hyper search
*/
public String[] getHyperNames() {
return _hyper_names;
}
// Cleanup models and grid
@Override
protected Futures remove_impl(final Futures fs) {
for (Key<Model> k : _models.values())
k.remove(fs);
_models.clear();
return fs;
}
/**
* Write out K/V pairs
*/
@Override
protected AutoBuffer writeAll_impl(AutoBuffer ab) {
for (Key<Model> k : _models.values())
ab.putKey(k);
return super.writeAll_impl(ab);
}
@Override
protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
throw H2O.unimpl();
}
@Override
protected long checksum_impl() {
throw H2O.unimpl();
}
@Override
public Class<KeyV3.GridKeyV3> makeSchema() {
return KeyV3.GridKeyV3.class;
}
public TwoDimTable createSummaryTable(Key<Model>[] model_ids, String sort_by, boolean decreasing) {
if (_hyper_names == null || model_ids == null || model_ids.length == 0) return null;
int extra_len = sort_by != null ? 2 : 1;
String[] colTypes = new String[_hyper_names.length + extra_len];
Arrays.fill(colTypes, "string");
String[] colFormats = new String[_hyper_names.length + extra_len];
Arrays.fill(colFormats, "%s");
String[] colNames = Arrays.copyOf(_hyper_names, _hyper_names.length + extra_len);
colNames[_hyper_names.length] = "model_ids";
if (sort_by != null)
colNames[_hyper_names.length + 1] = sort_by;
TwoDimTable table = new TwoDimTable("Hyper-Parameter Search Summary",
sort_by != null ? "ordered by " + (decreasing ? "decreasing " : "increasing ") + sort_by : null,
new String[_models.size()], colNames, colTypes, colFormats, "");
int i = 0;
for (Key<Model> km : model_ids) {
Model m = DKV.getGet(km);
Model.Parameters parms = m._parms;
int j;
for (j = 0; j < _hyper_names.length; ++j)
table.set(i, j, PojoUtils.getFieldValue(parms, _hyper_names[j], _field_naming_strategy));
table.set(i, j, km.toString());
if (sort_by != null) table.set(i, j + 1, ModelMetrics.getMetricFromModel(km, sort_by));
i++;
}
Log.info(table);
return table;
}
public TwoDimTable createScoringHistoryTable() {
if (0 == _models.values().size()) {
return ScoringInfo.createScoringHistoryTable(_scoring_infos, false, false, ModelCategory.Binomial, false);
}
Key<Model> k = null;
for (Key<Model> foo : _models.values()) {
k = foo;
break;
}
Model m = k.get();
if (null == m) {
Log.warn("Cannot create grid scoring history table; Model has been removed: " + k);
return ScoringInfo.createScoringHistoryTable(_scoring_infos, false, false, ModelCategory.Binomial, false);
}
ScoringInfo scoring_info = _scoring_infos != null && _scoring_infos.length > 0 ? _scoring_infos[0] : null;
return ScoringInfo.createScoringHistoryTable(_scoring_infos, (scoring_info != null ? scoring_info.validation : false), (scoring_info != null ? scoring_info.cross_validation: false), m._output.getModelCategory(), (scoring_info != null ? scoring_info.is_autoencoder : false));
}
}