package uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow; import org.joda.time.DateTime; import org.joda.time.DateTimeUtils; import org.junit.Before; import org.junit.Test; import org.mockito.InOrder; import uk.ac.ox.zoo.seeg.abraid.mp.common.domain.DiseaseGroup; import uk.ac.ox.zoo.seeg.abraid.mp.common.domain.DiseaseOccurrence; import uk.ac.ox.zoo.seeg.abraid.mp.common.domain.DiseaseProcessType; import uk.ac.ox.zoo.seeg.abraid.mp.common.service.core.DiseaseService; import uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support.*; import uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support.extent.DiseaseExtentGenerator; import uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support.ModelRunOccurrencesSelector; import uk.ac.ox.zoo.seeg.abraid.mp.common.service.workflow.support.runrequest.ModelRunRequester; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; import static org.mockito.Mockito.never; /** * Tests the ModelRunWorkflowServiceImpl class. * * Copyright (c) 2014 University of Oxford */ public class ModelRunWorkflowServiceTest { private WeightingsCalculator weightingsCalculator; private ModelRunRequester modelRunRequester; private DiseaseOccurrenceReviewManager reviewManager; private DiseaseService diseaseService; private DiseaseExtentGenerator diseaseExtentGenerator; private ModelRunWorkflowServiceImpl modelRunWorkflowService; private AutomaticModelRunsEnabler automaticModelRunsEnabler; private MachineWeightingPredictor machineWeightingPredictor; private ModelRunOccurrencesSelector modelRunOccurrencesSelector; @Before public void setUp() { weightingsCalculator = mock(WeightingsCalculator.class); modelRunRequester = mock(ModelRunRequester.class); reviewManager = mock(DiseaseOccurrenceReviewManager.class); diseaseService = mock(DiseaseService.class); diseaseExtentGenerator = mock(DiseaseExtentGenerator.class); automaticModelRunsEnabler = mock(AutomaticModelRunsEnabler.class); machineWeightingPredictor = mock(MachineWeightingPredictor.class); modelRunOccurrencesSelector = mock(ModelRunOccurrencesSelector.class); modelRunWorkflowService = new ModelRunWorkflowServiceImpl(weightingsCalculator, modelRunRequester, reviewManager, diseaseService, modelRunOccurrencesSelector, diseaseExtentGenerator, automaticModelRunsEnabler, machineWeightingPredictor); } @Test public void prepareForAndRequestModelRunForAutomaticProcess() { // Arrange int diseaseGroupId = 87; DateTimeUtils.setCurrentMillisFixed(DateTime.now().getMillis()); DateTime lastModelRunPrepDate = DateTime.now().minusWeeks(1); DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); diseaseGroup.setLastModelRunPrepDate(lastModelRunPrepDate); diseaseGroup.setAutomaticModelRunsStartDate(DateTime.now()); DateTime minimumOccurrenceDate = DateTime.now(); List<DiseaseOccurrence> occurrences = createListWithDate(minimumOccurrenceDate); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); when(modelRunOccurrencesSelector.selectOccurrencesForModelRun(diseaseGroupId, false)).thenReturn(occurrences); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.prepareForAndRequestModelRun(diseaseGroupId, DiseaseProcessType.AUTOMATIC, null, null); // Assert //// No prep verify(weightingsCalculator, never()).updateDiseaseOccurrenceExpertWeightings(anyInt()); verify(weightingsCalculator, never()).updateExpertsWeightings(); verify(reviewManager, never()).updateDiseaseOccurrenceStatus(anyInt(), anyBoolean()); verify(diseaseExtentGenerator, never()).generateDiseaseExtent(any(DiseaseGroup.class), any(DateTime.class), any(DiseaseProcessType.class)); verify(machineWeightingPredictor, never()).train(anyInt(), anyListOf(DiseaseOccurrence.class)); //// Request run verify(modelRunRequester).requestModelRun(eq(diseaseGroupId), same(occurrences), isNull(DateTime.class), isNull(DateTime.class)); verify(diseaseService).saveDiseaseGroup(same(diseaseGroup)); } @Test public void prepareForAndRequestModelRunForManualProcess() { // Arrange int diseaseGroupId = 87; DateTimeUtils.setCurrentMillisFixed(DateTime.now().getMillis()); DateTime lastModelRunPrepDate = DateTime.now().minusWeeks(1); DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); diseaseGroup.setLastModelRunPrepDate(lastModelRunPrepDate); DateTime batchStartDate = new DateTime("2012-11-13T00:00:00.000"); DateTime batchEndDate = new DateTime("2012-11-14T23:59:59.999"); DateTime minimumOccurrenceDate = DateTime.now(); List<DiseaseOccurrence> occurrences = createListWithDate(minimumOccurrenceDate); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); when(modelRunOccurrencesSelector.selectOccurrencesForModelRun(diseaseGroupId, false)).thenReturn(occurrences); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.prepareForAndRequestModelRun(diseaseGroupId, DiseaseProcessType.MANUAL, batchStartDate, batchEndDate); // Assert // Prep InOrder order = inOrder(weightingsCalculator, reviewManager, diseaseExtentGenerator, machineWeightingPredictor, modelRunRequester, diseaseService); order.verify(weightingsCalculator).updateExpertsWeightings(); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(false)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); order.verify(diseaseExtentGenerator).generateDiseaseExtent(eq(diseaseGroup), isNull(DateTime.class), eq(DiseaseProcessType.MANUAL)); //// Request run order.verify(modelRunRequester).requestModelRun(eq(diseaseGroupId), same(occurrences), eq(batchStartDate), eq(batchEndDate)); order.verify(diseaseService).saveDiseaseGroup(same(diseaseGroup)); } @Test public void prepareForAndRequestModelRunForGoldStandard() { // Arrange int diseaseGroupId = 87; DateTimeUtils.setCurrentMillisFixed(DateTime.now().getMillis()); DateTime lastModelRunPrepDate = DateTime.now().minusWeeks(1); DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); diseaseGroup.setLastModelRunPrepDate(lastModelRunPrepDate); DateTime minimumOccurrenceDate = DateTime.now(); List<DiseaseOccurrence> occurrences = createListWithDate(minimumOccurrenceDate); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); when(modelRunOccurrencesSelector.selectOccurrencesForModelRun(diseaseGroupId, true)).thenReturn(occurrences); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.prepareForAndRequestModelRun(diseaseGroupId, DiseaseProcessType.MANUAL_GOLD_STANDARD, null, null); // Assert // Prep InOrder order = inOrder(weightingsCalculator, reviewManager, diseaseExtentGenerator, machineWeightingPredictor, modelRunRequester, diseaseService); order.verify(weightingsCalculator).updateExpertsWeightings(); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(false)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); order.verify(diseaseExtentGenerator).generateDiseaseExtent(eq(diseaseGroup), isNull(DateTime.class), eq(DiseaseProcessType.MANUAL_GOLD_STANDARD)); //// Request run order.verify(modelRunRequester).requestModelRun(eq(diseaseGroupId), same(occurrences), isNull(DateTime.class), isNull(DateTime.class)); order.verify(diseaseService).saveDiseaseGroup(same(diseaseGroup)); } @Test public void enableAutomaticModelRuns() { // Arrange int diseaseGroupId = 87; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.enableAutomaticModelRuns(diseaseGroupId); // Assert InOrder order = inOrder(automaticModelRunsEnabler, weightingsCalculator, reviewManager, machineWeightingPredictor); order.verify(weightingsCalculator).updateExpertsWeightings(); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(false)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); order.verify(automaticModelRunsEnabler).enable(eq(diseaseGroupId)); } @Test public void processOccurrencesOnDataValidatorForAuto() { // Arrange int diseaseGroupId = 87; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.processOccurrencesOnDataValidator(diseaseGroupId, DiseaseProcessType.AUTOMATIC); // Assert InOrder order = inOrder(automaticModelRunsEnabler, weightingsCalculator, reviewManager, machineWeightingPredictor); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(true)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); } @Test public void processOccurrencesOnDataValidatorForManual() { // Arrange int diseaseGroupId = 87; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.processOccurrencesOnDataValidator(diseaseGroupId, DiseaseProcessType.MANUAL); // Assert InOrder order = inOrder(automaticModelRunsEnabler, weightingsCalculator, reviewManager, machineWeightingPredictor); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(false)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); } @Test public void processOccurrencesOnDataValidatorForManualGoldStandard() { // Arrange int diseaseGroupId = 87; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); List<DiseaseOccurrence> occurrencesForTrainingPredictor = new ArrayList<>(); when(diseaseService.getDiseaseOccurrencesForTrainingPredictor(diseaseGroupId)).thenReturn( occurrencesForTrainingPredictor); // Act modelRunWorkflowService.processOccurrencesOnDataValidator(diseaseGroupId, DiseaseProcessType.MANUAL_GOLD_STANDARD); // Assert InOrder order = inOrder(automaticModelRunsEnabler, weightingsCalculator, reviewManager, machineWeightingPredictor); order.verify(weightingsCalculator).updateDiseaseOccurrenceExpertWeightings(eq(diseaseGroupId)); order.verify(reviewManager).updateDiseaseOccurrenceStatus(eq(diseaseGroupId), eq(false)); order.verify(machineWeightingPredictor).train(eq(diseaseGroupId), same(occurrencesForTrainingPredictor)); } @Test public void generateDiseaseExtentForManualProcess() { // Arrange int diseaseGroupId = 1; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); // Act modelRunWorkflowService.generateDiseaseExtent(diseaseGroupId, DiseaseProcessType.MANUAL); // Assert verify(diseaseExtentGenerator).generateDiseaseExtent(eq(diseaseGroup), eq((DateTime) null), eq(DiseaseProcessType.MANUAL)); verify(modelRunOccurrencesSelector, never()).selectOccurrencesForModelRun(anyInt(), anyBoolean()); } @Test public void generateDiseaseExtentForAutomaticProcess() { // Arrange int diseaseGroupId = 1; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); diseaseGroup.setAutomaticModelRunsStartDate(DateTime.now()); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); DateTime minimumOccurrenceDate = DateTime.now(); List<DiseaseOccurrence> occurrences = createListWithMultipleDates(minimumOccurrenceDate.plus(1), minimumOccurrenceDate, minimumOccurrenceDate.plusWeeks(1)); when(modelRunOccurrencesSelector.selectOccurrencesForModelRun(diseaseGroupId, false)).thenReturn(occurrences); // Act modelRunWorkflowService.generateDiseaseExtent(diseaseGroupId, DiseaseProcessType.AUTOMATIC); // Assert verify(diseaseExtentGenerator).generateDiseaseExtent(eq(diseaseGroup), same(minimumOccurrenceDate), eq(DiseaseProcessType.AUTOMATIC)); } @Test public void generateDiseaseExtentForGoldStandard() { // Arrange int diseaseGroupId = 1; DiseaseGroup diseaseGroup = new DiseaseGroup(diseaseGroupId); when(diseaseService.getDiseaseGroupById(diseaseGroupId)).thenReturn(diseaseGroup); // Act modelRunWorkflowService.generateDiseaseExtent(diseaseGroupId, DiseaseProcessType.MANUAL_GOLD_STANDARD); // Assert verify(diseaseExtentGenerator).generateDiseaseExtent(eq(diseaseGroup), eq((DateTime) null), eq(DiseaseProcessType.MANUAL_GOLD_STANDARD)); } private List<DiseaseOccurrence> createListWithDate(DateTime minimumOccurrenceDate) { DiseaseOccurrence occurrence = new DiseaseOccurrence(); occurrence.setOccurrenceDate(minimumOccurrenceDate); return Arrays.asList(occurrence); } private List<DiseaseOccurrence> createListWithMultipleDates(DateTime... dates) { List<DiseaseOccurrence> diseaseOccurrences = new ArrayList<>(); for (DateTime date : dates) { DiseaseOccurrence occurrence = new DiseaseOccurrence(); occurrence.setOccurrenceDate(date); diseaseOccurrences.add(occurrence); } return diseaseOccurrences; } }