package eu.dnetlib.iis.wf.documentsclassification;
import static eu.dnetlib.iis.common.report.ReportEntryFactory.createCounterReportEntry;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.ACM_CLASSES;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.ARXIV_CLASSES;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.CLASSIFIED_DOCUMENTS;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.DDC_CLASSES;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.MESH_EURO_PMC_CLASSES;
import static eu.dnetlib.iis.wf.documentsclassification.DocClassificationReportCounterKeys.WOS_CLASSES;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.runners.MockitoJUnitRunner;
import com.google.common.collect.Lists;
import eu.dnetlib.iis.common.schemas.ReportEntry;
import eu.dnetlib.iis.documentsclassification.schemas.DocumentClass;
import eu.dnetlib.iis.documentsclassification.schemas.DocumentClasses;
import eu.dnetlib.iis.documentsclassification.schemas.DocumentToDocumentClasses;
/**
* @author Ćukasz Dumiszewski
*/
@RunWith(MockitoJUnitRunner.class)
public class DocClassificationReportGeneratorTest {
private DocClassificationReportGenerator reportGenerator = new DocClassificationReportGenerator();
@Mock
private JavaRDD<DocumentToDocumentClasses> documentClasses;
@Mock
private JavaRDD<Long> docClassCounts;
@Mock
private DocumentToDocumentClasses docClass;
@Mock
private DocumentClasses docClasses;
@Captor
private ArgumentCaptor<Function<DocumentToDocumentClasses, Long>> mapFunction;
@Captor
private ArgumentCaptor<Function2<Long, Long, Long>> reduceFunction;
//------------------------ TESTS --------------------------
@Test(expected = NullPointerException.class)
public void generateReport_NULL() {
// execute
reportGenerator.generateReport(null);
}
@Test
public void generateReport_empty() {
// given
when(documentClasses.count()).thenReturn(0L);
// execute
List<ReportEntry> reportEntries = reportGenerator.generateReport(documentClasses);
// assert
assertReportEntries(reportEntries, 0, 0, 0, 0, 0, 0);
}
@Test
public void generateReport() throws Exception {
// given
when(documentClasses.count()).thenReturn(100L);
doReturn(docClassCounts).when(documentClasses).map(Mockito.any());
doReturn(2L).doReturn(3L).doReturn(4L).doReturn(5L).doReturn(6L).when(docClassCounts).reduce(Mockito.any());
// execute
List<ReportEntry> reportEntries = reportGenerator.generateReport(documentClasses);
// assert
assertReportEntries(reportEntries, 100, 2, 3, 4, 5, 6);
verify(documentClasses, Mockito.times(5)).map(mapFunction.capture());
assertArxivMapFunction(mapFunction.getAllValues().get(0));
assertWosMapFunction(mapFunction.getAllValues().get(1));
assertDdcMapFunction(mapFunction.getAllValues().get(2));
assertMeshMapFunction(mapFunction.getAllValues().get(3));
assertAcmMapFunction(mapFunction.getAllValues().get(4));
verify(docClassCounts, Mockito.times(5)).reduce(reduceFunction.capture());
for (Function2<Long, Long, Long> rf : reduceFunction.getAllValues()) {
assertReduceFunction(rf);
}
}
//------------------------ PRIVATE --------------------------
private void assertReportEntries(List<ReportEntry> reportEntries, long classifiedDocumentCount, long arxivClassCount, long wosClassCount, long ddcClassCount, long meshEuroPmcClassCount, long acmClassCount) {
assertThat(reportEntries, Matchers.contains(createCounterReportEntry(CLASSIFIED_DOCUMENTS, classifiedDocumentCount),
createCounterReportEntry(ARXIV_CLASSES, arxivClassCount),
createCounterReportEntry(WOS_CLASSES, wosClassCount),
createCounterReportEntry(DDC_CLASSES, ddcClassCount),
createCounterReportEntry(MESH_EURO_PMC_CLASSES, meshEuroPmcClassCount),
createCounterReportEntry(ACM_CLASSES, acmClassCount)));
}
private void assertArxivMapFunction(Function<DocumentToDocumentClasses, Long> function) throws Exception {
// given
when(docClass.getClasses()).thenReturn(docClasses);
when(docClasses.getArXivClasses()).thenReturn(generateDocumentClasses(2));
// execute & assert
assertEquals(2L, function.call(docClass).longValue());
}
private void assertWosMapFunction(Function<DocumentToDocumentClasses, Long> function) throws Exception {
// given
when(docClass.getClasses()).thenReturn(docClasses);
when(docClasses.getWoSClasses()).thenReturn(generateDocumentClasses(4));
// execute & assert
assertEquals(4L, function.call(docClass).longValue());
}
private void assertDdcMapFunction(Function<DocumentToDocumentClasses, Long> function) throws Exception {
// given
when(docClass.getClasses()).thenReturn(docClasses);
when(docClasses.getDDCClasses()).thenReturn(generateDocumentClasses(6));
// execute & assert
assertEquals(6L, function.call(docClass).longValue());
}
private void assertMeshMapFunction(Function<DocumentToDocumentClasses, Long> function) throws Exception {
// given
when(docClass.getClasses()).thenReturn(docClasses);
when(docClasses.getMeshEuroPMCClasses()).thenReturn(generateDocumentClasses(17));
// execute & assert
assertEquals(17L, function.call(docClass).longValue());
}
private void assertAcmMapFunction(Function<DocumentToDocumentClasses, Long> function) throws Exception {
// given
when(docClass.getClasses()).thenReturn(docClasses);
when(docClasses.getACMClasses()).thenReturn(generateDocumentClasses(7));
// execute & assert
assertEquals(7L, function.call(docClass).longValue());
}
private List<DocumentClass> generateDocumentClasses(int numberOfItems) {
List<DocumentClass> docClasses = Lists.newArrayList();
for (int i = 0; i < numberOfItems; i++) {
docClasses.add(mock(DocumentClass.class));
}
return docClasses;
}
private void assertReduceFunction(Function2<Long, Long, Long> function) throws Exception {
// execute & assert
assertEquals(4L, function.call(1L, 3L).longValue());
assertEquals(7L, function.call(4L, 3L).longValue());
}
}