package sample.simple.service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.http.*;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;
import java.io.File;
import java.io.PrintWriter;
import java.util.HashMap;
/**
* Created by paragsanghavi on 2/3/15.
*
* Parse csv files with a specified separator
* Select columns and specify label column
* Set a model (GBM) with certain parameters
* Train the model
* Test the model on another data set
* Calculate AUC score.
* Dump model java file
*/
@Service
public class h2oService {
static final String H2O_HOST_URL = "http://localhost:54321";
static final String H2O_IMPORT_URL = "/2/ImportFiles2.json?path=" ;
static final String H2O_PARSE_URL = "/2/Parse2.json?source_key="; //http://localhost:54321/2/Parse2.query?source_key=nfs://tmp/etsy_images/image_deep_features_csv
static final String H2O_PROGRESS_URL = "/2/Progress2.json?"; //job_key=%240301ac10022e32d4ffffffff%24_9c2f4bf32b3bd2471dec44fc936d4363&destination_key=image_deep_features_csv.hex
static final String H2O_GBM_MODEL_URL = "/2/GBM.json?";
static final String H2O_GBM_MODEL_STATUS_URL = "/2/GBMProgressPage.json?";
static final String H2O_GBM_MODEL_PREDICT_URL = "/2/Predict.json?";
static final String H2O_GBM_MODEL_PREDICT_STATUS_URL = "/2/Inspect2.json?";
static final String H2O_GBM_MODEL_AUC_URL = "/2/AUC.json?";
static final String H2O_GBM_MODEL_POJO_URL = "/2/GBMModelView.java?_modelKey=" ;
private final Logger log = LoggerFactory.getLogger(h2oService.class);
public String ImportCSVFile (String path){
//example - url http://localhost:54321/2/ImportFiles2.json?path=%2Ftmp%2Fetsy_images%2Fimage_deep_features_csv#
String key;
String h2oUrlImportEndPoint = H2O_HOST_URL + H2O_IMPORT_URL + path;
log.debug("@@@ Calling endpoint {}", h2oUrlImportEndPoint);
RestTemplate restTemplate = new RestTemplate();
String result = restTemplate.getForObject(h2oUrlImportEndPoint, String.class);
//{"Request2":0,"response_info":{"h2o":"paragsanghavi","node":"/172.16.2.45:54321","time":1,
// "status":"done","redirect_url":null},"prefix":"nfs://tmp/etsy_images/deep_features_csv",
// "files":["/tmp/etsy_images/deep_features_csv"],"keys":["nfs://tmp/etsy_images/deep_features_csv"],
// "fails":[],"dels":["nfs://tmp/etsy_images/deep_features_csv"]}
log.debug("@@@ Response json from h2o {}", result);
JSONObject jsonobject = new JSONObject(result);
JSONObject response_info = (JSONObject)jsonobject.get("response_info");
String status = (String)response_info.get("status");
log.debug("!!!!!! Import Status {}", status);
if (status.equalsIgnoreCase("DONE")) {
JSONArray jsonarray = (JSONArray) jsonobject.get("keys");
key = (String) jsonarray.get(0);
System.out.println("!!!!!! Import key : " + key);
log.debug("!!!!!! Import key {}", key);
return key;
}
else{
return "error";
}
}
public String ParseCSVFile (String key, String framename ) throws org.json.JSONException{
//http://localhost:54321/2/Parse2.query?source_key=nfs://tmp/etsy_images/image_deep_features_csv
String h2oUrlImportEndPoint = H2O_HOST_URL + H2O_PARSE_URL+key ;
log.debug("@@@ Calling endpoint {}", h2oUrlImportEndPoint);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<String, String>();
/*parser_type:CSV
separator:44
header:1
header_with_hash:0
single_quotes:0
header_from_file:
exclude:
source_key:nfs://tmp/etsy_images/deep_features_csv
destination_key:deep_features_csv.hex
preview:
delete_on_done:1*/
parameters.add("parser_type", "CSV");
parameters.add("separator", "44");
parameters.add("header", "1");
parameters.add("singleQuotes", "0");
parameters.add("source_key", key) ;
parameters.add("destination_key", framename);
parameters.add("delete_on_done", "true");
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.add("Accept", MediaType.APPLICATION_JSON_VALUE);
HttpEntity<MultiValueMap<String, String>> request = new HttpEntity<MultiValueMap<String, String>>(parameters, headers);
ResponseEntity<String> response = restTemplate.exchange(h2oUrlImportEndPoint, HttpMethod.GET, request, String.class);
String responseBody = response.getBody();
/*
{"Request2":0,
"response_info":{"h2o":"paragsanghavi","node":"/172.16.2.46:54321","time":25,"status":"redirect",
"redirect_url":"/2/Progress2.json?job_key=%240301ac10022e32d4ffffffff%24_96ec9394a616c1e5d4ff8a63f17b428e&destination_key=image_deep_features_csv.hex"},
"job_key":"$0301ac10022e32d4ffffffff$_96ec9394a616c1e5d4ff8a63f17b428e",
"destination_key":"image_deep_features_csv.hex"}
*/
log.debug("@@@ Response json from h2o {}", responseBody);
JSONObject jsonobject = new JSONObject(responseBody);
String job_key = (String)jsonobject.get("job_key");
log.debug("!!!!!! Job name {}", job_key);
String destination_key= (String)jsonobject.get("destination_key");
String job_status = JobStatus(job_key, destination_key);
if(job_status !=null) {
return destination_key;
}
return null;
}
public String JobStatus( String job_key, String destination_key ) {
String status ;
String h2oUrlJobStatusEndPoint = H2O_HOST_URL + H2O_PROGRESS_URL + "job_key=" + job_key + "&destination_key=" + destination_key;
System.out.println(h2oUrlJobStatusEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlJobStatusEndPoint);
RestTemplate restTemplate = new RestTemplate();
try {
while (true) {
String responseBody = restTemplate.getForObject(h2oUrlJobStatusEndPoint, String.class);
JSONObject jsonobject = new JSONObject(responseBody);
JSONObject response_info = (JSONObject)jsonobject.get("response_info");
status = (String)response_info.get("status");
log.debug("!!!!!! JOB Status {}", status);
if (status.equalsIgnoreCase("redirect")) {
break;
}
Thread.sleep(2000L); //Should use futures here
}
}catch(Exception ex){
log.debug("!!!!!! Error Occured while getting job status {}", ex);
return null;
}
return status;
}
public String BuildGBMModel(String destination_key,String source_key) {
/*http://localhost:54321/2/GBM.html?destination_key=gbm&source=prostate_csv2.hex&response=CAPSULE&ignored_cols=0&
classification=1&validation=&n_folds=0&holdout_fraction=.1&keep_cross_validation_splits=0&ntrees=50&max_depth=5&min_rows=10&nbins=20&score_each_iteration=0&importance=1&balance_classes=0
&class_sampling_factors=&max_after_balance_size=Infinity&checkpoint=&overwrite_checkpoint=1&family=AUTO&learn_rate=0.1&grid_parallelism=1&seed=-1&group_split=1
*/
String h2oUrlGBMEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_URL +
"destination_key={destination_key}&source={source}&response={response}&ignored_cols={ignored_cols}" +
"&classification={classification}&validation={validation}" +
"&ntrees={ntrees}&max_depth={max_depth}&min_rows={min_rows}&nbins={nbins}&score_each_iteration={score_each_iteration}" +
"&importance={importance}&balance_classes={balance_classes}" +
"&class_sampling_factors={class_sampling_factors}&max_after_balance_size={max_after_balance_size}&checkpoint={checkpoint}" +
"&overwrite_checkpoint={overwrite_checkpoint}&family={family}&learn_rate={learn_rate}&grid_parallelism={grid_parallelism}&seed={seed}&group_split={group_split}";
System.out.println("@@@ h2oUrlGBMEndPoint : " + h2oUrlGBMEndPoint);
final HashMap<String, String> parameters = new HashMap<String, String>();
//MultiValueMap<String, String> parameters = new LinkedMultiValueMap<String, String>();
parameters.put("destination_key", destination_key);
parameters.put("source", source_key);
parameters.put("response", "CAPSULE");
parameters.put("ignored_cols", "0");
parameters.put("classification", "1"); // this means regression
parameters.put("validation", "");
parameters.put("n_folds", "10");
//parameters.put("holdout_fraction", "0.1");
parameters.put("keep_cross_validation_splits", "0");
parameters.put("ntrees", "50");
parameters.put("max_depth", "5");
parameters.put("min_rows", "10");
parameters.put("nbins", "20");
parameters.put("score_each_iteration", "0");
parameters.put("importance", "1");
parameters.put("balance_classes", "0");
parameters.put("class_sampling_factors","");
parameters.put("max_after_balance_size","Infinity");
parameters.put("checkpoint","");
parameters.put("overwrite_checkpoint","");
parameters.put("family", "AUTO");
parameters.put("learn_rate", "0.1");
parameters.put("grid_parallelism", "1");
parameters.put("seed", "-1");
parameters.put("group_split", "1");
try {
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.add("Accept", MediaType.APPLICATION_JSON_VALUE);
//HttpEntity<MultiValueMap<String, String>> request = new HttpEntity<MultiValueMap<String, String>>(parameters, headers);
//ResponseEntity<String> response = restTemplate.exchange(h2oUrlGBMEndPoint, HttpMethod.GET, request, String.class);
ResponseEntity<String> responseEntity = restTemplate.getForEntity(h2oUrlGBMEndPoint, String.class, parameters);
String responseBody = responseEntity.getBody();
JSONObject jsonobject = new JSONObject(responseBody);
String job_key = (String)jsonobject.get("job_key");
String ret_destination_key = (String)jsonobject.get("destination_key");
System.out.println("!!!!!! GBM Job Key : " + job_key);
System.out.println("!!!!!! GBM Destination Key : " + ret_destination_key);
String gbm_status = GBMJobStatus(job_key,ret_destination_key);
if(gbm_status!=null){
System.out.println("gbm_status : " + gbm_status);
}
return ret_destination_key;
} catch (HttpClientErrorException e) {
log.debug("Error occured in building model {}", e.getResponseBodyAsString());
log.debug("Root cause in GBM {}", e.getRootCause().getMessage());
e.printStackTrace();
return null;
} catch (Exception ex) {
log.debug("!!!!!Error occured in deep learning {}", ex.getMessage());
ex.printStackTrace();
return null;
}
}
public String GBMJobStatus( String job_key, String destination_key ) {
String status ;
String h2oUrlJobStatusEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_STATUS_URL + "job_key=" + job_key + "&destination_key=" + destination_key;
System.out.println(h2oUrlJobStatusEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlJobStatusEndPoint);
RestTemplate restTemplate = new RestTemplate();
try {
while (true) {
String responseBody = restTemplate.getForObject(h2oUrlJobStatusEndPoint, String.class);
JSONObject jsonobject = new JSONObject(responseBody);
JSONObject response_info = (JSONObject)jsonobject.get("response_info");
status = (String)response_info.get("status");
log.debug("!!!!!! JOB Status {}", status);
if (status.equalsIgnoreCase("redirect")) {
break;
}
Thread.sleep(2000L); //Should use futures here
}
}catch(Exception ex){
log.debug("!!!!!! Error Occured while getting job status {}", ex);
return null;
}
return status;
}
public String PredictGBM( String model, String new_data_key ) {
//http://localhost:54321/2/Predict.html?model=gbmmodelDestinationKey&data=prostate_csv.hex&prediction=predict_1
String status ;
String inspect_status;
String prediction_name = "Predict_GBM";
String h2oUrlPredictEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_PREDICT_URL + "model=" + model + "&data=" + new_data_key + "&prediction=" + prediction_name;
System.out.println(h2oUrlPredictEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlPredictEndPoint);
try {
RestTemplate restTemplate = new RestTemplate();
String responseBody = restTemplate.getForObject(h2oUrlPredictEndPoint, String.class);
JSONObject jsonobject = new JSONObject(responseBody);
JSONObject response_info = (JSONObject)jsonobject.get("response_info");
status = (String)response_info.get("status");
System.out.println("PREDICT GBM status: " + status);
inspect_status = PredictGBMStatus(prediction_name);
}catch(Exception ex){
log.debug("!!!!!! Error Occurred while getting job status {}", ex);
ex.printStackTrace();
return null;
}
return inspect_status;
}
public String PredictGBMStatus( String src_key ) {
//http://localhost:54321/2/Inspect2.json?src_key=1111
String status ;
String h2oUrlJobStatusEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_PREDICT_STATUS_URL + "src_key=" + src_key;
System.out.println(h2oUrlJobStatusEndPoint);
System.out.println("@@@ Calling endpoint {} : " + h2oUrlJobStatusEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlJobStatusEndPoint);
RestTemplate restTemplate = new RestTemplate();
try {
while (true) {
String responseBody = restTemplate.getForObject(h2oUrlJobStatusEndPoint, String.class);
JSONObject jsonobject = new JSONObject(responseBody);
JSONObject response_info = (JSONObject)jsonobject.get("response_info");
status = (String)response_info.get("status");
log.debug("!!!!!! JOB Status {}", status);
if (status.equalsIgnoreCase("done")) {
break;
}
Thread.sleep(2000L); //Should use futures here
}
}catch(Exception ex){
log.debug("!!!!!! Error Occured while getting job status {}", ex);
return null;
}
return status;
}
public Double CalculateAUC( String actual, String vactual, String vpredict ) {
//http://localhost:54321/2/AUC.json?actual=prostate_csv.hex&vactual=CAPSULE&predict=predict_1&vpredict=1&thresholds=&threshold_criterion=maximum_F1
Double AUC;
String prediction_name = "Predict_GBM";
String h2oUrlCalculateAUCEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_AUC_URL + "actual=" + actual + "&vactual=" + vactual + "&predict=" + prediction_name + "&vpredict=" + vpredict +"&threshold_criterion=maximum_F1";
System.out.println(h2oUrlCalculateAUCEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlCalculateAUCEndPoint);
try {
RestTemplate restTemplate = new RestTemplate();
String responseBody = restTemplate.getForObject(h2oUrlCalculateAUCEndPoint, String.class);
JSONObject jsonobject = new JSONObject(responseBody);
JSONObject aucdata = (JSONObject)jsonobject.get("aucdata");
AUC = (Double)aucdata.get("AUC");
//status = (String).get("status");
System.out.println("AUC: " + AUC);
}catch(Exception ex){
log.debug("!!!!!! Error Occurred while getting job status {}", ex);
ex.printStackTrace();
return null;
}
return AUC;
}
public String DownloadPOJO( String model_key ) {
//http://localhost:54321/2/GBMModelView.java?_modelKey=gbmmodelDestinationKey
File java_pojo;
String h2oUrlDownloadPOJOEndPoint = H2O_HOST_URL + H2O_GBM_MODEL_POJO_URL+ model_key;
System.out.println(h2oUrlDownloadPOJOEndPoint);
log.debug("@@@ Calling endpoint {}", h2oUrlDownloadPOJOEndPoint);
try {
RestTemplate restTemplate = new RestTemplate();
String responseBody = restTemplate.getForObject(h2oUrlDownloadPOJOEndPoint, String.class);
//create java POJO file
String pojofilename = model_key + ".java";
java_pojo = new File(pojofilename);
System.out.println("POJO File name" + java_pojo.getAbsolutePath());
PrintWriter out = new PrintWriter(java_pojo.getAbsolutePath());
out.write(responseBody);
out.close();
}catch(Exception ex){
log.debug("!!!!!! Error Occurred while getting job status {}", ex);
ex.printStackTrace();
return null;
}
return java_pojo.getAbsolutePath();
}
}