package hex.klime;
import hex.*;
import hex.glm.GLMModel;
import hex.klime.KLimeModel.*;
import hex.kmeans.KMeansModel;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import java.util.*;
import static hex.kmeans.KMeansModel.KMeansParameters;
public class KLimeModel extends Model<KLimeModel, KLimeParameters, KLimeOutput> {
public KLimeModel(Key<KLimeModel> selfKey, KLimeParameters params, KLimeOutput output) {
super(selfKey, params, output);
assert(Arrays.equals(_key._kb, selfKey._kb));
}
@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
return new KLimeMetricBuilder(_output._regressionModels.length, _output._names.length - 1);
}
@Override
public GridSortBy getDefaultGridSortBy() {
return GridSortBy.R2;
}
@Override
public double[] score0(Chunk[] chks, double weight, double offset, int row_in_chunk, double[] tmp, double[] preds) {
final double[] ps = _output._clustering.score0(chks, weight, offset, row_in_chunk, tmp, preds);
final int cluster = (int) ps[0];
final GLMModel m = _output.getClusterModel(cluster);
// preds[0] = value predicted by regression
m.score0(chks, weight, offset, row_in_chunk, tmp, preds);
// preds[1] = cluster id
preds[1] = cluster;
// preds[2..n] = glm terms
final DataInfo dinfo = m.dinfo();
final double[] b = m.beta();
int p = 2;
for (int i = 0; i < dinfo._cats; i++) {
int l = dinfo.getCategoricalId(i, tmp[i]);
preds[p++] = (l >= 0) ? b[l] : Double.NaN;
}
int numStart = dinfo.numStart();
for (int i = 0; i < dinfo._nums; i++) {
double d = tmp[dinfo._cats + i];
if (! dinfo._skipMissing && Double.isNaN(d))
d = dinfo._numMeans[i];
preds[p++] = b[numStart + i] * d;
}
return preds;
}
@Override
protected String[] makeScoringNames() {
String[] names = new String[_output._names.length + 1];
int offset = 0;
names[offset++] = "predict_klime";
names[offset++] = "cluster_klime";
for (int i = 0; i < _output._names.length - 1; i++) // last item of _output._names is the response column, remove it
names[offset++] = "rc_" + _output._names[i];
return names;
}
@Override
protected double[] score0(double[] data, double[] preds) {
throw H2O.unimpl("Intentionally not implemented, should never be called!");
}
@Override
public double deviance(double w, double y, double f) {
return (y - f) * (y - f);
}
public static class KLimeParameters extends Model.Parameters {
public String algoName() { return "KLime"; }
public String fullName() { return "k-LIME"; }
public String javaName() { return KLimeModel.class.getName(); }
public int _min_cluster_size = 20;
public int _max_k = 20;
public double _alpha = 0.5;
public boolean _estimate_k = true;
@Override
public long progressUnits() {
return fillClusteringParms(new KMeansParameters(), null).progressUnits() + _max_k /*local GLMs*/ + 1 /*global GLM*/;
}
KMeansParameters fillClusteringParms(KMeansParameters p, Key<Frame> clusteringFrameKey) {
p._k = _max_k;
p._estimate_k = _estimate_k;
p._train = clusteringFrameKey;
p._auto_rebalance = false;
p._seed = _seed;
return p;
}
GLMModel.GLMParameters fillRegressionParms(GLMModel.GLMParameters p, Key<Frame> frameKey, boolean isWeighted) {
p._family = GLMModel.GLMParameters.Family.gaussian;
p._alpha = new double[] {_alpha};
p._lambda_search = true;
p._intercept = true;
p._train = frameKey;
p._response_column = _response_column;
if (isWeighted)
p._weights_column = "__cluster_weights";
p._auto_rebalance = false;
p._seed = _seed;
return p;
}
}
public static class KLimeOutput extends Model.Output {
public KLimeOutput(KLime b) { super(b); }
public KMeansModel _clustering;
public GLMModel _globalRegressionModel;
public GLMModel[] _regressionModels;
public GLMModel getClusterModel(int cluster) {
if ((cluster < 0) || (cluster >= _regressionModels.length)) {
throw new IllegalStateException("Unknown cluster, cluster id = " + cluster);
}
return _regressionModels[cluster] != null ? _regressionModels[cluster] : _globalRegressionModel;
}
@Override public ModelCategory getModelCategory() { return ModelCategory.Regression; }
}
@Override
protected Futures remove_impl(Futures fs) {
if (_output._clustering != null)
_output._clustering.remove(fs);
if (_output._globalRegressionModel != null)
_output._globalRegressionModel.remove(fs);
if (_output._regressionModels != null)
for (Model m : _output._regressionModels)
if (m != null)
m.remove(fs);
return super.remove_impl(fs);
}
@Override
public KLimeMojoWriter getMojo() {
return new KLimeMojoWriter(this);
}
public static class ModelMetricsKLime extends ModelMetricsRegression {
public ModelMetricsRegression[] _clusterMetrics;
public boolean[] _usesGlobalModel;
public ModelMetricsKLime(Model model, Frame frame,
ModelMetricsRegression globalMetrics, ModelMetricsRegression[] clusterMetrics,
boolean[] usesGlobalModel) {
super(model, frame,
globalMetrics._nobs, globalMetrics.mse(), globalMetrics._sigma, globalMetrics.mae(), globalMetrics.rmsle(),
globalMetrics._mean_residual_deviance);
_clusterMetrics = clusterMetrics;
_usesGlobalModel = usesGlobalModel;
}
}
private static class KLimeMetricBuilder extends ModelMetricsRegression.MetricBuilderRegression<KLimeMetricBuilder> {
private ModelMetricsRegression.MetricBuilderRegression[] _clusterMBs;
public KLimeMetricBuilder() {} // externalizable constructor
private KLimeMetricBuilder(int k, int numCodes) {
_clusterMBs = new ModelMetricsRegression.MetricBuilderRegression[k];
for (int i = 0; i < k; i++)
_clusterMBs[i] = new ModelMetricsRegression.MetricBuilderRegression();
_work = new double[1 /*predict_klime*/ + 1 /*cluster_klime*/ + numCodes];
}
@Override
public double[] perRow(double[] ds, float[] yact, double w, double o, Model m) {
int cluster = (int) ds[1];
assert cluster == ds[1] && cluster >= 0 && cluster < _clusterMBs.length;
_clusterMBs[cluster].perRow(ds, yact, w, o, m);
return super.perRow(ds, yact, w, o, m);
}
@Override
@SuppressWarnings("unchecked")
public void reduce(KLimeMetricBuilder mb) {
for (int i = 0; i < _clusterMBs.length; i++)
_clusterMBs[i].reduce(mb._clusterMBs[i]);
super.reduce(mb);
}
@Override
public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
ModelMetricsRegression globalMetrics = (ModelMetricsRegression) super.makeModelMetrics(null, f, null, null);
ModelMetricsRegression[] clusterMetrics = new ModelMetricsRegression[_clusterMBs.length];
boolean[] usesGlobalModel = new boolean[_clusterMBs.length];
for (int i = 0; i < _clusterMBs.length; i++) {
clusterMetrics[i] = (ModelMetricsRegression) _clusterMBs[i].makeModelMetrics(null, f, null, null);
usesGlobalModel[i] = (m != null) && (((KLimeModel) m)._output._regressionModels[i] == null);
}
ModelMetricsKLime mm = new ModelMetricsKLime(m, f, globalMetrics, clusterMetrics, usesGlobalModel);
if (m != null) m.addModelMetrics(mm);
return mm;
}
}
}