package picard.analysis; import htsjdk.samtools.SAMFileReader; import htsjdk.samtools.SAMReadGroupRecord; import htsjdk.samtools.SAMRecord; import htsjdk.samtools.metrics.MetricsFile; import htsjdk.samtools.reference.ReferenceSequence; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import picard.metrics.MultiLevelCollector; import picard.metrics.MultilevelMetrics; import picard.metrics.PerUnitMetricCollector; import java.io.File; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import static htsjdk.samtools.util.CollectionUtil.makeSet; public class MultiLevelCollectorTest { public static File TESTFILE = new File("testdata/picard/sam/summary_alignment_stats_test_multiple.sam"); public String noneOrStr(final String str) { final String out; if(str == null) { out = ""; } else { out = str; } return out; } class TestArg { public final SAMRecord samRecord; public final ReferenceSequence refSeq; public TestArg(final SAMRecord samRecord, final ReferenceSequence refSeq) { this.samRecord = samRecord; this.refSeq = refSeq; } } /** We will just Tally up the number of times records were added to this metric and change FINISHED * to true when FINISHED is called */ class TotalNumberMetric extends MultilevelMetrics { /** The number of these encountered **/ public Integer TALLY = 0; public boolean FINISHED = false; } class RecordCountMultiLevelCollector extends MultiLevelCollector<TotalNumberMetric, Integer, TestArg> { public RecordCountMultiLevelCollector(final Set<MetricAccumulationLevel> accumulationLevels, final List<SAMReadGroupRecord> samRgRecords) { setup(accumulationLevels, samRgRecords); } //The number of times records were accepted by a RecordCountPerUnitCollectors (note since the same //samRecord might be aggregated by multiple PerUnit collectors, this may be greater than the number of //records in the file private int numProcessed = 0; public int getNumProcessed() { return numProcessed; } private final Map<String, TotalNumberMetric> unitsToMetrics = new HashMap<String, TotalNumberMetric>(); public Map<String, TotalNumberMetric> getUnitsToMetrics() { return unitsToMetrics; } @Override protected TestArg makeArg(final SAMRecord samRec, final ReferenceSequence refSeq) { return new TestArg(samRec, refSeq); } @Override protected PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg> makeChildCollector(final String sample, final String library, final String readGroup) { return new RecordCountPerUnitCollector(sample, library, readGroup); } private class RecordCountPerUnitCollector implements PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg>{ private final TotalNumberMetric metric; public RecordCountPerUnitCollector(final String sample, final String library, final String readGroup) { metric = new TotalNumberMetric(); metric.SAMPLE = sample; metric.LIBRARY = library; metric.READ_GROUP = readGroup; unitsToMetrics.put(noneOrStr(sample) + "_" + noneOrStr(library) + "_" + noneOrStr(readGroup), metric); } @Override public void acceptRecord(final TestArg args) { numProcessed += 1; metric.TALLY += 1; if(metric.SAMPLE != null) { Assert.assertEquals(metric.SAMPLE, args.samRecord.getReadGroup().getSample()); } if(metric.LIBRARY != null) { Assert.assertEquals(metric.LIBRARY, args.samRecord.getReadGroup().getLibrary()); } if(metric.READ_GROUP != null) { Assert.assertEquals(metric.READ_GROUP, args.samRecord.getReadGroup().getPlatformUnit()); } } @Override public void finish() { metric.FINISHED = true; } @Override public void addMetricsToFile(final MetricsFile<TotalNumberMetric, Integer> totalNumberMetricIntegerMetricsFile) { totalNumberMetricIntegerMetricsFile.addMetric(metric); } } } public static final Map<MetricAccumulationLevel, Map<String, Integer>> accumulationLevelToPerUnitReads = new HashMap<MetricAccumulationLevel, Map<String, Integer>>(); static { HashMap<String, Integer> curMap = new HashMap<String, Integer>(); curMap.put("__", 19); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.ALL_READS, curMap); curMap = new HashMap<String, Integer>(); curMap.put("Ma__", 10); curMap.put("Pa__", 9); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.SAMPLE, curMap); curMap = new HashMap<String, Integer>(); curMap.put("Ma_whatever_", 10); curMap.put("Pa_lib1_", 4); curMap.put("Pa_lib2_", 5); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.LIBRARY, curMap); curMap = new HashMap<String, Integer>(); curMap.put("Ma_whatever_me", 10); curMap.put("Pa_lib1_myself", 4); curMap.put("Pa_lib2_i", 3); curMap.put("Pa_lib2_i2", 2); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.READ_GROUP, curMap); } @DataProvider(name = "variedAccumulationLevels") public Object [][] variedAccumulationLevels() { return new Object[][] { {makeSet(MetricAccumulationLevel.ALL_READS)}, {makeSet(MetricAccumulationLevel.ALL_READS, MetricAccumulationLevel.SAMPLE)}, {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY)}, {makeSet(MetricAccumulationLevel.READ_GROUP, MetricAccumulationLevel.LIBRARY)}, {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP)}, {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP, MetricAccumulationLevel.ALL_READS)}, }; } @Test(dataProvider = "variedAccumulationLevels") public void multilevelCollectorTest(final Set<MetricAccumulationLevel> accumulationLevels) { final SAMFileReader in = new SAMFileReader(TESTFILE); final RecordCountMultiLevelCollector collector = new RecordCountMultiLevelCollector(accumulationLevels, in.getFileHeader().getReadGroups()); for (final SAMRecord rec : in) { collector.acceptRecord(rec, null); } collector.finish(); int totalProcessed = 0; int totalMetrics = 0; for(final MetricAccumulationLevel level : accumulationLevels) { final Map<String, Integer> keyToMetrics = accumulationLevelToPerUnitReads.get(level); for(final Map.Entry<String, Integer> entry : keyToMetrics.entrySet()) { final TotalNumberMetric metric = collector.getUnitsToMetrics().get(entry.getKey()); Assert.assertEquals(entry.getValue(), metric.TALLY); Assert.assertTrue(metric.FINISHED); totalProcessed += metric.TALLY; totalMetrics += 1; } } Assert.assertEquals(collector.getUnitsToMetrics().size(), totalMetrics); Assert.assertEquals(totalProcessed, collector.getNumProcessed()); } }