package hex.genmodel.algos.klime; import hex.genmodel.MojoModel; import hex.genmodel.algos.glm.GlmMojoModel; public class KLimeMojoModel extends MojoModel { MojoModel _clusteringModel; MojoModel _globalRegressionModel; MojoModel[] _clusterRegressionModels; KLimeMojoModel(String[] columns, String[][] domains) { super(columns, domains); } @Override public double[] score0(double[] row, double[] preds) { assert preds.length == row.length + 2; System.arraycopy(row, 0, preds, 2, row.length); _clusteringModel.score0(row, preds); final int cluster = (int) preds[0]; GlmMojoModel regressionModel = getRegressionModel(cluster); System.arraycopy(preds, 2, row, 0, row.length); regressionModel.score0(row, preds); preds[1] = cluster; for (int i = 2; i < preds.length; i++) preds[i] = Double.NaN; // preds = {prediction, cluster, NaN, ..., NaN) regressionModel.applyCoefficients(row, preds, 2); // preds = {prediction, cluster, reason code 1, ..., reason code N} return preds; } GlmMojoModel getRegressionModel(int cluster) { return (GlmMojoModel) (_clusterRegressionModels[cluster] != null ? _clusterRegressionModels[cluster] : _globalRegressionModel); } @Override public int getPredsSize() { return nfeatures() + 2; } }