package uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectWriter; import org.springframework.web.util.UriComponentsBuilder; import uk.ac.ox.zoo.seeg.abraid.mp.common.domain.DiseaseOccurrence; import uk.ac.ox.zoo.seeg.abraid.mp.common.dto.json.AbraidJsonObjectMapper; import uk.ac.ox.zoo.seeg.abraid.mp.common.dto.json.JsonDiseaseOccurrenceDataPoint; import uk.ac.ox.zoo.seeg.abraid.mp.common.dto.json.JsonDiseaseOccurrenceDataSet; import uk.ac.ox.zoo.seeg.abraid.mp.common.web.WebServiceClient; import uk.ac.ox.zoo.seeg.abraid.mp.common.web.WebServiceClientException; import java.util.ArrayList; import java.util.List; import static java.lang.Double.parseDouble; /** * Web Service for calling out to Machine Learning predictor. * Copyright (c) 2014 University of Oxford */ public class MachineLearningWebService { /** Server response indicating that a trusted prediction was not returned and point should be validated manually. */ private static final String EXPECTED_PREDICTION_FAILURE_RESPONSE = "No prediction"; /** URL component for training method name. */ private static final String TRAIN_METHOD = "/train"; /** URL component for prediction method name. */ private static final String PREDICT_METHOD = "/predict"; private WebServiceClient webServiceClient; private AbraidJsonObjectMapper objectMapper; private String rootUrl; public MachineLearningWebService(WebServiceClient webServiceClient, AbraidJsonObjectMapper objectMapper, String rootUrl) { this.webServiceClient = webServiceClient; this.objectMapper = objectMapper; this.rootUrl = rootUrl; } /** * Send the data points of one disease group, with which to train the predictor, as Json. * @param diseaseGroupId The ID of the disease group the occurrences belong to. * @param occurrences The training data points. * @throws JsonProcessingException if the JSON is invalid * @throws WebServiceClientException if the web service client cannot execute the request */ public void sendTrainingData(int diseaseGroupId, List<DiseaseOccurrence> occurrences) throws JsonProcessingException, WebServiceClientException { String url = buildUrl(diseaseGroupId, TRAIN_METHOD); JsonDiseaseOccurrenceDataSet data = convertToDTO(occurrences); String bodyAsJson = writeRequestBodyAsJson(data); webServiceClient.makePostRequestWithJSON(url, bodyAsJson); } /** * Find the predicted weighting of the given disease occurrence. * @param occurrence The disease occurrence. * @return The predicted weighting. * @throws JsonProcessingException If the JSON is invalid * @throws WebServiceClientException If the web service client fails to execute request * @throws NumberFormatException If the response string cannot be parsed as double */ public Double getPrediction(DiseaseOccurrence occurrence) throws JsonProcessingException, WebServiceClientException, NumberFormatException { Integer diseaseGroupId = getDiseaseGroupId(occurrence); if (diseaseGroupId != null) { String url = buildUrl(diseaseGroupId, PREDICT_METHOD); JsonDiseaseOccurrenceDataPoint data = new JsonDiseaseOccurrenceDataPoint(occurrence); String bodyAsJson = writeRequestBodyAsJson(data); String response = webServiceClient.makePostRequestWithJSON(url, bodyAsJson); if (response.equals(EXPECTED_PREDICTION_FAILURE_RESPONSE)) { return null; } else { return parseDouble(response); } } else { throw new ModelRunWorkflowException("No disease group"); } } private String buildUrl(int diseaseGroupId, String action) { UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(rootUrl) .path("/" + Integer.toString(diseaseGroupId)) .path(action); return builder.build().toString(); } private JsonDiseaseOccurrenceDataSet convertToDTO(List<DiseaseOccurrence> occurrences) { List<JsonDiseaseOccurrenceDataPoint> data = new ArrayList<>(); for (DiseaseOccurrence occurrence : occurrences) { data.add(new JsonDiseaseOccurrenceDataPoint(occurrence)); } return new JsonDiseaseOccurrenceDataSet(data); } private Integer getDiseaseGroupId(DiseaseOccurrence occurrence) { return (occurrence.getDiseaseGroup() != null) ? occurrence.getDiseaseGroup().getId() : null; } private String writeRequestBodyAsJson(Object data) throws JsonProcessingException { ObjectWriter writer = objectMapper.writer(); return writer.writeValueAsString(data); } }