package hex.klime; import hex.Model; import hex.ModelBuilder; import hex.ModelCategory; import hex.ModelMetrics; import hex.ModelMetricsSupervised; import hex.glm.GLM; import hex.glm.GLMModel; import hex.kmeans.KMeans; import hex.kmeans.KMeansModel; import water.*; import water.fvec.Frame; import water.fvec.Vec; import water.util.FrameUtils; import java.util.HashSet; import java.util.Set; import static hex.kmeans.KMeansModel.KMeansParameters; public class KLime extends ModelBuilder<KLimeModel, KLimeModel.KLimeParameters, KLimeModel.KLimeOutput> { @Override public ModelCategory[] can_build() { return new ModelCategory[]{ModelCategory.Regression}; } @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; } @Override public boolean isSupervised() { return true; } public KLime(boolean startup_once) { super(new KLimeModel.KLimeParameters(), startup_once); } public KLime(KLimeModel.KLimeParameters parms) { super(parms); init(false); } @Override public boolean haveMojo() { return true; } @Override protected Driver trainModelImpl() { return new KLimeDriver(); } private class KLimeDriver extends Driver { @Override public void computeImpl() { KLimeModel model = null; Set<Key<Frame>> frameKeys = new HashSet<>(); // temporary frame key try { init(true); // The model to be built model = new KLimeModel(dest(), _parms, new KLimeModel.KLimeOutput(KLime.this)); model.delete_and_lock(_job); _job.update(0, "Building global k-LIME regression model"); Key<Frame> globalKey = Key.make("klime_train_global_" + _parms._train); frameKeys.add(globalKey); DKV.put(globalKey, train()); Key<Model> globalRegressionKey = Key.make("klime_glm_global_" + model._key); Job globalJob = new Job<>(globalRegressionKey, ModelBuilder.javaName("glm"), "k-LIME Regression (Global GLM)"); GLM globalBuilder = ModelBuilder.make("GLM", globalJob, globalRegressionKey); _parms.fillRegressionParms(globalBuilder._parms, globalKey, false); final GLMModel globalRegressionModel = globalBuilder.trainModel().get(); _job.update(1); // global regression done final Frame adaptedTrain = new Frame(globalRegressionModel.names(), train().vecs(globalRegressionModel.names())); model._output.setNames(adaptedTrain._names); model._output._domains = adaptedTrain.domains(); Key<Frame> clusteringTrainKey = Key.<Frame>make("klime_clustering_" + _parms._train); frameKeys.add(clusteringTrainKey); Frame clusteringTrain = new Frame(clusteringTrainKey); clusteringTrain.add(adaptedTrain); clusteringTrain.remove(_parms._response_column); DKV.put(clusteringTrain); KMeansParameters kmeansParms = _parms.fillClusteringParms(new KMeansParameters(), clusteringTrain._key); KMeans clustering = new KMeans(kmeansParms, _job); KMeansModel clusteringModel = clustering.trainModelNested(null); Frame clusterLabels = Scope.track(clusteringModel.score(clusteringTrain)); final int K = clusteringModel._output._k[clusteringModel._output._k.length - 1]; model._output._clustering = clusteringModel; model._output._globalRegressionModel = globalRegressionModel; model._output._regressionModels = new GLMModel[K]; model.update(_job); // this calculates R2 on each cluster using a global model model.score(_parms.train()).delete(); // This scores on the training data and appends a ModelMetrics model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train()); String[] clusterNames = new String[K]; for (int i = 0; i < K; i++) clusterNames[i] = "cluster" + i; clusterLabels.vec(0).setDomain(clusterNames); DKV.put(clusterLabels); Frame clusterWeights = Scope.track(FrameUtils.categoricalEncoder(clusterLabels, new String[0], Model.Parameters.CategoricalEncodingScheme.OneHotExplicit, null)); ModelBuilder[] localBuilders = new ModelBuilder[K]; int localBuilderCnt = 0; for (int i = 0; i < K; i++) { Vec weightVec = clusterWeights.vec(clusterWeights.find("predict." + clusterNames[i])); if (weightVec.nzCnt() < _parms._min_cluster_size) continue; // do not build a local model for too small clusters localBuilderCnt++; Key<Frame> key = Key.<Frame>make("klime_train_cluster_" + i + "-" + _parms._train); frameKeys.add(key); Frame frame = new Frame(key); frame.add("__cluster_weights", weightVec); frame.add(train()); DKV.put(frame); Key<Model> glmKey = Key.<Model>make("klime_glm_cluster_" + i + "-" + model._key); Job glmJob = new Job<>(glmKey, ModelBuilder.javaName("glm"), "k-LIME Regression (GLM, cluster = " + i + ")"); DKV.put(glmJob); Scope.track_generic(glmJob); GLM glmBuilder = ModelBuilder.make("GLM", glmJob, glmKey); _parms.fillRegressionParms(glmBuilder._parms, key, true); localBuilders[i] = glmBuilder; } ModelBuilder[] allBuilders = new ModelBuilder[localBuilderCnt]; int localIdx = 0; for (ModelBuilder localBuilder : localBuilders) { if (localBuilder != null) allBuilders[localIdx++] = localBuilder; } assert localIdx == localBuilderCnt; bulkBuildModels(_job, allBuilders, 1); double global_r2 = ((ModelMetricsSupervised) globalRegressionModel._output._training_metrics).r2(); GLMModel[] regressionModels = new GLMModel[K]; for (int i = 0; i < localBuilders.length; i++) { if (localBuilders[i] != null) { GLMModel localModel = DKV.getGet(localBuilders[i]._job._result); double local_r2 = ((ModelMetricsSupervised) localModel._output._training_metrics).r2(); if (local_r2 > global_r2) regressionModels[i] = localModel; // local model is better, keep it else Scope.track_generic(localModel); // global model is better, delete the local one } else _job.update(1); // model won't be built } model._output._regressionModels = regressionModels; model.score(_parms.train()).delete(); // This scores on the training data and appends a ModelMetrics model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train()); model.update(_job); } finally { if (model != null) model.unlock(_job); Futures fs = new Futures(); for (Key<Frame> k : frameKeys) DKV.remove(k, fs); fs.blockForPending(); } } } }