package uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.log4j.Logger;
import uk.ac.ox.zoo.seeg.abraid.mp.common.domain.DiseaseOccurrence;
import uk.ac.ox.zoo.seeg.abraid.mp.common.web.WebServiceClientException;
import java.util.List;
/**
* Machine learning component used to predict the weighting.
* Copyright (c) 2014 University of Oxford
*/
public class MachineWeightingPredictor {
private MachineLearningWebService webService;
private static final Logger LOGGER = Logger.getLogger(MachineWeightingPredictor.class);
private static final String TRAINING_MESSAGE = "Training predictor for disease group %d with %d occurrences";
private static final String TRAINING_FAILURE = "Unable to train predictor for disease group.";
private static final String PREDICTION_FAILURE = "Unable to get prediction for occurrence.";
public MachineWeightingPredictor(MachineLearningWebService webService) {
this.webService = webService;
}
/**
* Train the model with the list of occurrences.
* @param diseaseGroupId The ID of the disease group to which the occurrences belong.
* @param occurrences The occurrences with which to train the predictor.
* @throws ModelRunWorkflowException if the json cannot be processed, or the request cannot be made.
*/
public void train(int diseaseGroupId, List<DiseaseOccurrence> occurrences)
throws ModelRunWorkflowException {
try {
LOGGER.info(String.format(TRAINING_MESSAGE, diseaseGroupId, occurrences.size()));
webService.sendTrainingData(diseaseGroupId, occurrences);
} catch (WebServiceClientException|JsonProcessingException e) {
LOGGER.error(e.getMessage());
throw new ModelRunWorkflowException(TRAINING_FAILURE);
}
}
/**
* Predict the weighting of a new occurrence.
* @param occurrence The occurrence.
* @return The predicted value for weighting.
* @throws ModelRunWorkflowException if the request cannot be made or the response cannot be handled.
*/
public Double findMachineWeighting(DiseaseOccurrence occurrence) throws ModelRunWorkflowException {
try {
return webService.getPrediction(occurrence);
} catch (JsonProcessingException|WebServiceClientException|NumberFormatException e) {
LOGGER.error(e.getMessage());
throw new ModelRunWorkflowException(PREDICTION_FAILURE);
}
}
}