package eu.dnetlib.iis.wf.citationmatching.direct.service;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import com.google.common.collect.Maps;
import eu.dnetlib.iis.common.citations.schemas.Citation;
import eu.dnetlib.iis.common.citations.schemas.CitationEntry;
import eu.dnetlib.iis.common.schemas.ReportEntry;
import pl.edu.icm.sparkutils.avro.SparkAvroSaver;
/**
* @author madryk
*/
@RunWith(MockitoJUnitRunner.class)
public class CitationMatchingDirectCounterReporterTest {
@InjectMocks
private CitationMatchingDirectCounterReporter counterReporter = new CitationMatchingDirectCounterReporter();
@Mock
private SparkAvroSaver avroSaver;
@Mock
private JavaRDD<Citation> matchedCitations;
@Mock
private JavaRDD<String> matchedCitationsDocumentIds;
@Mock
private JavaRDD<String> matchedCitationsDistinctDocumentIds;
@Mock
private JavaRDD<ReportEntry> reportCounters;
@Captor
private ArgumentCaptor<Function<Citation,String>> extractDocIdFunction;
@Captor
private ArgumentCaptor<List<ReportEntry>> reportEntriesCaptor;
//------------------------ TESTS --------------------------
@Test
public void report() throws Exception {
// given
JavaSparkContext sparkContext = mock(JavaSparkContext.class);
String reportPath = "/report/path";
when(matchedCitations.count()).thenReturn(14L);
doReturn(matchedCitationsDocumentIds).when(matchedCitations).map(any());
when(matchedCitationsDocumentIds.distinct()).thenReturn(matchedCitationsDistinctDocumentIds);
when(matchedCitationsDistinctDocumentIds.count()).thenReturn(3L);
doReturn(reportCounters).when(sparkContext).parallelize(any());
// execute
counterReporter.report(sparkContext, matchedCitations, reportPath);
// assert
verify(matchedCitations).map(extractDocIdFunction.capture());
assertExtractDocIdFunction(extractDocIdFunction.getValue());
verify(sparkContext).parallelize(reportEntriesCaptor.capture());
assertReportEntries(reportEntriesCaptor.getValue());
verify(avroSaver).saveJavaRDD(reportCounters, ReportEntry.SCHEMA$, reportPath);
}
//------------------------ PRIVATE --------------------------
private void assertExtractDocIdFunction(Function<Citation,String> function) throws Exception {
CitationEntry citationEntry = CitationEntry.newBuilder()
.setDestinationDocumentId("DEST_ID")
.setPosition(2)
.setExternalDestinationDocumentIds(Maps.newHashMap())
.build();
Citation citation = Citation.newBuilder()
.setSourceDocumentId("SOURCE_ID")
.setEntry(citationEntry)
.build();
String docId = function.call(citation);
assertEquals("SOURCE_ID", docId);
}
private void assertReportEntries(List<ReportEntry> reportEntries) {
assertEquals(2, reportEntries.size());
assertEquals("processing.citationMatching.direct.citDocReference", reportEntries.get(0).getKey());
assertEquals("14", reportEntries.get(0).getValue());
assertEquals("processing.citationMatching.direct.doc", reportEntries.get(1).getKey());
assertEquals("3", reportEntries.get(1).getValue());
}
}