package hex.anomaly;
import hex.deeplearning.DeepLearningModel;
import water.*;
import water.api.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import java.util.HashSet;
/**
* Deep Learning Based Anomaly Detection
*/
public class Anomaly extends Job.FrameJob {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
public static DocGen.FieldDoc[] DOC_FIELDS;
public static final String DOC_GET = "Anomaly Detection via Deep Learning";
@API(help = "Deep Learning Auto-Encoder Model ", required=true, filter= Default.class, json = true)
public Key dl_autoencoder_model;
@API(help = "(Optional) Threshold of reconstruction error for rows to be displayed in logs (default: 10x training MSE)", filter= Default.class, json = true)
public double thresh = -1;
@Override
protected final void execImpl() {
if (dl_autoencoder_model == null) throw new IllegalArgumentException("Deep Learning Model must be specified.");
DeepLearningModel dlm = UKV.get(dl_autoencoder_model);
if (dlm == null) throw new IllegalArgumentException("Deep Learning Model not found.");
if (!dlm.get_params().autoencoder) throw new IllegalArgumentException("Deep Learning Model must be build with autoencoder = true.");
if (thresh == -1) {
Log.info("Mean reconstruction error (MSE) of model on training data: " + dlm.mse());
thresh = 10*dlm.mse();
Log.info("Setting MSE threshold for anomaly to: " + thresh + ".");
}
StringBuilder sb = new StringBuilder();
sb.append("\nFinding outliers in frame " + source._key.toString() + ".\n");
Frame mse = dlm.scoreAutoEncoder(source);
sb.append("Storing the reconstruction error (MSE) for all rows under: " + dest() + ".\n");
Frame output = new Frame(dest(), new String[]{"Reconstruction.MSE"}, new Vec[]{mse.vecs()[0]});
output.delete_and_lock(null);
output.unlock(null);
final Vec mse_test = mse.anyVec();
sb.append("Mean reconstruction error (MSE): " + mse_test.mean() + ".\n");
// print stats and potential outliers
sb.append("The following data points have a reconstruction error greater than " + thresh + ":\n");
HashSet<Long> outliers = new HashSet<Long>();
for( long i=0; i<mse_test.length(); i++ ) {
if (mse_test.at(i) > thresh) {
outliers.add(i);
sb.append(String.format("row %d : MSE = %5f\n", i, mse_test.at(i)));
}
}
Log.info(sb);
}
}