/* * Copyright (C) 2015 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package com.google.cloud.genomics.dataflow.pipelines; import com.google.api.client.util.Strings; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.Default; import com.google.cloud.dataflow.sdk.options.Description; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.Filter; import com.google.cloud.dataflow.sdk.transforms.ParDo; import com.google.cloud.dataflow.sdk.transforms.Sample; import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; import com.google.cloud.dataflow.sdk.transforms.View; import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.cloud.genomics.dataflow.coders.GenericJsonCoder; import com.google.cloud.genomics.dataflow.functions.VariantFunctions; import com.google.cloud.genomics.dataflow.functions.verifybamid.LikelihoodFn; import com.google.cloud.genomics.dataflow.functions.verifybamid.ReadFunctions; import com.google.cloud.genomics.dataflow.functions.verifybamid.Solver; import com.google.cloud.genomics.dataflow.model.AlleleFreq; import com.google.cloud.genomics.dataflow.model.ReadBaseQuality; import com.google.cloud.genomics.dataflow.model.ReadBaseWithReference; import com.google.cloud.genomics.dataflow.model.ReadCounts; import com.google.cloud.genomics.dataflow.model.ReadQualityCount; import com.google.cloud.genomics.dataflow.pipelines.CalculateCoverage.CheckMatchingReferenceSet; import com.google.cloud.genomics.dataflow.readers.ReadGroupStreamer; import com.google.cloud.genomics.dataflow.readers.VariantStreamer; import com.google.cloud.genomics.dataflow.utils.CallSetNamesOptions; import com.google.cloud.genomics.dataflow.utils.GCSOutputOptions; import com.google.cloud.genomics.dataflow.utils.GenomicsOptions; import com.google.cloud.genomics.dataflow.utils.ShardOptions; import com.google.cloud.genomics.utils.GenomicsUtils; import com.google.cloud.genomics.utils.OfflineAuth; import com.google.cloud.genomics.utils.ShardBoundary; import com.google.cloud.genomics.utils.ShardUtils; import com.google.cloud.genomics.utils.ShardUtils.SexChromosomeFilter; import com.google.common.collect.ImmutableMultiset; import com.google.common.collect.Lists; import com.google.common.collect.Multiset; import com.google.genomics.v1.Position; import com.google.genomics.v1.Read; import com.google.genomics.v1.StreamVariantsRequest; import com.google.genomics.v1.Variant; import com.google.protobuf.ListValue; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Vector; /** * Test a set of reads for contamination. * * Takes a set of specified ReadGroupSets of reads to test and statistics on reference allele * frequencies for SNPs with a single alternative from a specified set of VariantSets. * * See http://googlegenomics.readthedocs.org/en/latest/use_cases/perform_quality_control_checks/verify_bam_id.html * for running instructions. * * Uses the sequence data alone approach described in: * G. Jun, M. Flickinger, K. N. Hetrick, Kurt, J. M. Romm, K. F. Doheny, * G. Abecasis, M. Boehnke,and H. M. Kang, Detecting and Estimating * Contamination of Human DNA Samples in Sequencing and Array-Based Genotype * Data, American journal of human genetics doi:10.1016/j.ajhg.2012.09.004 * (volume 91 issue 5 pp.839 - 848) * http://www.sciencedirect.com/science/article/pii/S0002929712004788 */ public class VerifyBamId { /** * Options required to run this pipeline. */ public static interface Options extends // Options for call set names. CallSetNamesOptions, // Options for calculating over regions, chromosomes, or whole genomes. ShardOptions, // Options for the output destination. GCSOutputOptions { @Description("A comma delimited list of the IDs of the Google Genomics ReadGroupSets this " + "pipeline is working with. Default (empty) indicates all ReadGroupSets in InputDatasetId." + " This or InputDatasetId must be set. InputDatasetId overrides " + "ReadGroupSetIds (if InputDatasetId is set, this field will be ignored).") @Default.String("") String getReadGroupSetIds(); void setReadGroupSetIds(String readGroupSetId); @Description("The ID of the Google Genomics Dataset that the pipeline will get its input reads" + " from. Default (empty) means to use ReadGroupSetIds and VariantSetIds instead. This or" + " ReadGroupSetIds and VariantSetIds must be set. InputDatasetId overrides" + " ReadGroupSetIds and VariantSetIds (if this field is set, ReadGroupSetIds and" + " VariantSetIds will be ignored).") @Default.String("") String getInputDatasetId(); void setInputDatasetId(String inputDatasetId); public String DEFAULT_VARIANTSET = "10473108253681171589"; @Override @Description("The ID of the Google Genomics VariantSet this pipeline is working with." + " It assumes the variant set has INFO field 'AF' from which it retrieves the" + " allele frequency for the variant, such as 1,000 Genomes phase 1 or phase 3 variants." + " Defaults to the 1,000 Genomes phase 1 VariantSet with id " + DEFAULT_VARIANTSET + ".") @Default.String(DEFAULT_VARIANTSET) String getVariantSetId(); void setVariantSetId(String variantSetId); @Description("The minimum allele frequency to use in analysis. Defaults to 0.01.") @Default.Double(0.01) double getMinFrequency(); void setMinFrequency(double minFrequency); @Description("The fraction of positions to check. Defaults to 0.01.") @Default.Double(0.01) double getSamplingFraction(); void setSamplingFraction(double minFrequency); public static class Methods { public static void validateOptions(Options options) { GCSOutputOptions.Methods.validateOptions(options); } } } private static Pipeline p; private static Options pipelineOptions; private static OfflineAuth auth; /** * String prefix used for sampling hash function */ private static final String HASH_PREFIX = ""; // Tip: Use the API explorer to test which fields to include in partial responses. // https://developers.google.com/apis-explorer/#p/genomics/v1/genomics.variants.stream?fields=variants(alternateBases%252Ccalls(callSetName%252Cgenotype)%252CreferenceBases)&_h=3&resource=%257B%250A++%2522variantSetId%2522%253A+%25223049512673186936334%2522%252C%250A++%2522referenceName%2522%253A+%2522chr17%2522%252C%250A++%2522start%2522%253A+%252241196311%2522%252C%250A++%2522end%2522%253A+%252241196312%2522%252C%250A++%2522callSetIds%2522%253A+%250A++%255B%25223049512673186936334-0%2522%250A++%255D%250A%257D& private static final String VARIANT_FIELDS = "variants(alternateBases,filter,info,quality,referenceBases,referenceName,start)"; /** * Run the VerifyBamId algorithm and output the resulting contamination estimate. */ public static void main(String[] args) throws GeneralSecurityException, IOException { // Register the options so that they show up via --help PipelineOptionsFactory.register(Options.class); pipelineOptions = PipelineOptionsFactory.fromArgs(args) .withValidation().as(Options.class); // Option validation is not yet automatic, we make an explicit call here. Options.Methods.validateOptions(pipelineOptions); // Set up the prototype request and auth. StreamVariantsRequest prototype = CallSetNamesOptions.Methods.getRequestPrototype(pipelineOptions); auth = GenomicsOptions.Methods.getGenomicsAuth(pipelineOptions); p = Pipeline.create(pipelineOptions); p.getCoderRegistry().setFallbackCoderProvider(GenericJsonCoder.PROVIDER); if (pipelineOptions.getInputDatasetId().isEmpty() && pipelineOptions.getReadGroupSetIds().isEmpty()) { throw new IllegalArgumentException("InputDatasetId or ReadGroupSetIds must be specified"); } List<String> rgsIds; if (pipelineOptions.getInputDatasetId().isEmpty()) { rgsIds = Lists.newArrayList(pipelineOptions.getReadGroupSetIds().split(",")); } else { rgsIds = GenomicsUtils.getReadGroupSetIds(pipelineOptions.getInputDatasetId(), auth); } // Grab one ReferenceSetId to be used within the pipeline to confirm that all ReadGroupSets // are associated with the same ReferenceSet. String referenceSetId = GenomicsUtils.getReferenceSetId(rgsIds.get(0), auth); if (Strings.isNullOrEmpty(referenceSetId)) { throw new IllegalArgumentException("No ReferenceSetId associated with ReadGroupSetId " + rgsIds.get(0) + ". All ReadGroupSets in given input must have an associated ReferenceSet."); } // TODO: confirm that variant set also corresponds to the same reference // https://github.com/googlegenomics/api-client-java/issues/66 // Reads in Reads. PCollection<Read> reads = p.begin() .apply(Create.of(rgsIds)) .apply(ParDo.of(new CheckMatchingReferenceSet(referenceSetId, auth))) .apply(new ReadGroupStreamer(auth, ShardBoundary.Requirement.STRICT, null, SexChromosomeFilter.INCLUDE_XY)); /* TODO: We can reduce the number of requests needed to be created by doing the following: 1. Stream the Variants first (rather than concurrently with the Reads). Select a subset of them equal to some threshold (say 50K by default). 2. Create the requests for streaming Reads by running a ParDo over the selected Variants to get their ranges (we only need to stream Reads that overlap the selected Variants). 3. Stream the Reads from the created requests. */ // Reads in Variants. TODO potentially provide an option to load the Variants from a file. List<StreamVariantsRequest> variantRequests = pipelineOptions.isAllReferences() ? ShardUtils.getVariantRequests(prototype, ShardUtils.SexChromosomeFilter.INCLUDE_XY, pipelineOptions.getBasesPerShard(), auth) : ShardUtils.getVariantRequests(prototype, pipelineOptions.getBasesPerShard(), pipelineOptions.getReferences()); PCollection<Variant> variants = p.apply(Create.of(variantRequests)) .apply(new VariantStreamer(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS)); PCollection<KV<Position, AlleleFreq>> refFreq = getFreq(variants, pipelineOptions.getMinFrequency()); PCollection<KV<Position, ReadCounts>> readCountsTable = combineReads(reads, pipelineOptions.getSamplingFraction(), HASH_PREFIX, refFreq); // Converts our results to a single Map of Position keys to ReadCounts values. PCollectionView<Map<Position, ReadCounts>> view = readCountsTable .apply(View.<Position, ReadCounts>asMap()); // Calculates the contamination estimate based on the resulting Map above. PCollection<String> result = p.begin() .apply(Create.of("")) .apply(ParDo.of(new Maximizer(view)).withSideInputs(view)); // Writes the result to the given output location in Cloud Storage. result.apply(TextIO.Write.to(pipelineOptions.getOutput()).named("WriteOutput").withoutSharding()); p.run(); } /** * Compute a PCollection of reference allele frequencies for SNPs of interest. * The SNPs all have only a single alternate allele, and neither the * reference nor the alternate allele have a population frequency < minFreq. * The results are returned in a PCollection indexed by Position. * * @param variants a set of variant calls for a reference population * @param minFreq the minimum allele frequency for the set * @return a PCollection mapping Position to AlleleCounts */ static PCollection<KV<Position, AlleleFreq>> getFreq( PCollection<Variant> variants, double minFreq) { return variants.apply(Filter.byPredicate(VariantFunctions.IS_PASSING).named("PassingFilter")) .apply(Filter.byPredicate(VariantFunctions.IS_ON_CHROMOSOME).named("OnChromosomeFilter")) .apply(Filter.byPredicate(VariantFunctions.IS_NOT_LOW_QUALITY).named("NotLowQualityFilter")) .apply(Filter.byPredicate(VariantFunctions.IS_SINGLE_ALTERNATE_SNP).named("SNPFilter")) .apply(ParDo.of(new GetAlleleFreq())) .apply(Filter.byPredicate(new FilterFreq(minFreq))); } /** * Filter, pile up, and sample reads, then join against reference statistics. * * @param reads A PCollection of reads * @param samplingFraction Fraction of reads to keep * @param samplingPrefix A prefix used in generating hashes used in sampling * @param refCounts A PCollection mapping position to counts of alleles in * a reference population. * @return A PCollection mapping Position to a ReadCounts proto */ static PCollection<KV<Position, ReadCounts>> combineReads(PCollection<Read> reads, double samplingFraction, String samplingPrefix, PCollection<KV<Position, AlleleFreq>> refFreq) { // Runs filters on input Reads, splits into individual aligned bases (emitting the // base and quality) and grabs a sample of them based on a hash mod of Position. PCollection<KV<Position, ReadBaseQuality>> joinReadCounts = reads.apply(Filter.byPredicate(ReadFunctions.IS_ON_CHROMOSOME).named("IsOnChromosome")) .apply(Filter.byPredicate(ReadFunctions.IS_NOT_QC_FAILURE).named("IsNotQCFailure")) .apply(Filter.byPredicate(ReadFunctions.IS_NOT_DUPLICATE).named("IsNotDuplicate")) .apply(Filter.byPredicate(ReadFunctions.IS_PROPER_PLACEMENT).named("IsProperPlacement")) .apply(ParDo.of(new SplitReads())) .apply(Filter.byPredicate(new SampleReads(samplingFraction, samplingPrefix))); TupleTag<ReadBaseQuality> readCountsTag = new TupleTag<>(); TupleTag<AlleleFreq> refFreqTag = new TupleTag<>(); // Pile up read counts, then join against reference stats. PCollection<KV<Position, CoGbkResult>> joined = KeyedPCollectionTuple .of(readCountsTag, joinReadCounts) .and(refFreqTag, refFreq) .apply(CoGroupByKey.<Position>create()); return joined.apply(ParDo.of(new PileupAndJoinReads(readCountsTag, refFreqTag))); } /** * Split reads into individual aligned bases and emit base + quality. */ static class SplitReads extends DoFn<Read, KV<Position, ReadBaseQuality>> { @Override public void processElement(ProcessContext c) throws Exception { List<ReadBaseWithReference> readBases = ReadFunctions.extractReadBases(c.element()); if (!readBases.isEmpty()) { for (ReadBaseWithReference rb : readBases) { c.output(KV.of(rb.getRefPosition(), rb.getRbq())); } } } } /** * Sample bases via a hash mod of position. */ static class SampleReads implements SerializableFunction<KV<Position, ReadBaseQuality>, Boolean> { private final double samplingFraction; private final String samplingPrefix; public SampleReads(double samplingFraction, String samplingPrefix) { this.samplingFraction = samplingFraction; this.samplingPrefix = samplingPrefix; } @Override public Boolean apply(KV<Position, ReadBaseQuality> input) { if (samplingFraction == 1.0) { return true; } else { byte[] msg; Position position = input.getKey(); try { msg = (samplingPrefix + position.getReferenceName() + ":" + position.getPosition() + ":" + position.getReverseStrand()).getBytes("UTF-8"); } catch (UnsupportedEncodingException e) { throw new AssertionError("UTF-8 not available - should not happen"); } MessageDigest md; try { md = MessageDigest.getInstance("MD5"); } catch (NoSuchAlgorithmException e) { throw new AssertionError("MD5 not available - should not happen"); } byte[] digest = md.digest(msg); if (digest.length != 16) { throw new AssertionError("MD5 should return 128 bits"); } ByteBuffer buffer = ByteBuffer.allocate(Long.SIZE); buffer.put(Arrays.copyOf(digest, Long.SIZE)); return ((((double) buffer.getLong(0) / (double) ((long) 1 << 63)) + 1.0) * 0.5) < samplingFraction; } } } /** * Map a variant to a Position, AlleleFreq pair. */ static class GetAlleleFreq extends DoFn<Variant, KV<Position, AlleleFreq>> { @Override public void processElement(ProcessContext c) throws Exception { ListValue lv = c.element().getInfo().get("AF"); if (lv != null && lv.getValuesCount() > 0) { Position position = Position.newBuilder() .setPosition(c.element().getStart()) .setReferenceName(c.element().getReferenceName()) .build(); AlleleFreq af = new AlleleFreq(); af.setRefFreq(Double.parseDouble(lv.getValues(0).getStringValue())); af.setAltBases(c.element().getAlternateBasesList()); af.setRefBases(c.element().getReferenceBases()); c.output(KV.of(position, af)); } else { // AF field wasn't populated in info, so we don't have frequency information // for this Variant. // TODO instead of straight throwing an exception, log a warning. If at the end of this // step the number of AlleleFreqs retrieved is below a given threshold, then throw an // exception. throw new IllegalArgumentException("Variant " + c.element().getId() + " does not have " + "allele frequency information stored in INFO field AF."); } } } /** * Filters out AlleleFreqs for which the reference or alternate allele * frequencies are below a minimum specified at construction. */ static class FilterFreq implements SerializableFunction<KV<Position, AlleleFreq>, Boolean> { private final double minFreq; public FilterFreq(double minFreq) { this.minFreq = minFreq; } @Override public Boolean apply(KV<Position, AlleleFreq> input) { double freq = input.getValue().getRefFreq(); if (freq >= minFreq && (1.0 - freq) >= minFreq) { return true; } return false; } } /** * Piles up reads and joins them against reference population statistics. */ static class PileupAndJoinReads extends DoFn<KV<Position, CoGbkResult>, KV<Position, ReadCounts>> { private final TupleTag<ReadBaseQuality> readCountsTag; private final TupleTag<AlleleFreq> refFreqTag; public PileupAndJoinReads(TupleTag<ReadBaseQuality> readCountsTag, TupleTag<AlleleFreq> refFreqTag) { this.readCountsTag = readCountsTag; this.refFreqTag = refFreqTag; } @Override public void processElement(ProcessContext c) throws Exception { AlleleFreq af = null; af = c.element().getValue().getOnly(refFreqTag, null); if (af == null || af.getAltBases() == null) { // no ref stats return; } if (af.getAltBases().size() != 1) { throw new IllegalArgumentException("Wrong number (" + af.getAltBases().size() + ") of" + " alternate bases for Position " + c.element().getKey()); } Iterable<ReadBaseQuality> reads = c.element().getValue().getAll(readCountsTag); ImmutableMultiset.Builder<ReadQualityCount> rqSetBuilder = ImmutableMultiset.builder(); for (ReadBaseQuality r : reads) { ReadQualityCount.Base b; if (af.getRefBases().equals(r.getBase())) { b = ReadQualityCount.Base.REF; } else if (af.getAltBases().get(0).equals(r.getBase())) { b = ReadQualityCount.Base.NONREF; } else { b = ReadQualityCount.Base.OTHER; } ReadQualityCount rqc = new ReadQualityCount(); rqc.setBase(b); rqc.setQuality(r.getQuality()); rqSetBuilder.add(rqc); } ReadCounts rc = new ReadCounts(); rc.setRefFreq(af.getRefFreq()); for (Multiset.Entry<ReadQualityCount> entry : rqSetBuilder.build().entrySet()) { ReadQualityCount rq = entry.getElement(); rq.setCount(entry.getCount()); rc.addReadQualityCount(rq); } c.output(KV.of(c.element().getKey(), rc)); } } /** * Calls the Solver to maximize via a univariate function the results of the pipeline, inputted * as a PCollectionView (the best way to retrieve our results as a Map in Dataflow). */ static class Maximizer extends DoFn<Object, String> { private final PCollectionView<Map<Position, ReadCounts>> view; // Target absolute error for Brent's algorithm private static final double ABS_ERR = 0.00001; // Target relative error for Brent's algorithm private static final double REL_ERR = 0.0001; // Maximum number of evaluations of the Likelihood function in Brent's algorithm private static final int MAX_EVAL = 1000; // Maximum number of iterations of Brent's algorithm private static final int MAX_ITER = 1000; // Grid search step size private static final double GRID_STEP = 0.001; public Maximizer(PCollectionView<Map<Position, ReadCounts>> view) { this.view = view; } @Override public void processElement(ProcessContext c) throws Exception { float[] steps = new float[]{0.1f, 0.05f, 0.01f, 0.005f, 0.001f}; for (float step : steps) { c.output(Float.toString(step) + ": " + Double.toString(Solver.maximize(new LikelihoodFn(c.sideInput(view)), 0.0, 0.5, step, REL_ERR, ABS_ERR, MAX_ITER, MAX_EVAL))); } } } }